| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- """
- 数据库连接管理模块
- 提供数据库连接池和配置管理功能
- """
- import os
- import sqlite3
- from typing import Optional, Dict, Any
- from contextlib import contextmanager
- from sqlalchemy import create_engine, Engine, text
- from sqlalchemy.orm import sessionmaker, Session
- from sqlalchemy.pool import QueuePool
- from backend.utils.logger_config import setup_logger
- logger = setup_logger(__name__)
- class DatabaseConnection:
- """数据库连接管理类"""
-
- def __init__(self, db_type: str = "sqlite", **kwargs):
- """
- 初始化数据库连接
-
- Args:
- db_type: 数据库类型 ("sqlite" 或 "mysql")
- **kwargs: 数据库连接参数
- """
- self.db_type = db_type.lower()
- self.engine: Optional[Engine] = None
- self.SessionLocal: Optional[sessionmaker] = None
- self.connection_params = kwargs
- self._setup_connection()
-
- def _setup_connection(self):
- """设置数据库连接"""
- try:
- if self.db_type == "sqlite":
- self._setup_sqlite()
- elif self.db_type == "mysql":
- self._setup_mysql()
- else:
- raise ValueError(f"不支持的数据库类型: {self.db_type}")
-
- # 创建会话工厂
- self.SessionLocal = sessionmaker(
- autocommit=False,
- autoflush=False,
- bind=self.engine
- )
-
- logger.info(f"数据库连接初始化成功: {self.db_type}")
-
- except Exception as e:
- logger.error(f"数据库连接初始化失败: {e}")
- raise
-
- def _setup_sqlite(self):
- """设置SQLite连接"""
- db_path = self.connection_params.get('database', './backend/db/ai_swap.db')
-
- # 确保数据库目录存在
- db_dir = os.path.dirname(db_path)
- if db_dir and not os.path.exists(db_dir):
- os.makedirs(db_dir)
-
- # 创建SQLite引擎
- self.engine = create_engine(
- f"sqlite:///{db_path}",
- connect_args={"check_same_thread": False},
- poolclass=QueuePool,
- pool_size=10,
- max_overflow=20,
- pool_pre_ping=True,
- echo=False # 设置为True可以看到SQL语句
- )
-
- def _setup_mysql(self):
- """设置MySQL连接"""
- host = self.connection_params.get('host', 'localhost')
- port = self.connection_params.get('port', 3306)
- database = self.connection_params.get('database', 'ai_swap')
- username = self.connection_params.get('username', 'root')
- password = self.connection_params.get('password', '')
- charset = self.connection_params.get('charset', 'utf8mb4')
-
- # 构建MySQL连接URL
- mysql_url = f"mysql+pymysql://{username}:{password}@{host}:{port}/{database}?charset={charset}"
-
- # 创建MySQL引擎
- self.engine = create_engine(
- mysql_url,
- poolclass=QueuePool,
- pool_size=10,
- max_overflow=20,
- pool_pre_ping=True,
- pool_recycle=3600, # 连接回收时间
- echo=False
- )
-
- @contextmanager
- def get_session(self) -> Session:
- """
- 获取数据库会话的上下文管理器
-
- Yields:
- Session: 数据库会话对象
- """
- session = self.SessionLocal()
- try:
- yield session
- session.commit()
- except Exception as e:
- session.rollback()
- logger.error(f"数据库操作失败: {e}")
- raise
- finally:
- session.close()
-
- def get_engine(self) -> Engine:
- """获取数据库引擎"""
- return self.engine
-
- def test_connection(self) -> bool:
- """
- 测试数据库连接
-
- Returns:
- bool: 连接是否成功
- """
- try:
- with self.get_session() as session:
- if self.db_type == "sqlite":
- session.execute(text("SELECT 1"))
- else:
- session.execute(text("SELECT 1"))
- logger.info("数据库连接测试成功")
- return True
- except Exception as e:
- logger.error(f"数据库连接测试失败: {e}")
- return False
-
- def close(self):
- """关闭数据库连接"""
- if self.engine:
- self.engine.dispose()
- logger.info("数据库连接已关闭")
- # 默认数据库连接实例
- default_db = DatabaseConnection(
- db_type="sqlite",
- database="backend/data/ai_swap.db"
- )
|