""" 数据库迁移模块 处理数据库版本迁移和初始化 """ import os import json from datetime import datetime from typing import List, Dict, Any, Optional from sqlalchemy import text, inspect from sqlalchemy.exc import OperationalError from .connection import DatabaseConnection from .models import Base, DatabaseVersion from backend.utils.logger_config import setup_logger logger = setup_logger(__name__) class DatabaseMigration: """数据库迁移管理类""" def __init__(self, db_connection: Optional[DatabaseConnection] = None): """ 初始化数据库迁移 Args: db_connection: 数据库连接对象 """ self.db_connection = db_connection or DatabaseConnection() self.migrations_dir = os.path.join(os.path.dirname(__file__), "migrations") # 确保迁移目录存在 if not os.path.exists(self.migrations_dir): os.makedirs(self.migrations_dir) def init_database(self) -> bool: """ 初始化数据库 Returns: bool: 初始化是否成功 """ try: with self.db_connection.get_session() as session: # 创建所有表 Base.metadata.create_all(bind=self.db_connection.get_engine()) # 检查是否已有版本记录 version_count = session.query(DatabaseVersion).count() if version_count == 0: # 创建初始版本记录 initial_version = DatabaseVersion( version="1.0.0", description="初始数据库版本", migration_file="initial_migration" ) session.add(initial_version) session.flush() logger.info("数据库初始化成功,版本: 1.0.0") else: logger.info("数据库已存在,跳过初始化") return True except Exception as e: logger.error(f"数据库初始化失败: {e}") return False def get_current_version(self) -> Optional[str]: """ 获取当前数据库版本 Returns: str: 当前版本号,如果未初始化则返回None """ try: with self.db_connection.get_session() as session: latest_version = session.query(DatabaseVersion).order_by( DatabaseVersion.applied_at.desc() ).first() return latest_version.version if latest_version else None except Exception as e: logger.error(f"获取当前版本失败: {e}") return None def get_all_versions(self) -> List[DatabaseVersion]: """ 获取所有已应用的版本 Returns: List[DatabaseVersion]: 版本列表 """ try: with self.db_connection.get_session() as session: return session.query(DatabaseVersion).order_by( DatabaseVersion.applied_at.asc() ).all() except Exception as e: logger.error(f"获取版本列表失败: {e}") return [] def create_migration(self, version: str, description: str, up_sql: str, down_sql: str) -> str: """ 创建迁移文件 Args: version: 版本号 description: 版本描述 up_sql: 升级SQL down_sql: 回滚SQL Returns: str: 迁移文件路径 """ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{timestamp}_{version}_{description.replace(' ', '_')}.json" filepath = os.path.join(self.migrations_dir, filename) migration_data = { "version": version, "description": description, "created_at": datetime.now().isoformat(), "up_sql": up_sql, "down_sql": down_sql } with open(filepath, 'w', encoding='utf-8') as f: json.dump(migration_data, f, ensure_ascii=False, indent=2) logger.info(f"创建迁移文件: {filepath}") return filepath def get_pending_migrations(self) -> List[Dict[str, Any]]: """ 获取待执行的迁移 Returns: List[Dict]: 待执行的迁移列表 """ current_version = self.get_current_version() pending_migrations = [] # 读取迁移目录中的所有迁移文件 for filename in os.listdir(self.migrations_dir): if filename.endswith('.json'): filepath = os.path.join(self.migrations_dir, filename) try: with open(filepath, 'r', encoding='utf-8') as f: migration_data = json.load(f) # 检查是否已应用 if not current_version or self._compare_versions(migration_data['version'], current_version) > 0: migration_data['filepath'] = filepath pending_migrations.append(migration_data) except Exception as e: logger.error(f"读取迁移文件失败 {filename}: {e}") # 按版本号排序 pending_migrations.sort(key=lambda x: x['version']) return pending_migrations def migrate_up(self, target_version: Optional[str] = None) -> bool: """ 执行数据库升级 Args: target_version: 目标版本,如果为None则升级到最新版本 Returns: bool: 升级是否成功 """ try: pending_migrations = self.get_pending_migrations() if not pending_migrations: logger.info("没有待执行的迁移") return True # 如果指定了目标版本,只执行到该版本 if target_version: pending_migrations = [ m for m in pending_migrations if self._compare_versions(m['version'], target_version) <= 0 ] for migration in pending_migrations: if not self._apply_migration(migration): logger.error(f"应用迁移失败: {migration['version']}") return False logger.info(f"数据库升级成功,共执行 {len(pending_migrations)} 个迁移") return True except Exception as e: logger.error(f"数据库升级失败: {e}") return False def migrate_down(self, target_version: Optional[str] = None) -> bool: """ 执行数据库回滚 Args: target_version: 目标版本 Returns: bool: 回滚是否成功 """ try: current_version = self.get_current_version() if not current_version: logger.error("数据库未初始化") return False if self._compare_versions(current_version, target_version) <= 0: logger.info("当前版本已经小于等于目标版本,无需回滚") return True # 获取需要回滚的迁移(按倒序) applied_versions = self.get_all_versions() rollback_migrations = [] for version_record in reversed(applied_versions): if self._compare_versions(version_record.version, target_version) <= 0: break # 查找对应的迁移文件 migration_file = self._find_migration_file(version_record.version) if migration_file: with open(migration_file, 'r', encoding='utf-8') as f: migration_data = json.load(f) rollback_migrations.append(migration_data) # 执行回滚 for migration in rollback_migrations: if not self._rollback_migration(migration): logger.error(f"回滚迁移失败: {migration['version']}") return False logger.info(f"数据库回滚成功,共回滚 {len(rollback_migrations)} 个迁移") return True except Exception as e: logger.error(f"数据库回滚失败: {e}") return False def _apply_migration(self, migration: Dict[str, Any]) -> bool: """ 应用单个迁移 Args: migration: 迁移数据 Returns: bool: 是否成功 """ try: with self.db_connection.get_session() as session: # 执行升级SQL(支持多条语句) if migration['up_sql'].strip(): for sql in migration['up_sql'].split(';'): sql = sql.strip() if sql: session.execute(text(sql)) # 记录版本 version_record = DatabaseVersion( version=migration['version'], description=migration['description'], migration_file=os.path.basename(migration['filepath']) ) session.add(version_record) session.flush() logger.info(f"应用迁移成功: {migration['version']} - {migration['description']}") return True except Exception as e: logger.error(f"应用迁移失败 {migration['version']}: {e}") return False def _rollback_migration(self, migration: Dict[str, Any]) -> bool: """ 回滚单个迁移 Args: migration: 迁移数据 Returns: bool: 是否成功 """ try: with self.db_connection.get_session() as session: # 执行回滚SQL(支持多条语句) if migration['down_sql'].strip(): for sql in migration['down_sql'].split(';'): sql = sql.strip() if sql: session.execute(text(sql)) # 删除版本记录 session.query(DatabaseVersion).filter( DatabaseVersion.version == migration['version'] ).delete() logger.info(f"回滚迁移成功: {migration['version']} - {migration['description']}") return True except Exception as e: logger.error(f"回滚迁移失败 {migration['version']}: {e}") return False def _find_migration_file(self, version: str) -> Optional[str]: """ 查找迁移文件 Args: version: 版本号 Returns: str: 文件路径,如果未找到则返回None """ for filename in os.listdir(self.migrations_dir): if filename.endswith('.json'): filepath = os.path.join(self.migrations_dir, filename) try: with open(filepath, 'r', encoding='utf-8') as f: migration_data = json.load(f) if migration_data['version'] == version: return filepath except: continue return None def _compare_versions(self, version1: str, version2: str) -> int: """ 比较版本号 Args: version1: 版本1 version2: 版本2 Returns: int: 1表示version1>version2, -1表示version1 v2_tuple: return 1 elif v1_tuple < v2_tuple: return -1 else: return 0 except: # 如果版本号格式不正确,按字符串比较 if version1 > version2: return 1 elif version1 < version2: return -1 else: return 0 def check_database_health(self) -> Dict[str, Any]: """ 检查数据库健康状态 Returns: Dict: 健康状态信息 """ health_info = { "status": "unknown", "current_version": None, "total_migrations": 0, "pending_migrations": 0, "errors": [] } try: # 测试连接 if not self.db_connection.test_connection(): health_info["status"] = "error" health_info["errors"].append("数据库连接失败") return health_info # 获取当前版本 current_version = self.get_current_version() health_info["current_version"] = current_version # 获取版本统计 all_versions = self.get_all_versions() health_info["total_migrations"] = len(all_versions) # 获取待执行迁移 pending_migrations = self.get_pending_migrations() health_info["pending_migrations"] = len(pending_migrations) # 判断状态 if health_info["errors"]: health_info["status"] = "error" elif health_info["pending_migrations"] > 0: health_info["status"] = "outdated" else: health_info["status"] = "healthy" except Exception as e: health_info["status"] = "error" health_info["errors"].append(str(e)) return health_info if __name__ == "__main__": migration = DatabaseMigration() migration.init_database()