""" 数据库连接管理模块 提供数据库连接池和配置管理功能 """ import os import sqlite3 from typing import Optional, Dict, Any from contextlib import contextmanager from sqlalchemy import create_engine, Engine, text from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.pool import QueuePool from backend.utils.logger_config import setup_logger logger = setup_logger(__name__) class DatabaseConnection: """数据库连接管理类""" def __init__(self, db_type: str = "sqlite", **kwargs): """ 初始化数据库连接 Args: db_type: 数据库类型 ("sqlite" 或 "mysql") **kwargs: 数据库连接参数 """ self.db_type = db_type.lower() self.engine: Optional[Engine] = None self.SessionLocal: Optional[sessionmaker] = None self.connection_params = kwargs self._setup_connection() def _setup_connection(self): """设置数据库连接""" try: if self.db_type == "sqlite": self._setup_sqlite() elif self.db_type == "mysql": self._setup_mysql() else: raise ValueError(f"不支持的数据库类型: {self.db_type}") # 创建会话工厂 self.SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=self.engine ) logger.info(f"数据库连接初始化成功: {self.db_type}") except Exception as e: logger.error(f"数据库连接初始化失败: {e}") raise def _setup_sqlite(self): """设置SQLite连接""" db_path = self.connection_params.get('database', './backend/db/ai_swap.db') # 确保数据库目录存在 db_dir = os.path.dirname(db_path) if db_dir and not os.path.exists(db_dir): os.makedirs(db_dir) # 创建SQLite引擎 self.engine = create_engine( f"sqlite:///{db_path}", connect_args={"check_same_thread": False}, poolclass=QueuePool, pool_size=10, max_overflow=20, pool_pre_ping=True, echo=False # 设置为True可以看到SQL语句 ) def _setup_mysql(self): """设置MySQL连接""" host = self.connection_params.get('host', 'localhost') port = self.connection_params.get('port', 3306) database = self.connection_params.get('database', 'ai_swap') username = self.connection_params.get('username', 'root') password = self.connection_params.get('password', '') charset = self.connection_params.get('charset', 'utf8mb4') # 构建MySQL连接URL mysql_url = f"mysql+pymysql://{username}:{password}@{host}:{port}/{database}?charset={charset}" # 创建MySQL引擎 self.engine = create_engine( mysql_url, poolclass=QueuePool, pool_size=10, max_overflow=20, pool_pre_ping=True, pool_recycle=3600, # 连接回收时间 echo=False ) @contextmanager def get_session(self) -> Session: """ 获取数据库会话的上下文管理器 Yields: Session: 数据库会话对象 """ session = self.SessionLocal() try: yield session session.commit() except Exception as e: session.rollback() logger.error(f"数据库操作失败: {e}") raise finally: session.close() def get_engine(self) -> Engine: """获取数据库引擎""" return self.engine def test_connection(self) -> bool: """ 测试数据库连接 Returns: bool: 连接是否成功 """ try: with self.get_session() as session: if self.db_type == "sqlite": session.execute(text("SELECT 1")) else: session.execute(text("SELECT 1")) logger.info("数据库连接测试成功") return True except Exception as e: logger.error(f"数据库连接测试失败: {e}") return False def close(self): """关闭数据库连接""" if self.engine: self.engine.dispose() logger.info("数据库连接已关闭") # 默认数据库连接实例 default_db = DatabaseConnection( db_type="sqlite", database="backend/data/ai_swap.db" )