| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650 |
- """
- 数据库操作封装模块
- 提供完整的CRUD操作功能
- 包括用户管理、图片管理、处理记录管理、系统配置管理、数据库版本管理
- """
- import json
- import hashlib
- import uuid
- from datetime import datetime
- from typing import Optional, List, Dict, Any, Union
- from sqlalchemy.orm import Session
- from sqlalchemy import and_, or_, desc, asc
- from sqlalchemy.exc import IntegrityError
- from .connection import DatabaseConnection
- from .models import User, ImageRecord, TextRecord, ProcessRecord, SystemConfig, DatabaseVersion
- from backend.utils.logger_config import setup_logger
- logger = setup_logger(__name__)
- class DatabaseOperations:
- """数据库操作封装类"""
- def __init__(self, db_connection: Optional[DatabaseConnection] =None):
- """
- 初始化数据库操作
- Args:
- db_connection: 数据库连接对象,如果为None则使用默认连接
- """
- self.db_connection = db_connection or DatabaseConnection()
- # ==================== 用户相关操作 ====================
- def create_user(self, username: str, password_hash: str, is_admin: bool =False) -> dict:
- """
- 创建用户
- Args:
- username: 用户名
- password_hash: 密码哈希
- is_admin: 是否管理员
- Returns:
- dict: 创建的用户信息
- """
- with self.db_connection.get_session() as session:
- user = User(
- username=username,
- password_hash=password_hash,
- is_admin=is_admin
- )
- session.add(user)
- session.flush()
- session.refresh(user)
- result = {
- "id": user.id,
- "username": user.username,
- "is_admin": user.is_admin,
- "is_active": user.is_active,
- "created_at": user.created_at,
- "updated_at": user.updated_at,
- "last_login": user.last_login
- }
- return result
- def get_user_by_id(self, user_id: int) -> dict:
- """根据ID获取用户"""
- with self.db_connection.get_session() as session:
- user = session.query(User).filter(User.id == user_id).first()
- if not user:
- return None
- return {
- "id": user.id,
- "username": user.username,
- "password_hash": user.password_hash,
- "is_admin": user.is_admin,
- "is_active": user.is_active,
- "created_at": user.created_at,
- "updated_at": user.updated_at,
- "last_login": user.last_login
- }
- def get_user_by_username(self, username: str) -> dict:
- """根据用户名获取用户"""
- with self.db_connection.get_session() as session:
- user = session.query(User).filter(User.username == username).first()
- if not user:
- return None
- return {
- "id": user.id,
- "username": user.username,
- "password_hash": user.password_hash,
- "is_admin": user.is_admin,
- "is_active": user.is_active,
- "created_at": user.created_at,
- "updated_at": user.updated_at,
- "last_login": user.last_login
- }
- def update_user(self, user_id: int, **kwargs) -> dict:
- """更新用户信息"""
- with self.db_connection.get_session() as session:
- user = session.query(User).filter(User.id == user_id).first()
- if user:
- for key, value in kwargs.items():
- if hasattr(user, key):
- setattr(user, key, value)
- user.updated_at = datetime.now()
- session.flush()
- logger.info(f"更新用户成功: {user_id}")
- return {
- "id": user.id,
- "username": user.username,
- "password_hash": user.password_hash,
- "is_admin": user.is_admin,
- "is_active": user.is_active,
- "created_at": user.created_at,
- "updated_at": user.updated_at,
- "last_login": user.last_login
- }
- return None
- def delete_user(self, user_id: int) -> bool:
- """删除用户"""
- with self.db_connection.get_session() as session:
- user = session.query(User).filter(User.id == user_id).first()
- if user:
- session.delete(user)
- logger.info(f"删除用户成功: {user_id}")
- return True
- return False
- # ==================== 图片记录相关操作 ====================
- def create_image_record(self, user_id: int, image_type: str, original_filename: str, stored_path: str,
- file_size: Optional[int] = None, image_hash: Optional[str] = None) -> dict:
- """
- 创建图片记录
- Args:
- user_id: 用户ID
- image_type: 图片类型(face/cloth/result)
- original_filename: 原始文件名
- stored_path: 存储路径
- file_size: 文件大小
- image_hash: 图片哈希
- Returns:
- dict: 创建的图片记录信息
- """
- with self.db_connection.get_session() as session:
- image_record = ImageRecord(
- user_id=user_id,
- image_type=image_type,
- original_filename=original_filename,
- stored_path=stored_path,
- file_size=file_size,
- image_hash=image_hash
- )
- session.add(image_record)
- session.flush()
- session.refresh(image_record)
- result = {
- "id": image_record.id,
- "user_id": image_record.user_id,
- "image_type": image_record.image_type,
- "original_filename": image_record.original_filename,
- "stored_path": image_record.stored_path,
- "file_size": image_record.file_size,
- "image_hash": image_record.image_hash,
- "is_deleted": image_record.is_deleted,
- "created_at": image_record.created_at,
- "updated_at": image_record.updated_at
- }
- return result
-
- def get_image_record_by_id(self, image_id: int) -> dict:
- """根据ID获取图片记录"""
- with self.db_connection.get_session() as session:
- image_record = session.query(ImageRecord).filter(ImageRecord.id == image_id).first()
- if not image_record:
- return None
- return {
- "id": image_record.id,
- "user_id": image_record.user_id,
- "image_type": image_record.image_type,
- "original_filename": image_record.original_filename,
- "stored_path": image_record.stored_path,
- "file_size": image_record.file_size,
- "image_hash": image_record.image_hash,
- "is_deleted": image_record.is_deleted,
- "created_at": image_record.created_at,
- "updated_at": image_record.updated_at
- }
- def get_user_images(self, user_id: int, image_type: Optional[str] = None,
- page: int = 1, page_size: int = 20) -> Dict[str, Any]:
- """获取用户的图片记录"""
- with self.db_connection.get_session() as session:
- query = session.query(ImageRecord).filter(
- and_(ImageRecord.user_id == user_id, ImageRecord.is_deleted == False)
- )
- if image_type:
- query = query.filter(ImageRecord.image_type == image_type)
- total = query.count()
- images = query.order_by(desc(ImageRecord.created_at)).offset(
- (page - 1) * page_size
- ).limit(page_size).all()
- return {
- "images": [
- {
- "id": img.id,
- "user_id": img.user_id,
- "image_type": img.image_type,
- "original_filename": img.original_filename,
- "stored_path": img.stored_path,
- "file_size": img.file_size,
- "image_hash": img.image_hash,
- "is_deleted": img.is_deleted,
- "created_at": img.created_at,
- "updated_at": img.updated_at
- }
- for img in images
- ],
- "total": total,
- "page": page,
- "page_size": page_size,
- "total_pages": (total + page_size - 1) // page_size
- }
- def update_image_record(self, image_id: int, **kwargs) -> dict:
- """更新图片记录"""
- with self.db_connection.get_session() as session:
- image_record = session.query(ImageRecord).filter(ImageRecord.id == image_id).first()
- if image_record:
- for key, value in kwargs.items():
- if hasattr(image_record, key):
- setattr(image_record, key, value)
- image_record.updated_at = datetime.now()
- session.flush()
- logger.info(f"更新图片记录成功: {image_id}")
- return {
- "id": image_record.id,
- "user_id": image_record.user_id,
- "image_type": image_record.image_type,
- "original_filename": image_record.original_filename,
- "stored_path": image_record.stored_path,
- "file_size": image_record.file_size,
- "image_hash": image_record.image_hash,
- "is_deleted": image_record.is_deleted,
- "created_at": image_record.created_at,
- "updated_at": image_record.updated_at
- }
- return None
- def delete_image_record(self, image_id: int, soft_delete: bool = True) -> bool:
- """删除图片记录"""
- with self.db_connection.get_session() as session:
- image_record = session.query(ImageRecord).filter(ImageRecord.id == image_id).first()
- if image_record:
- if soft_delete:
- image_record.is_deleted = True
- image_record.updated_at = datetime.now()
- else:
- session.delete(image_record)
- logger.info(f"删除图片记录成功: {image_id}")
- return True
- return False
- # ==================== 文本模板相关操作 ====================
- def create_text_record(self, user_id: int, text_type: str, text_name: str, text_label: str, text_content: str) -> dict:
- """
- 创建文本模板记录
- Args:
- user_id: 用户ID
- text_type: 文本模板类型(prompt/copywrite)
- text_name: 文本名称
- text_label: 文本模板标签
- text_content: 文本模板内容
- Returns:
- dict: 创建的文本模板记录信息
- """
- with self.db_connection.get_session() as session:
- text_record = TextRecord(
- user_id=user_id,
- text_type=text_type,
- text_name=text_name,
- text_label=text_label,
- text_content=text_content
- )
- session.add(text_record)
- session.flush()
- session.refresh(text_record)
- result = {
- "id": text_record.id,
- "user_id": text_record.user_id,
- "text_type": text_record.text_type,
- "text_name": text_record.text_name,
- "text_label": text_record.text_label,
- "text_content": text_record.text_content,
- "created_at": text_record.created_at,
- }
- return result
- def get_text_record_by_id(self, text_id: int) -> dict:
- """根据ID获取文本模板记录"""
- with self.db_connection.get_session() as session:
- text_record = session.query(TextRecord).filter(TextRecord.id == text_id).first()
- if not text_record:
- return None
- return {
- "id": text_record.id,
- "user_id": text_record.user_id,
- "text_type": text_record.text_type,
- "text_name": text_record.text_name,
- "text_label": text_record.text_label,
- "text_content": text_record.text_content,
- "created_at": text_record.created_at,
- }
- def get_user_text_records(self, user_id: int, text_type: Optional[str] = None, page: int = 1, page_size: int = 20) -> Dict[str, Any]:
- """获取用户的文本模板记录"""
- with self.db_connection.get_session() as session:
- query = session.query(TextRecord).filter(TextRecord.user_id == user_id)
- if text_type:
- query = query.filter(TextRecord.text_type == text_type)
- total = query.count()
- text_records = query.order_by(desc(TextRecord.created_at)).offset(
- (page - 1) * page_size
- ).limit(page_size).all()
- return {
- "records": [
- {
- "id": record.id,
- "user_id": record.user_id,
- "text_type": record.text_type,
- "text_name": record.text_name,
- "text_label": record.text_label,
- "text_content": record.text_content,
- "created_at": record.created_at,
- }
- for record in text_records
- ],
- "total": total,
- "page": page,
- "page_size": page_size,
- "total_pages": (total + page_size - 1) // page_size
- }
- def update_text_record(self, text_id: int, **kwargs) -> dict:
- """更新文本模板记录"""
- with self.db_connection.get_session() as session:
- text_record = session.query(TextRecord).filter(TextRecord.id == text_id).first()
- if not text_record:
- return None
- for key, value in kwargs.items():
- if hasattr(text_record, key):
- setattr(text_record, key, value)
- # 更新时间
- if hasattr(text_record, "updated_at"):
- text_record.updated_at = datetime.now()
- session.flush()
- session.commit()
- logger.info(f"更新文本模板记录成功: {text_id}")
- return {
- "id": text_record.id,
- "user_id": text_record.user_id,
- "text_type": text_record.text_type,
- "text_name": text_record.text_name,
- "text_label": text_record.text_label,
- "text_content": text_record.text_content,
- "created_at": text_record.created_at,
- }
- def delete_text_record(self, text_id: int) -> bool:
- """删除文本模板记录"""
- with self.db_connection.get_session() as session:
- text_record = session.query(TextRecord).filter(TextRecord.id == text_id).first()
- if text_record:
- session.delete(text_record)
- logger.info(f"删除文本模板记录成功: {text_id}")
- return True
- return False
- # ==================== 处理记录相关操作 ====================
- def create_process_record(self, user_id: int, face_image_id: int, cloth_image_id: int,
- result_image_id: int, generated_text: Optional[str] = None, status: str = "待审核", task_type: str = "swap_face", prompt: str = "") -> dict:
- """
- 创建AI任务处理记录
- Args:
- user_id: 用户ID
- face_image_id: 人脸图片ID
- cloth_image_id: 服装图片ID
- result_image_id: 结果图片ID
- generated_text: AI生成的文案内容
- status: 处理记录状态
- task_type: 任务类型
- prompt: 提示词
- Returns:
- dict: 创建的处理记录信息
- """
- with self.db_connection.get_session() as session:
- process_record = ProcessRecord(
- user_id=user_id,
- face_image_id=face_image_id,
- cloth_image_id=cloth_image_id,
- result_image_id=result_image_id,
- generated_text=generated_text,
- status=status,
- task_type=task_type,
- prompt=prompt,
- completed_at=datetime.now()
- )
- session.add(process_record)
- session.flush()
- session.refresh(process_record)
- result = {
- "id": process_record.id,
- "user_id": process_record.user_id,
- "face_image_id": process_record.face_image_id,
- "cloth_image_id": process_record.cloth_image_id,
- "result_image_id": process_record.result_image_id,
- "generated_text": process_record.generated_text,
- "status": process_record.status,
- "task_type": process_record.task_type,
- "prompt": process_record.prompt,
- "completed_at": process_record.completed_at
- }
- return result
- def get_process_record_by_id(self, process_id: int) -> dict:
- """根据ID获取处理记录"""
- with self.db_connection.get_session() as session:
- process_record = session.query(ProcessRecord).filter(ProcessRecord.id == process_id).first()
- if not process_record:
- return None
- return {
- "id": process_record.id,
- "user_id": process_record.user_id,
- "face_image_id": process_record.face_image_id,
- "cloth_image_id": process_record.cloth_image_id,
- "result_image_id": process_record.result_image_id,
- "generated_text": process_record.generated_text,
- "status": process_record.status,
- "task_type": process_record.task_type,
- "prompt": process_record.prompt,
- "completed_at": process_record.completed_at
- }
- def get_user_process_records(self, user_id: int, page: int = 1, page_size: int = 20) -> Dict[str, Any]:
- """获取用户的处理记录"""
- with self.db_connection.get_session() as session:
- query = session.query(ProcessRecord).filter(ProcessRecord.user_id == user_id)
- total = query.count()
- records = query.order_by(desc(ProcessRecord.completed_at)).offset(
- (page - 1) * page_size
- ).limit(page_size).all()
- return {
- "records": [
- {
- "id": record.id,
- "user_id": record.user_id,
- "face_image_id": record.face_image_id,
- "cloth_image_id": record.cloth_image_id,
- "result_image_id": record.result_image_id,
- "generated_text": record.generated_text,
- "status": record.status,
- "task_type": record.task_type,
- "prompt": record.prompt,
- "completed_at": record.completed_at
- }
- for record in records
- ],
- "total": total,
- "page": page,
- "page_size": page_size,
- "total_pages": (total + page_size - 1) // page_size
- }
- def update_process_record(self, process_id: int, update_data: Dict[str, Any]) -> dict:
- """
- 更新处理记录
-
- Args:
- process_id: 处理记录ID
- update_data: 要更新的字段和值的字典
-
- Returns:
- dict: 更新后的处理记录信息,如果记录不存在则返回None
- """
- with self.db_connection.get_session() as session:
- process_record = session.query(ProcessRecord).filter(ProcessRecord.id == process_id).first()
- if process_record:
- for key, value in update_data.items():
- if hasattr(process_record, key):
- setattr(process_record, key, value)
- session.flush()
- session.commit()
- logger.info(f"更新处理记录成功: {process_id}")
- return {
- "id": process_record.id,
- "user_id": process_record.user_id,
- "face_image_id": process_record.face_image_id,
- "cloth_image_id": process_record.cloth_image_id,
- "result_image_id": process_record.result_image_id,
- "generated_text": process_record.generated_text,
- "status": process_record.status,
- "task_type": process_record.task_type,
- "prompt": process_record.prompt,
- "completed_at": process_record.completed_at
- }
- return None
- def delete_process_record(self, process_id: int) -> bool:
- """删除处理记录"""
- with self.db_connection.get_session() as session:
- process_record = session.query(ProcessRecord).filter(ProcessRecord.id == process_id).first()
- if process_record:
- session.delete(process_record)
- logger.info(f"删除处理记录成功: {process_id}")
- return True
- return False
- # ==================== 系统配置相关操作 ====================
-
- def set_config(self, config_key: str, config_value: Any, config_type: str = "string",
- config_description: Optional[str] = None, is_public: bool = True) -> SystemConfig:
- """
- 设置系统配置
-
- Args:
- config_key: 配置键
- config_value: 配置值
- config_type: 配置类型
- config_description: 配置描述
- is_public: 是否公开
-
- Returns:
- SystemConfig: 配置对象
- """
- with self.db_connection.get_session() as session:
- # 序列化配置值
- if config_type == "json" and isinstance(config_value, (dict, list)):
- config_value = json.dumps(config_value, ensure_ascii=False)
- elif config_type in ["int", "float", "bool"]:
- config_value = str(config_value)
-
- # 查找现有配置
- config = session.query(SystemConfig).filter(SystemConfig.config_key == config_key).first()
-
- if config:
- # 更新现有配置
- config.config_value = config_value
- config.config_type = config_type
- config.config_description = config_description
- config.is_public = is_public
- config.updated_at = datetime.now()
- else:
- # 创建新配置
- config = SystemConfig(
- config_key=config_key,
- config_value=config_value,
- config_type=config_type,
- config_description=config_description,
- is_public=is_public
- )
- session.add(config)
-
- session.flush()
- session.refresh(config)
- logger.info(f"设置配置成功: {config_key}")
- return config
-
- def get_config(self, config_key: str) -> Optional[SystemConfig]:
- """获取系统配置"""
- with self.db_connection.get_session() as session:
- config = session.query(SystemConfig).filter(SystemConfig.config_key == config_key).first()
- if config:
- return {
- "id": config.id,
- "config_key": config.config_key,
- "config_value": config.config_value,
- "config_type": config.config_type,
- "config_description": config.config_description,
- "is_public": config.is_public,
- "created_at": config.created_at,
- "updated_at": config.updated_at
- }
- return None
-
- def list_configs(self, is_public: Optional[bool] = None) -> List[SystemConfig]:
- """获取配置列表"""
- with self.db_connection.get_session() as session:
- query = session.query(SystemConfig)
- if is_public is not None:
- query = query.filter(SystemConfig.is_public == is_public)
- return query.order_by(SystemConfig.config_key).all()
-
- def delete_config(self, config_key: str) -> bool:
- """删除配置"""
- with self.db_connection.get_session() as session:
- config = session.query(SystemConfig).filter(SystemConfig.config_key == config_key).first()
- if config:
- session.delete(config)
- logger.info(f"删除配置成功: {config_key}")
- return True
- return False
- # ==================== 工具方法 ====================
- def _calculate_image_hash(self, image_data: bytes) -> str:
- """计算图片哈希值"""
- return hashlib.md5(image_data).hexdigest()
- def get_statistics(self) -> Dict[str, Any]:
- """获取系统统计信息"""
- with self.db_connection.get_session() as session:
- user_count = session.query(User).count()
- active_user_count = session.query(User).filter(User.is_active == True).count()
- process_count = session.query(ProcessRecord).count()
- # AI换脸相关统计
- face_images = session.query(ImageRecord).filter(
- and_(ImageRecord.image_type == "face", ImageRecord.is_deleted == False)
- ).count()
- cloth_images = session.query(ImageRecord).filter(
- and_(ImageRecord.image_type == "cloth", ImageRecord.is_deleted == False)
- ).count()
- result_images = session.query(ImageRecord).filter(
- and_(ImageRecord.image_type == "result", ImageRecord.is_deleted == False)
- ).count()
- return {
- "total_users": user_count,
- "active_users": active_user_count,
- "total_processes": process_count,
- "face_images": face_images,
- "cloth_images": cloth_images,
- "result_images": result_images
- }
|