connection.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. """
  2. 数据库连接管理模块
  3. 提供数据库连接池和配置管理功能
  4. """
  5. import os
  6. import sqlite3
  7. from typing import Optional, Dict, Any
  8. from contextlib import contextmanager
  9. from sqlalchemy import create_engine, Engine, text
  10. from sqlalchemy.orm import sessionmaker, Session
  11. from sqlalchemy.pool import QueuePool
  12. from backend.utils.logger_config import setup_logger
  13. logger = setup_logger(__name__)
  14. class DatabaseConnection:
  15. """数据库连接管理类"""
  16. def __init__(self, db_type: str = "sqlite", **kwargs):
  17. """
  18. 初始化数据库连接
  19. Args:
  20. db_type: 数据库类型 ("sqlite" 或 "mysql")
  21. **kwargs: 数据库连接参数
  22. """
  23. self.db_type = db_type.lower()
  24. self.engine: Optional[Engine] = None
  25. self.SessionLocal: Optional[sessionmaker] = None
  26. self.connection_params = kwargs
  27. self._setup_connection()
  28. def _setup_connection(self):
  29. """设置数据库连接"""
  30. try:
  31. if self.db_type == "sqlite":
  32. self._setup_sqlite()
  33. elif self.db_type == "mysql":
  34. self._setup_mysql()
  35. else:
  36. raise ValueError(f"不支持的数据库类型: {self.db_type}")
  37. # 创建会话工厂
  38. self.SessionLocal = sessionmaker(
  39. autocommit=False,
  40. autoflush=False,
  41. bind=self.engine
  42. )
  43. logger.info(f"数据库连接初始化成功: {self.db_type}")
  44. except Exception as e:
  45. logger.error(f"数据库连接初始化失败: {e}")
  46. raise
  47. def _setup_sqlite(self):
  48. """设置SQLite连接"""
  49. db_path = self.connection_params.get('database', './backend/db/ai_swap.db')
  50. # 确保数据库目录存在
  51. db_dir = os.path.dirname(db_path)
  52. if db_dir and not os.path.exists(db_dir):
  53. os.makedirs(db_dir)
  54. # 创建SQLite引擎
  55. self.engine = create_engine(
  56. f"sqlite:///{db_path}",
  57. connect_args={"check_same_thread": False},
  58. poolclass=QueuePool,
  59. pool_size=10,
  60. max_overflow=20,
  61. pool_pre_ping=True,
  62. echo=False # 设置为True可以看到SQL语句
  63. )
  64. def _setup_mysql(self):
  65. """设置MySQL连接"""
  66. host = self.connection_params.get('host', 'localhost')
  67. port = self.connection_params.get('port', 3306)
  68. database = self.connection_params.get('database', 'ai_swap')
  69. username = self.connection_params.get('username', 'root')
  70. password = self.connection_params.get('password', '')
  71. charset = self.connection_params.get('charset', 'utf8mb4')
  72. # 构建MySQL连接URL
  73. mysql_url = f"mysql+pymysql://{username}:{password}@{host}:{port}/{database}?charset={charset}"
  74. # 创建MySQL引擎
  75. self.engine = create_engine(
  76. mysql_url,
  77. poolclass=QueuePool,
  78. pool_size=10,
  79. max_overflow=20,
  80. pool_pre_ping=True,
  81. pool_recycle=3600, # 连接回收时间
  82. echo=False
  83. )
  84. @contextmanager
  85. def get_session(self) -> Session:
  86. """
  87. 获取数据库会话的上下文管理器
  88. Yields:
  89. Session: 数据库会话对象
  90. """
  91. session = self.SessionLocal()
  92. try:
  93. yield session
  94. session.commit()
  95. except Exception as e:
  96. session.rollback()
  97. logger.error(f"数据库操作失败: {e}")
  98. raise
  99. finally:
  100. session.close()
  101. def get_engine(self) -> Engine:
  102. """获取数据库引擎"""
  103. return self.engine
  104. def test_connection(self) -> bool:
  105. """
  106. 测试数据库连接
  107. Returns:
  108. bool: 连接是否成功
  109. """
  110. try:
  111. with self.get_session() as session:
  112. if self.db_type == "sqlite":
  113. session.execute(text("SELECT 1"))
  114. else:
  115. session.execute(text("SELECT 1"))
  116. logger.info("数据库连接测试成功")
  117. return True
  118. except Exception as e:
  119. logger.error(f"数据库连接测试失败: {e}")
  120. return False
  121. def close(self):
  122. """关闭数据库连接"""
  123. if self.engine:
  124. self.engine.dispose()
  125. logger.info("数据库连接已关闭")
  126. # 默认数据库连接实例
  127. default_db = DatabaseConnection(
  128. db_type="sqlite",
  129. database="backend/data/ai_swap.db"
  130. )