operations.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. """
  2. 数据库操作封装模块
  3. 提供完整的CRUD操作功能
  4. 包括用户管理、图片管理、处理记录管理、系统配置管理、数据库版本管理
  5. """
  6. import json
  7. import hashlib
  8. import uuid
  9. from datetime import datetime
  10. from typing import Optional, List, Dict, Any, Union
  11. from sqlalchemy.orm import Session
  12. from sqlalchemy import and_, or_, desc, asc
  13. from sqlalchemy.exc import IntegrityError
  14. from .connection import DatabaseConnection
  15. from .models import User, ImageRecord, TextRecord, ProcessRecord, SystemConfig, DatabaseVersion
  16. from backend.utils.logger_config import setup_logger
  17. logger = setup_logger(__name__)
  18. class DatabaseOperations:
  19. """数据库操作封装类"""
  20. def __init__(self, db_connection: Optional[DatabaseConnection] =None):
  21. """
  22. 初始化数据库操作
  23. Args:
  24. db_connection: 数据库连接对象,如果为None则使用默认连接
  25. """
  26. self.db_connection = db_connection or DatabaseConnection()
  27. # ==================== 用户相关操作 ====================
  28. def create_user(self, username: str, password_hash: str, is_admin: bool =False) -> dict:
  29. """
  30. 创建用户
  31. Args:
  32. username: 用户名
  33. password_hash: 密码哈希
  34. is_admin: 是否管理员
  35. Returns:
  36. dict: 创建的用户信息
  37. """
  38. with self.db_connection.get_session() as session:
  39. user = User(
  40. username=username,
  41. password_hash=password_hash,
  42. is_admin=is_admin
  43. )
  44. session.add(user)
  45. session.flush()
  46. session.refresh(user)
  47. result = {
  48. "id": user.id,
  49. "username": user.username,
  50. "is_admin": user.is_admin,
  51. "is_active": user.is_active,
  52. "created_at": user.created_at,
  53. "updated_at": user.updated_at,
  54. "last_login": user.last_login
  55. }
  56. return result
  57. def get_user_by_id(self, user_id: int) -> dict:
  58. """根据ID获取用户"""
  59. with self.db_connection.get_session() as session:
  60. user = session.query(User).filter(User.id == user_id).first()
  61. if not user:
  62. return None
  63. return {
  64. "id": user.id,
  65. "username": user.username,
  66. "password_hash": user.password_hash,
  67. "is_admin": user.is_admin,
  68. "is_active": user.is_active,
  69. "created_at": user.created_at,
  70. "updated_at": user.updated_at,
  71. "last_login": user.last_login
  72. }
  73. def get_user_by_username(self, username: str) -> dict:
  74. """根据用户名获取用户"""
  75. with self.db_connection.get_session() as session:
  76. user = session.query(User).filter(User.username == username).first()
  77. if not user:
  78. return None
  79. return {
  80. "id": user.id,
  81. "username": user.username,
  82. "password_hash": user.password_hash,
  83. "is_admin": user.is_admin,
  84. "is_active": user.is_active,
  85. "created_at": user.created_at,
  86. "updated_at": user.updated_at,
  87. "last_login": user.last_login
  88. }
  89. def update_user(self, user_id: int, **kwargs) -> dict:
  90. """更新用户信息"""
  91. with self.db_connection.get_session() as session:
  92. user = session.query(User).filter(User.id == user_id).first()
  93. if user:
  94. for key, value in kwargs.items():
  95. if hasattr(user, key):
  96. setattr(user, key, value)
  97. user.updated_at = datetime.now()
  98. session.flush()
  99. logger.info(f"更新用户成功: {user_id}")
  100. return {
  101. "id": user.id,
  102. "username": user.username,
  103. "password_hash": user.password_hash,
  104. "is_admin": user.is_admin,
  105. "is_active": user.is_active,
  106. "created_at": user.created_at,
  107. "updated_at": user.updated_at,
  108. "last_login": user.last_login
  109. }
  110. return None
  111. def delete_user(self, user_id: int) -> bool:
  112. """删除用户"""
  113. with self.db_connection.get_session() as session:
  114. user = session.query(User).filter(User.id == user_id).first()
  115. if user:
  116. session.delete(user)
  117. logger.info(f"删除用户成功: {user_id}")
  118. return True
  119. return False
  120. # ==================== 图片记录相关操作 ====================
  121. def create_image_record(self, user_id: int, image_type: str, original_filename: str, stored_path: str,
  122. file_size: Optional[int] = None, image_hash: Optional[str] = None) -> dict:
  123. """
  124. 创建图片记录
  125. Args:
  126. user_id: 用户ID
  127. image_type: 图片类型(face/cloth/result)
  128. original_filename: 原始文件名
  129. stored_path: 存储路径
  130. file_size: 文件大小
  131. image_hash: 图片哈希
  132. Returns:
  133. dict: 创建的图片记录信息
  134. """
  135. with self.db_connection.get_session() as session:
  136. image_record = ImageRecord(
  137. user_id=user_id,
  138. image_type=image_type,
  139. original_filename=original_filename,
  140. stored_path=stored_path,
  141. file_size=file_size,
  142. image_hash=image_hash
  143. )
  144. session.add(image_record)
  145. session.flush()
  146. session.refresh(image_record)
  147. result = {
  148. "id": image_record.id,
  149. "user_id": image_record.user_id,
  150. "image_type": image_record.image_type,
  151. "original_filename": image_record.original_filename,
  152. "stored_path": image_record.stored_path,
  153. "file_size": image_record.file_size,
  154. "image_hash": image_record.image_hash,
  155. "is_deleted": image_record.is_deleted,
  156. "created_at": image_record.created_at,
  157. "updated_at": image_record.updated_at
  158. }
  159. return result
  160. def get_image_record_by_id(self, image_id: int) -> dict:
  161. """根据ID获取图片记录"""
  162. with self.db_connection.get_session() as session:
  163. image_record = session.query(ImageRecord).filter(ImageRecord.id == image_id).first()
  164. if not image_record:
  165. return None
  166. return {
  167. "id": image_record.id,
  168. "user_id": image_record.user_id,
  169. "image_type": image_record.image_type,
  170. "original_filename": image_record.original_filename,
  171. "stored_path": image_record.stored_path,
  172. "file_size": image_record.file_size,
  173. "image_hash": image_record.image_hash,
  174. "is_deleted": image_record.is_deleted,
  175. "created_at": image_record.created_at,
  176. "updated_at": image_record.updated_at
  177. }
  178. def get_user_images(self, user_id: int, image_type: Optional[str] = None,
  179. page: int = 1, page_size: int = 20) -> Dict[str, Any]:
  180. """获取用户的图片记录"""
  181. with self.db_connection.get_session() as session:
  182. query = session.query(ImageRecord).filter(
  183. and_(ImageRecord.user_id == user_id, ImageRecord.is_deleted == False)
  184. )
  185. if image_type:
  186. query = query.filter(ImageRecord.image_type == image_type)
  187. total = query.count()
  188. images = query.order_by(desc(ImageRecord.created_at)).offset(
  189. (page - 1) * page_size
  190. ).limit(page_size).all()
  191. return {
  192. "images": [
  193. {
  194. "id": img.id,
  195. "user_id": img.user_id,
  196. "image_type": img.image_type,
  197. "original_filename": img.original_filename,
  198. "stored_path": img.stored_path,
  199. "file_size": img.file_size,
  200. "image_hash": img.image_hash,
  201. "is_deleted": img.is_deleted,
  202. "created_at": img.created_at,
  203. "updated_at": img.updated_at
  204. }
  205. for img in images
  206. ],
  207. "total": total,
  208. "page": page,
  209. "page_size": page_size,
  210. "total_pages": (total + page_size - 1) // page_size
  211. }
  212. def update_image_record(self, image_id: int, **kwargs) -> dict:
  213. """更新图片记录"""
  214. with self.db_connection.get_session() as session:
  215. image_record = session.query(ImageRecord).filter(ImageRecord.id == image_id).first()
  216. if image_record:
  217. for key, value in kwargs.items():
  218. if hasattr(image_record, key):
  219. setattr(image_record, key, value)
  220. image_record.updated_at = datetime.now()
  221. session.flush()
  222. logger.info(f"更新图片记录成功: {image_id}")
  223. return {
  224. "id": image_record.id,
  225. "user_id": image_record.user_id,
  226. "image_type": image_record.image_type,
  227. "original_filename": image_record.original_filename,
  228. "stored_path": image_record.stored_path,
  229. "file_size": image_record.file_size,
  230. "image_hash": image_record.image_hash,
  231. "is_deleted": image_record.is_deleted,
  232. "created_at": image_record.created_at,
  233. "updated_at": image_record.updated_at
  234. }
  235. return None
  236. def delete_image_record(self, image_id: int, soft_delete: bool = True) -> bool:
  237. """删除图片记录"""
  238. with self.db_connection.get_session() as session:
  239. image_record = session.query(ImageRecord).filter(ImageRecord.id == image_id).first()
  240. if image_record:
  241. if soft_delete:
  242. image_record.is_deleted = True
  243. image_record.updated_at = datetime.now()
  244. else:
  245. session.delete(image_record)
  246. logger.info(f"删除图片记录成功: {image_id}")
  247. return True
  248. return False
  249. # ==================== 文本模板相关操作 ====================
  250. def create_text_record(self, user_id: int, text_type: str, text_name: str, text_label: str, text_content: str) -> dict:
  251. """
  252. 创建文本模板记录
  253. Args:
  254. user_id: 用户ID
  255. text_type: 文本模板类型(prompt/copywrite)
  256. text_name: 文本名称
  257. text_label: 文本模板标签
  258. text_content: 文本模板内容
  259. Returns:
  260. dict: 创建的文本模板记录信息
  261. """
  262. with self.db_connection.get_session() as session:
  263. text_record = TextRecord(
  264. user_id=user_id,
  265. text_type=text_type,
  266. text_name=text_name,
  267. text_label=text_label,
  268. text_content=text_content
  269. )
  270. session.add(text_record)
  271. session.flush()
  272. session.refresh(text_record)
  273. result = {
  274. "id": text_record.id,
  275. "user_id": text_record.user_id,
  276. "text_type": text_record.text_type,
  277. "text_name": text_record.text_name,
  278. "text_label": text_record.text_label,
  279. "text_content": text_record.text_content,
  280. "created_at": text_record.created_at,
  281. }
  282. return result
  283. def get_text_record_by_id(self, text_id: int) -> dict:
  284. """根据ID获取文本模板记录"""
  285. with self.db_connection.get_session() as session:
  286. text_record = session.query(TextRecord).filter(TextRecord.id == text_id).first()
  287. if not text_record:
  288. return None
  289. return {
  290. "id": text_record.id,
  291. "user_id": text_record.user_id,
  292. "text_type": text_record.text_type,
  293. "text_name": text_record.text_name,
  294. "text_label": text_record.text_label,
  295. "text_content": text_record.text_content,
  296. "created_at": text_record.created_at,
  297. }
  298. def get_user_text_records(self, user_id: int, text_type: Optional[str] = None, page: int = 1, page_size: int = 20) -> Dict[str, Any]:
  299. """获取用户的文本模板记录"""
  300. with self.db_connection.get_session() as session:
  301. query = session.query(TextRecord).filter(TextRecord.user_id == user_id)
  302. if text_type:
  303. query = query.filter(TextRecord.text_type == text_type)
  304. total = query.count()
  305. text_records = query.order_by(desc(TextRecord.created_at)).offset(
  306. (page - 1) * page_size
  307. ).limit(page_size).all()
  308. return {
  309. "records": [
  310. {
  311. "id": record.id,
  312. "user_id": record.user_id,
  313. "text_type": record.text_type,
  314. "text_name": record.text_name,
  315. "text_label": record.text_label,
  316. "text_content": record.text_content,
  317. "created_at": record.created_at,
  318. }
  319. for record in text_records
  320. ],
  321. "total": total,
  322. "page": page,
  323. "page_size": page_size,
  324. "total_pages": (total + page_size - 1) // page_size
  325. }
  326. def update_text_record(self, text_id: int, **kwargs) -> dict:
  327. """更新文本模板记录"""
  328. with self.db_connection.get_session() as session:
  329. text_record = session.query(TextRecord).filter(TextRecord.id == text_id).first()
  330. if not text_record:
  331. return None
  332. for key, value in kwargs.items():
  333. if hasattr(text_record, key):
  334. setattr(text_record, key, value)
  335. # 更新时间
  336. if hasattr(text_record, "updated_at"):
  337. text_record.updated_at = datetime.now()
  338. session.flush()
  339. session.commit()
  340. logger.info(f"更新文本模板记录成功: {text_id}")
  341. return {
  342. "id": text_record.id,
  343. "user_id": text_record.user_id,
  344. "text_type": text_record.text_type,
  345. "text_name": text_record.text_name,
  346. "text_label": text_record.text_label,
  347. "text_content": text_record.text_content,
  348. "created_at": text_record.created_at,
  349. }
  350. def delete_text_record(self, text_id: int) -> bool:
  351. """删除文本模板记录"""
  352. with self.db_connection.get_session() as session:
  353. text_record = session.query(TextRecord).filter(TextRecord.id == text_id).first()
  354. if text_record:
  355. session.delete(text_record)
  356. logger.info(f"删除文本模板记录成功: {text_id}")
  357. return True
  358. return False
  359. # ==================== 处理记录相关操作 ====================
  360. def create_process_record(self, user_id: int, face_image_id: int, cloth_image_id: int,
  361. result_image_id: int, generated_text: Optional[str] = None, status: str = "待审核", task_type: str = "swap_face", prompt: str = "") -> dict:
  362. """
  363. 创建AI任务处理记录
  364. Args:
  365. user_id: 用户ID
  366. face_image_id: 人脸图片ID
  367. cloth_image_id: 服装图片ID
  368. result_image_id: 结果图片ID
  369. generated_text: AI生成的文案内容
  370. status: 处理记录状态
  371. task_type: 任务类型
  372. prompt: 提示词
  373. Returns:
  374. dict: 创建的处理记录信息
  375. """
  376. with self.db_connection.get_session() as session:
  377. process_record = ProcessRecord(
  378. user_id=user_id,
  379. face_image_id=face_image_id,
  380. cloth_image_id=cloth_image_id,
  381. result_image_id=result_image_id,
  382. generated_text=generated_text,
  383. status=status,
  384. task_type=task_type,
  385. prompt=prompt,
  386. completed_at=datetime.now()
  387. )
  388. session.add(process_record)
  389. session.flush()
  390. session.refresh(process_record)
  391. result = {
  392. "id": process_record.id,
  393. "user_id": process_record.user_id,
  394. "face_image_id": process_record.face_image_id,
  395. "cloth_image_id": process_record.cloth_image_id,
  396. "result_image_id": process_record.result_image_id,
  397. "generated_text": process_record.generated_text,
  398. "status": process_record.status,
  399. "task_type": process_record.task_type,
  400. "prompt": process_record.prompt,
  401. "completed_at": process_record.completed_at
  402. }
  403. return result
  404. def get_process_record_by_id(self, process_id: int) -> dict:
  405. """根据ID获取处理记录"""
  406. with self.db_connection.get_session() as session:
  407. process_record = session.query(ProcessRecord).filter(ProcessRecord.id == process_id).first()
  408. if not process_record:
  409. return None
  410. return {
  411. "id": process_record.id,
  412. "user_id": process_record.user_id,
  413. "face_image_id": process_record.face_image_id,
  414. "cloth_image_id": process_record.cloth_image_id,
  415. "result_image_id": process_record.result_image_id,
  416. "generated_text": process_record.generated_text,
  417. "status": process_record.status,
  418. "task_type": process_record.task_type,
  419. "prompt": process_record.prompt,
  420. "completed_at": process_record.completed_at
  421. }
  422. def get_user_process_records(self, user_id: int, page: int = 1, page_size: int = 20) -> Dict[str, Any]:
  423. """获取用户的处理记录"""
  424. with self.db_connection.get_session() as session:
  425. query = session.query(ProcessRecord).filter(ProcessRecord.user_id == user_id)
  426. total = query.count()
  427. records = query.order_by(desc(ProcessRecord.completed_at)).offset(
  428. (page - 1) * page_size
  429. ).limit(page_size).all()
  430. return {
  431. "records": [
  432. {
  433. "id": record.id,
  434. "user_id": record.user_id,
  435. "face_image_id": record.face_image_id,
  436. "cloth_image_id": record.cloth_image_id,
  437. "result_image_id": record.result_image_id,
  438. "generated_text": record.generated_text,
  439. "status": record.status,
  440. "task_type": record.task_type,
  441. "prompt": record.prompt,
  442. "completed_at": record.completed_at
  443. }
  444. for record in records
  445. ],
  446. "total": total,
  447. "page": page,
  448. "page_size": page_size,
  449. "total_pages": (total + page_size - 1) // page_size
  450. }
  451. def update_process_record(self, process_id: int, update_data: Dict[str, Any]) -> dict:
  452. """
  453. 更新处理记录
  454. Args:
  455. process_id: 处理记录ID
  456. update_data: 要更新的字段和值的字典
  457. Returns:
  458. dict: 更新后的处理记录信息,如果记录不存在则返回None
  459. """
  460. with self.db_connection.get_session() as session:
  461. process_record = session.query(ProcessRecord).filter(ProcessRecord.id == process_id).first()
  462. if process_record:
  463. for key, value in update_data.items():
  464. if hasattr(process_record, key):
  465. setattr(process_record, key, value)
  466. session.flush()
  467. session.commit()
  468. logger.info(f"更新处理记录成功: {process_id}")
  469. return {
  470. "id": process_record.id,
  471. "user_id": process_record.user_id,
  472. "face_image_id": process_record.face_image_id,
  473. "cloth_image_id": process_record.cloth_image_id,
  474. "result_image_id": process_record.result_image_id,
  475. "generated_text": process_record.generated_text,
  476. "status": process_record.status,
  477. "task_type": process_record.task_type,
  478. "prompt": process_record.prompt,
  479. "completed_at": process_record.completed_at
  480. }
  481. return None
  482. def delete_process_record(self, process_id: int) -> bool:
  483. """删除处理记录"""
  484. with self.db_connection.get_session() as session:
  485. process_record = session.query(ProcessRecord).filter(ProcessRecord.id == process_id).first()
  486. if process_record:
  487. session.delete(process_record)
  488. logger.info(f"删除处理记录成功: {process_id}")
  489. return True
  490. return False
  491. # ==================== 系统配置相关操作 ====================
  492. def set_config(self, config_key: str, config_value: Any, config_type: str = "string",
  493. config_description: Optional[str] = None, is_public: bool = True) -> SystemConfig:
  494. """
  495. 设置系统配置
  496. Args:
  497. config_key: 配置键
  498. config_value: 配置值
  499. config_type: 配置类型
  500. config_description: 配置描述
  501. is_public: 是否公开
  502. Returns:
  503. SystemConfig: 配置对象
  504. """
  505. with self.db_connection.get_session() as session:
  506. # 序列化配置值
  507. if config_type == "json" and isinstance(config_value, (dict, list)):
  508. config_value = json.dumps(config_value, ensure_ascii=False)
  509. elif config_type in ["int", "float", "bool"]:
  510. config_value = str(config_value)
  511. # 查找现有配置
  512. config = session.query(SystemConfig).filter(SystemConfig.config_key == config_key).first()
  513. if config:
  514. # 更新现有配置
  515. config.config_value = config_value
  516. config.config_type = config_type
  517. config.config_description = config_description
  518. config.is_public = is_public
  519. config.updated_at = datetime.now()
  520. else:
  521. # 创建新配置
  522. config = SystemConfig(
  523. config_key=config_key,
  524. config_value=config_value,
  525. config_type=config_type,
  526. config_description=config_description,
  527. is_public=is_public
  528. )
  529. session.add(config)
  530. session.flush()
  531. session.refresh(config)
  532. logger.info(f"设置配置成功: {config_key}")
  533. return config
  534. def get_config(self, config_key: str) -> Optional[SystemConfig]:
  535. """获取系统配置"""
  536. with self.db_connection.get_session() as session:
  537. config = session.query(SystemConfig).filter(SystemConfig.config_key == config_key).first()
  538. if config:
  539. return {
  540. "id": config.id,
  541. "config_key": config.config_key,
  542. "config_value": config.config_value,
  543. "config_type": config.config_type,
  544. "config_description": config.config_description,
  545. "is_public": config.is_public,
  546. "created_at": config.created_at,
  547. "updated_at": config.updated_at
  548. }
  549. return None
  550. def list_configs(self, is_public: Optional[bool] = None) -> List[SystemConfig]:
  551. """获取配置列表"""
  552. with self.db_connection.get_session() as session:
  553. query = session.query(SystemConfig)
  554. if is_public is not None:
  555. query = query.filter(SystemConfig.is_public == is_public)
  556. return query.order_by(SystemConfig.config_key).all()
  557. def delete_config(self, config_key: str) -> bool:
  558. """删除配置"""
  559. with self.db_connection.get_session() as session:
  560. config = session.query(SystemConfig).filter(SystemConfig.config_key == config_key).first()
  561. if config:
  562. session.delete(config)
  563. logger.info(f"删除配置成功: {config_key}")
  564. return True
  565. return False
  566. # ==================== 工具方法 ====================
  567. def _calculate_image_hash(self, image_data: bytes) -> str:
  568. """计算图片哈希值"""
  569. return hashlib.md5(image_data).hexdigest()
  570. def get_statistics(self) -> Dict[str, Any]:
  571. """获取系统统计信息"""
  572. with self.db_connection.get_session() as session:
  573. user_count = session.query(User).count()
  574. active_user_count = session.query(User).filter(User.is_active == True).count()
  575. process_count = session.query(ProcessRecord).count()
  576. # AI换脸相关统计
  577. face_images = session.query(ImageRecord).filter(
  578. and_(ImageRecord.image_type == "face", ImageRecord.is_deleted == False)
  579. ).count()
  580. cloth_images = session.query(ImageRecord).filter(
  581. and_(ImageRecord.image_type == "cloth", ImageRecord.is_deleted == False)
  582. ).count()
  583. result_images = session.query(ImageRecord).filter(
  584. and_(ImageRecord.image_type == "result", ImageRecord.is_deleted == False)
  585. ).count()
  586. return {
  587. "total_users": user_count,
  588. "active_users": active_user_count,
  589. "total_processes": process_count,
  590. "face_images": face_images,
  591. "cloth_images": cloth_images,
  592. "result_images": result_images
  593. }