""" 数据库操作封装模块 提供完整的CRUD操作功能 包括用户管理、图片管理、处理记录管理、系统配置管理、数据库版本管理 """ import json import hashlib import uuid from datetime import datetime from typing import Optional, List, Dict, Any, Union from sqlalchemy.orm import Session from sqlalchemy import and_, or_, desc, asc from sqlalchemy.exc import IntegrityError from .connection import DatabaseConnection from .models import User, ImageRecord, TextRecord, ProcessRecord, SystemConfig, DatabaseVersion from backend.utils.logger_config import setup_logger logger = setup_logger(__name__) class DatabaseOperations: """数据库操作封装类""" def __init__(self, db_connection: Optional[DatabaseConnection] =None): """ 初始化数据库操作 Args: db_connection: 数据库连接对象,如果为None则使用默认连接 """ self.db_connection = db_connection or DatabaseConnection() # ==================== 用户相关操作 ==================== def create_user(self, username: str, password_hash: str, is_admin: bool =False) -> dict: """ 创建用户 Args: username: 用户名 password_hash: 密码哈希 is_admin: 是否管理员 Returns: dict: 创建的用户信息 """ with self.db_connection.get_session() as session: user = User( username=username, password_hash=password_hash, is_admin=is_admin ) session.add(user) session.flush() session.refresh(user) result = { "id": user.id, "username": user.username, "is_admin": user.is_admin, "is_active": user.is_active, "created_at": user.created_at, "updated_at": user.updated_at, "last_login": user.last_login } return result def get_user_by_id(self, user_id: int) -> dict: """根据ID获取用户""" with self.db_connection.get_session() as session: user = session.query(User).filter(User.id == user_id).first() if not user: return None return { "id": user.id, "username": user.username, "password_hash": user.password_hash, "is_admin": user.is_admin, "is_active": user.is_active, "created_at": user.created_at, "updated_at": user.updated_at, "last_login": user.last_login } def get_user_by_username(self, username: str) -> dict: """根据用户名获取用户""" with self.db_connection.get_session() as session: user = session.query(User).filter(User.username == username).first() if not user: return None return { "id": user.id, "username": user.username, "password_hash": user.password_hash, "is_admin": user.is_admin, "is_active": user.is_active, "created_at": user.created_at, "updated_at": user.updated_at, "last_login": user.last_login } def update_user(self, user_id: int, **kwargs) -> dict: """更新用户信息""" with self.db_connection.get_session() as session: user = session.query(User).filter(User.id == user_id).first() if user: for key, value in kwargs.items(): if hasattr(user, key): setattr(user, key, value) user.updated_at = datetime.now() session.flush() logger.info(f"更新用户成功: {user_id}") return { "id": user.id, "username": user.username, "password_hash": user.password_hash, "is_admin": user.is_admin, "is_active": user.is_active, "created_at": user.created_at, "updated_at": user.updated_at, "last_login": user.last_login } return None def delete_user(self, user_id: int) -> bool: """删除用户""" with self.db_connection.get_session() as session: user = session.query(User).filter(User.id == user_id).first() if user: session.delete(user) logger.info(f"删除用户成功: {user_id}") return True return False # ==================== 图片记录相关操作 ==================== def create_image_record(self, user_id: int, image_type: str, original_filename: str, stored_path: str, file_size: Optional[int] = None, image_hash: Optional[str] = None) -> dict: """ 创建图片记录 Args: user_id: 用户ID image_type: 图片类型(face/cloth/result) original_filename: 原始文件名 stored_path: 存储路径 file_size: 文件大小 image_hash: 图片哈希 Returns: dict: 创建的图片记录信息 """ with self.db_connection.get_session() as session: image_record = ImageRecord( user_id=user_id, image_type=image_type, original_filename=original_filename, stored_path=stored_path, file_size=file_size, image_hash=image_hash ) session.add(image_record) session.flush() session.refresh(image_record) result = { "id": image_record.id, "user_id": image_record.user_id, "image_type": image_record.image_type, "original_filename": image_record.original_filename, "stored_path": image_record.stored_path, "file_size": image_record.file_size, "image_hash": image_record.image_hash, "is_deleted": image_record.is_deleted, "created_at": image_record.created_at, "updated_at": image_record.updated_at } return result def get_image_record_by_id(self, image_id: int) -> dict: """根据ID获取图片记录""" with self.db_connection.get_session() as session: image_record = session.query(ImageRecord).filter(ImageRecord.id == image_id).first() if not image_record: return None return { "id": image_record.id, "user_id": image_record.user_id, "image_type": image_record.image_type, "original_filename": image_record.original_filename, "stored_path": image_record.stored_path, "file_size": image_record.file_size, "image_hash": image_record.image_hash, "is_deleted": image_record.is_deleted, "created_at": image_record.created_at, "updated_at": image_record.updated_at } def get_user_images(self, user_id: int, image_type: Optional[str] = None, page: int = 1, page_size: int = 20) -> Dict[str, Any]: """获取用户的图片记录""" with self.db_connection.get_session() as session: query = session.query(ImageRecord).filter( and_(ImageRecord.user_id == user_id, ImageRecord.is_deleted == False) ) if image_type: query = query.filter(ImageRecord.image_type == image_type) total = query.count() images = query.order_by(desc(ImageRecord.created_at)).offset( (page - 1) * page_size ).limit(page_size).all() return { "images": [ { "id": img.id, "user_id": img.user_id, "image_type": img.image_type, "original_filename": img.original_filename, "stored_path": img.stored_path, "file_size": img.file_size, "image_hash": img.image_hash, "is_deleted": img.is_deleted, "created_at": img.created_at, "updated_at": img.updated_at } for img in images ], "total": total, "page": page, "page_size": page_size, "total_pages": (total + page_size - 1) // page_size } def update_image_record(self, image_id: int, **kwargs) -> dict: """更新图片记录""" with self.db_connection.get_session() as session: image_record = session.query(ImageRecord).filter(ImageRecord.id == image_id).first() if image_record: for key, value in kwargs.items(): if hasattr(image_record, key): setattr(image_record, key, value) image_record.updated_at = datetime.now() session.flush() logger.info(f"更新图片记录成功: {image_id}") return { "id": image_record.id, "user_id": image_record.user_id, "image_type": image_record.image_type, "original_filename": image_record.original_filename, "stored_path": image_record.stored_path, "file_size": image_record.file_size, "image_hash": image_record.image_hash, "is_deleted": image_record.is_deleted, "created_at": image_record.created_at, "updated_at": image_record.updated_at } return None def delete_image_record(self, image_id: int, soft_delete: bool = True) -> bool: """删除图片记录""" with self.db_connection.get_session() as session: image_record = session.query(ImageRecord).filter(ImageRecord.id == image_id).first() if image_record: if soft_delete: image_record.is_deleted = True image_record.updated_at = datetime.now() else: session.delete(image_record) logger.info(f"删除图片记录成功: {image_id}") return True return False # ==================== 文本模板相关操作 ==================== def create_text_record(self, user_id: int, text_type: str, text_name: str, text_label: str, text_content: str) -> dict: """ 创建文本模板记录 Args: user_id: 用户ID text_type: 文本模板类型(prompt/copywrite) text_name: 文本名称 text_label: 文本模板标签 text_content: 文本模板内容 Returns: dict: 创建的文本模板记录信息 """ with self.db_connection.get_session() as session: text_record = TextRecord( user_id=user_id, text_type=text_type, text_name=text_name, text_label=text_label, text_content=text_content ) session.add(text_record) session.flush() session.refresh(text_record) result = { "id": text_record.id, "user_id": text_record.user_id, "text_type": text_record.text_type, "text_name": text_record.text_name, "text_label": text_record.text_label, "text_content": text_record.text_content, "created_at": text_record.created_at, } return result def get_text_record_by_id(self, text_id: int) -> dict: """根据ID获取文本模板记录""" with self.db_connection.get_session() as session: text_record = session.query(TextRecord).filter(TextRecord.id == text_id).first() if not text_record: return None return { "id": text_record.id, "user_id": text_record.user_id, "text_type": text_record.text_type, "text_name": text_record.text_name, "text_label": text_record.text_label, "text_content": text_record.text_content, "created_at": text_record.created_at, } def get_user_text_records(self, user_id: int, text_type: Optional[str] = None, page: int = 1, page_size: int = 20) -> Dict[str, Any]: """获取用户的文本模板记录""" with self.db_connection.get_session() as session: query = session.query(TextRecord).filter(TextRecord.user_id == user_id) if text_type: query = query.filter(TextRecord.text_type == text_type) total = query.count() text_records = query.order_by(desc(TextRecord.created_at)).offset( (page - 1) * page_size ).limit(page_size).all() return { "records": [ { "id": record.id, "user_id": record.user_id, "text_type": record.text_type, "text_name": record.text_name, "text_label": record.text_label, "text_content": record.text_content, "created_at": record.created_at, } for record in text_records ], "total": total, "page": page, "page_size": page_size, "total_pages": (total + page_size - 1) // page_size } def update_text_record(self, text_id: int, **kwargs) -> dict: """更新文本模板记录""" with self.db_connection.get_session() as session: text_record = session.query(TextRecord).filter(TextRecord.id == text_id).first() if not text_record: return None for key, value in kwargs.items(): if hasattr(text_record, key): setattr(text_record, key, value) # 更新时间 if hasattr(text_record, "updated_at"): text_record.updated_at = datetime.now() session.flush() session.commit() logger.info(f"更新文本模板记录成功: {text_id}") return { "id": text_record.id, "user_id": text_record.user_id, "text_type": text_record.text_type, "text_name": text_record.text_name, "text_label": text_record.text_label, "text_content": text_record.text_content, "created_at": text_record.created_at, } def delete_text_record(self, text_id: int) -> bool: """删除文本模板记录""" with self.db_connection.get_session() as session: text_record = session.query(TextRecord).filter(TextRecord.id == text_id).first() if text_record: session.delete(text_record) logger.info(f"删除文本模板记录成功: {text_id}") return True return False # ==================== 处理记录相关操作 ==================== def create_process_record(self, user_id: int, face_image_id: int, cloth_image_id: int, result_image_id: int, generated_text: Optional[str] = None, status: str = "待审核", task_type: str = "swap_face", prompt: str = "") -> dict: """ 创建AI任务处理记录 Args: user_id: 用户ID face_image_id: 人脸图片ID cloth_image_id: 服装图片ID result_image_id: 结果图片ID generated_text: AI生成的文案内容 status: 处理记录状态 task_type: 任务类型 prompt: 提示词 Returns: dict: 创建的处理记录信息 """ with self.db_connection.get_session() as session: process_record = ProcessRecord( user_id=user_id, face_image_id=face_image_id, cloth_image_id=cloth_image_id, result_image_id=result_image_id, generated_text=generated_text, status=status, task_type=task_type, prompt=prompt, completed_at=datetime.now() ) session.add(process_record) session.flush() session.refresh(process_record) result = { "id": process_record.id, "user_id": process_record.user_id, "face_image_id": process_record.face_image_id, "cloth_image_id": process_record.cloth_image_id, "result_image_id": process_record.result_image_id, "generated_text": process_record.generated_text, "status": process_record.status, "task_type": process_record.task_type, "prompt": process_record.prompt, "completed_at": process_record.completed_at } return result def get_process_record_by_id(self, process_id: int) -> dict: """根据ID获取处理记录""" with self.db_connection.get_session() as session: process_record = session.query(ProcessRecord).filter(ProcessRecord.id == process_id).first() if not process_record: return None return { "id": process_record.id, "user_id": process_record.user_id, "face_image_id": process_record.face_image_id, "cloth_image_id": process_record.cloth_image_id, "result_image_id": process_record.result_image_id, "generated_text": process_record.generated_text, "status": process_record.status, "task_type": process_record.task_type, "prompt": process_record.prompt, "completed_at": process_record.completed_at } def get_user_process_records(self, user_id: int, page: int = 1, page_size: int = 20) -> Dict[str, Any]: """获取用户的处理记录""" with self.db_connection.get_session() as session: query = session.query(ProcessRecord).filter(ProcessRecord.user_id == user_id) total = query.count() records = query.order_by(desc(ProcessRecord.completed_at)).offset( (page - 1) * page_size ).limit(page_size).all() return { "records": [ { "id": record.id, "user_id": record.user_id, "face_image_id": record.face_image_id, "cloth_image_id": record.cloth_image_id, "result_image_id": record.result_image_id, "generated_text": record.generated_text, "status": record.status, "task_type": record.task_type, "prompt": record.prompt, "completed_at": record.completed_at } for record in records ], "total": total, "page": page, "page_size": page_size, "total_pages": (total + page_size - 1) // page_size } def update_process_record(self, process_id: int, update_data: Dict[str, Any]) -> dict: """ 更新处理记录 Args: process_id: 处理记录ID update_data: 要更新的字段和值的字典 Returns: dict: 更新后的处理记录信息,如果记录不存在则返回None """ with self.db_connection.get_session() as session: process_record = session.query(ProcessRecord).filter(ProcessRecord.id == process_id).first() if process_record: for key, value in update_data.items(): if hasattr(process_record, key): setattr(process_record, key, value) session.flush() session.commit() logger.info(f"更新处理记录成功: {process_id}") return { "id": process_record.id, "user_id": process_record.user_id, "face_image_id": process_record.face_image_id, "cloth_image_id": process_record.cloth_image_id, "result_image_id": process_record.result_image_id, "generated_text": process_record.generated_text, "status": process_record.status, "task_type": process_record.task_type, "prompt": process_record.prompt, "completed_at": process_record.completed_at } return None def delete_process_record(self, process_id: int) -> bool: """删除处理记录""" with self.db_connection.get_session() as session: process_record = session.query(ProcessRecord).filter(ProcessRecord.id == process_id).first() if process_record: session.delete(process_record) logger.info(f"删除处理记录成功: {process_id}") return True return False # ==================== 系统配置相关操作 ==================== def set_config(self, config_key: str, config_value: Any, config_type: str = "string", config_description: Optional[str] = None, is_public: bool = True) -> SystemConfig: """ 设置系统配置 Args: config_key: 配置键 config_value: 配置值 config_type: 配置类型 config_description: 配置描述 is_public: 是否公开 Returns: SystemConfig: 配置对象 """ with self.db_connection.get_session() as session: # 序列化配置值 if config_type == "json" and isinstance(config_value, (dict, list)): config_value = json.dumps(config_value, ensure_ascii=False) elif config_type in ["int", "float", "bool"]: config_value = str(config_value) # 查找现有配置 config = session.query(SystemConfig).filter(SystemConfig.config_key == config_key).first() if config: # 更新现有配置 config.config_value = config_value config.config_type = config_type config.config_description = config_description config.is_public = is_public config.updated_at = datetime.now() else: # 创建新配置 config = SystemConfig( config_key=config_key, config_value=config_value, config_type=config_type, config_description=config_description, is_public=is_public ) session.add(config) session.flush() session.refresh(config) logger.info(f"设置配置成功: {config_key}") return config def get_config(self, config_key: str) -> Optional[SystemConfig]: """获取系统配置""" with self.db_connection.get_session() as session: config = session.query(SystemConfig).filter(SystemConfig.config_key == config_key).first() if config: return { "id": config.id, "config_key": config.config_key, "config_value": config.config_value, "config_type": config.config_type, "config_description": config.config_description, "is_public": config.is_public, "created_at": config.created_at, "updated_at": config.updated_at } return None def list_configs(self, is_public: Optional[bool] = None) -> List[SystemConfig]: """获取配置列表""" with self.db_connection.get_session() as session: query = session.query(SystemConfig) if is_public is not None: query = query.filter(SystemConfig.is_public == is_public) return query.order_by(SystemConfig.config_key).all() def delete_config(self, config_key: str) -> bool: """删除配置""" with self.db_connection.get_session() as session: config = session.query(SystemConfig).filter(SystemConfig.config_key == config_key).first() if config: session.delete(config) logger.info(f"删除配置成功: {config_key}") return True return False # ==================== 工具方法 ==================== def _calculate_image_hash(self, image_data: bytes) -> str: """计算图片哈希值""" return hashlib.md5(image_data).hexdigest() def get_statistics(self) -> Dict[str, Any]: """获取系统统计信息""" with self.db_connection.get_session() as session: user_count = session.query(User).count() active_user_count = session.query(User).filter(User.is_active == True).count() process_count = session.query(ProcessRecord).count() # AI换脸相关统计 face_images = session.query(ImageRecord).filter( and_(ImageRecord.image_type == "face", ImageRecord.is_deleted == False) ).count() cloth_images = session.query(ImageRecord).filter( and_(ImageRecord.image_type == "cloth", ImageRecord.is_deleted == False) ).count() result_images = session.query(ImageRecord).filter( and_(ImageRecord.image_type == "result", ImageRecord.is_deleted == False) ).count() return { "total_users": user_count, "active_users": active_user_count, "total_processes": process_count, "face_images": face_images, "cloth_images": cloth_images, "result_images": result_images }