migrations.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. """
  2. 数据库迁移模块
  3. 处理数据库版本迁移和初始化
  4. """
  5. import os
  6. import json
  7. from datetime import datetime
  8. from typing import List, Dict, Any, Optional
  9. from sqlalchemy import text, inspect
  10. from sqlalchemy.exc import OperationalError
  11. from .connection import DatabaseConnection
  12. from .models import Base, DatabaseVersion
  13. from backend.utils.logger_config import setup_logger
  14. logger = setup_logger(__name__)
  15. class DatabaseMigration:
  16. """数据库迁移管理类"""
  17. def __init__(self, db_connection: Optional[DatabaseConnection] = None):
  18. """
  19. 初始化数据库迁移
  20. Args:
  21. db_connection: 数据库连接对象
  22. """
  23. self.db_connection = db_connection or DatabaseConnection()
  24. self.migrations_dir = os.path.join(os.path.dirname(__file__), "migrations")
  25. # 确保迁移目录存在
  26. if not os.path.exists(self.migrations_dir):
  27. os.makedirs(self.migrations_dir)
  28. def init_database(self) -> bool:
  29. """
  30. 初始化数据库
  31. Returns:
  32. bool: 初始化是否成功
  33. """
  34. try:
  35. with self.db_connection.get_session() as session:
  36. # 创建所有表
  37. Base.metadata.create_all(bind=self.db_connection.get_engine())
  38. # 检查是否已有版本记录
  39. version_count = session.query(DatabaseVersion).count()
  40. if version_count == 0:
  41. # 创建初始版本记录
  42. initial_version = DatabaseVersion(
  43. version="1.0.0",
  44. description="初始数据库版本",
  45. migration_file="initial_migration"
  46. )
  47. session.add(initial_version)
  48. session.flush()
  49. logger.info("数据库初始化成功,版本: 1.0.0")
  50. else:
  51. logger.info("数据库已存在,跳过初始化")
  52. return True
  53. except Exception as e:
  54. logger.error(f"数据库初始化失败: {e}")
  55. return False
  56. def get_current_version(self) -> Optional[str]:
  57. """
  58. 获取当前数据库版本
  59. Returns:
  60. str: 当前版本号,如果未初始化则返回None
  61. """
  62. try:
  63. with self.db_connection.get_session() as session:
  64. latest_version = session.query(DatabaseVersion).order_by(
  65. DatabaseVersion.applied_at.desc()
  66. ).first()
  67. return latest_version.version if latest_version else None
  68. except Exception as e:
  69. logger.error(f"获取当前版本失败: {e}")
  70. return None
  71. def get_all_versions(self) -> List[DatabaseVersion]:
  72. """
  73. 获取所有已应用的版本
  74. Returns:
  75. List[DatabaseVersion]: 版本列表
  76. """
  77. try:
  78. with self.db_connection.get_session() as session:
  79. return session.query(DatabaseVersion).order_by(
  80. DatabaseVersion.applied_at.asc()
  81. ).all()
  82. except Exception as e:
  83. logger.error(f"获取版本列表失败: {e}")
  84. return []
  85. def create_migration(self, version: str, description: str,
  86. up_sql: str, down_sql: str) -> str:
  87. """
  88. 创建迁移文件
  89. Args:
  90. version: 版本号
  91. description: 版本描述
  92. up_sql: 升级SQL
  93. down_sql: 回滚SQL
  94. Returns:
  95. str: 迁移文件路径
  96. """
  97. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  98. filename = f"{timestamp}_{version}_{description.replace(' ', '_')}.json"
  99. filepath = os.path.join(self.migrations_dir, filename)
  100. migration_data = {
  101. "version": version,
  102. "description": description,
  103. "created_at": datetime.now().isoformat(),
  104. "up_sql": up_sql,
  105. "down_sql": down_sql
  106. }
  107. with open(filepath, 'w', encoding='utf-8') as f:
  108. json.dump(migration_data, f, ensure_ascii=False, indent=2)
  109. logger.info(f"创建迁移文件: {filepath}")
  110. return filepath
  111. def get_pending_migrations(self) -> List[Dict[str, Any]]:
  112. """
  113. 获取待执行的迁移
  114. Returns:
  115. List[Dict]: 待执行的迁移列表
  116. """
  117. current_version = self.get_current_version()
  118. pending_migrations = []
  119. # 读取迁移目录中的所有迁移文件
  120. for filename in os.listdir(self.migrations_dir):
  121. if filename.endswith('.json'):
  122. filepath = os.path.join(self.migrations_dir, filename)
  123. try:
  124. with open(filepath, 'r', encoding='utf-8') as f:
  125. migration_data = json.load(f)
  126. # 检查是否已应用
  127. if not current_version or self._compare_versions(migration_data['version'], current_version) > 0:
  128. migration_data['filepath'] = filepath
  129. pending_migrations.append(migration_data)
  130. except Exception as e:
  131. logger.error(f"读取迁移文件失败 {filename}: {e}")
  132. # 按版本号排序
  133. pending_migrations.sort(key=lambda x: x['version'])
  134. return pending_migrations
  135. def migrate_up(self, target_version: Optional[str] = None) -> bool:
  136. """
  137. 执行数据库升级
  138. Args:
  139. target_version: 目标版本,如果为None则升级到最新版本
  140. Returns:
  141. bool: 升级是否成功
  142. """
  143. try:
  144. pending_migrations = self.get_pending_migrations()
  145. if not pending_migrations:
  146. logger.info("没有待执行的迁移")
  147. return True
  148. # 如果指定了目标版本,只执行到该版本
  149. if target_version:
  150. pending_migrations = [
  151. m for m in pending_migrations
  152. if self._compare_versions(m['version'], target_version) <= 0
  153. ]
  154. for migration in pending_migrations:
  155. if not self._apply_migration(migration):
  156. logger.error(f"应用迁移失败: {migration['version']}")
  157. return False
  158. logger.info(f"数据库升级成功,共执行 {len(pending_migrations)} 个迁移")
  159. return True
  160. except Exception as e:
  161. logger.error(f"数据库升级失败: {e}")
  162. return False
  163. def migrate_down(self, target_version: Optional[str] = None) -> bool:
  164. """
  165. 执行数据库回滚
  166. Args:
  167. target_version: 目标版本
  168. Returns:
  169. bool: 回滚是否成功
  170. """
  171. try:
  172. current_version = self.get_current_version()
  173. if not current_version:
  174. logger.error("数据库未初始化")
  175. return False
  176. if self._compare_versions(current_version, target_version) <= 0:
  177. logger.info("当前版本已经小于等于目标版本,无需回滚")
  178. return True
  179. # 获取需要回滚的迁移(按倒序)
  180. applied_versions = self.get_all_versions()
  181. rollback_migrations = []
  182. for version_record in reversed(applied_versions):
  183. if self._compare_versions(version_record.version, target_version) <= 0:
  184. break
  185. # 查找对应的迁移文件
  186. migration_file = self._find_migration_file(version_record.version)
  187. if migration_file:
  188. with open(migration_file, 'r', encoding='utf-8') as f:
  189. migration_data = json.load(f)
  190. rollback_migrations.append(migration_data)
  191. # 执行回滚
  192. for migration in rollback_migrations:
  193. if not self._rollback_migration(migration):
  194. logger.error(f"回滚迁移失败: {migration['version']}")
  195. return False
  196. logger.info(f"数据库回滚成功,共回滚 {len(rollback_migrations)} 个迁移")
  197. return True
  198. except Exception as e:
  199. logger.error(f"数据库回滚失败: {e}")
  200. return False
  201. def _apply_migration(self, migration: Dict[str, Any]) -> bool:
  202. """
  203. 应用单个迁移
  204. Args:
  205. migration: 迁移数据
  206. Returns:
  207. bool: 是否成功
  208. """
  209. try:
  210. with self.db_connection.get_session() as session:
  211. # 执行升级SQL(支持多条语句)
  212. if migration['up_sql'].strip():
  213. for sql in migration['up_sql'].split(';'):
  214. sql = sql.strip()
  215. if sql:
  216. session.execute(text(sql))
  217. # 记录版本
  218. version_record = DatabaseVersion(
  219. version=migration['version'],
  220. description=migration['description'],
  221. migration_file=os.path.basename(migration['filepath'])
  222. )
  223. session.add(version_record)
  224. session.flush()
  225. logger.info(f"应用迁移成功: {migration['version']} - {migration['description']}")
  226. return True
  227. except Exception as e:
  228. logger.error(f"应用迁移失败 {migration['version']}: {e}")
  229. return False
  230. def _rollback_migration(self, migration: Dict[str, Any]) -> bool:
  231. """
  232. 回滚单个迁移
  233. Args:
  234. migration: 迁移数据
  235. Returns:
  236. bool: 是否成功
  237. """
  238. try:
  239. with self.db_connection.get_session() as session:
  240. # 执行回滚SQL(支持多条语句)
  241. if migration['down_sql'].strip():
  242. for sql in migration['down_sql'].split(';'):
  243. sql = sql.strip()
  244. if sql:
  245. session.execute(text(sql))
  246. # 删除版本记录
  247. session.query(DatabaseVersion).filter(
  248. DatabaseVersion.version == migration['version']
  249. ).delete()
  250. logger.info(f"回滚迁移成功: {migration['version']} - {migration['description']}")
  251. return True
  252. except Exception as e:
  253. logger.error(f"回滚迁移失败 {migration['version']}: {e}")
  254. return False
  255. def _find_migration_file(self, version: str) -> Optional[str]:
  256. """
  257. 查找迁移文件
  258. Args:
  259. version: 版本号
  260. Returns:
  261. str: 文件路径,如果未找到则返回None
  262. """
  263. for filename in os.listdir(self.migrations_dir):
  264. if filename.endswith('.json'):
  265. filepath = os.path.join(self.migrations_dir, filename)
  266. try:
  267. with open(filepath, 'r', encoding='utf-8') as f:
  268. migration_data = json.load(f)
  269. if migration_data['version'] == version:
  270. return filepath
  271. except:
  272. continue
  273. return None
  274. def _compare_versions(self, version1: str, version2: str) -> int:
  275. """
  276. 比较版本号
  277. Args:
  278. version1: 版本1
  279. version2: 版本2
  280. Returns:
  281. int: 1表示version1>version2, -1表示version1<version2, 0表示相等
  282. """
  283. def version_to_tuple(v):
  284. return tuple(int(x) for x in v.split('.'))
  285. try:
  286. v1_tuple = version_to_tuple(version1)
  287. v2_tuple = version_to_tuple(version2)
  288. if v1_tuple > v2_tuple:
  289. return 1
  290. elif v1_tuple < v2_tuple:
  291. return -1
  292. else:
  293. return 0
  294. except:
  295. # 如果版本号格式不正确,按字符串比较
  296. if version1 > version2:
  297. return 1
  298. elif version1 < version2:
  299. return -1
  300. else:
  301. return 0
  302. def check_database_health(self) -> Dict[str, Any]:
  303. """
  304. 检查数据库健康状态
  305. Returns:
  306. Dict: 健康状态信息
  307. """
  308. health_info = {
  309. "status": "unknown",
  310. "current_version": None,
  311. "total_migrations": 0,
  312. "pending_migrations": 0,
  313. "errors": []
  314. }
  315. try:
  316. # 测试连接
  317. if not self.db_connection.test_connection():
  318. health_info["status"] = "error"
  319. health_info["errors"].append("数据库连接失败")
  320. return health_info
  321. # 获取当前版本
  322. current_version = self.get_current_version()
  323. health_info["current_version"] = current_version
  324. # 获取版本统计
  325. all_versions = self.get_all_versions()
  326. health_info["total_migrations"] = len(all_versions)
  327. # 获取待执行迁移
  328. pending_migrations = self.get_pending_migrations()
  329. health_info["pending_migrations"] = len(pending_migrations)
  330. # 判断状态
  331. if health_info["errors"]:
  332. health_info["status"] = "error"
  333. elif health_info["pending_migrations"] > 0:
  334. health_info["status"] = "outdated"
  335. else:
  336. health_info["status"] = "healthy"
  337. except Exception as e:
  338. health_info["status"] = "error"
  339. health_info["errors"].append(str(e))
  340. return health_info
  341. if __name__ == "__main__":
  342. migration = DatabaseMigration()
  343. migration.init_database()