| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419 |
- """
- 数据库迁移模块
- 处理数据库版本迁移和初始化
- """
- import os
- import json
- from datetime import datetime
- from typing import List, Dict, Any, Optional
- from sqlalchemy import text, inspect
- from sqlalchemy.exc import OperationalError
- from .connection import DatabaseConnection
- from .models import Base, DatabaseVersion
- from backend.utils.logger_config import setup_logger
- logger = setup_logger(__name__)
- class DatabaseMigration:
- """数据库迁移管理类"""
-
- def __init__(self, db_connection: Optional[DatabaseConnection] = None):
- """
- 初始化数据库迁移
-
- Args:
- db_connection: 数据库连接对象
- """
- self.db_connection = db_connection or DatabaseConnection()
- self.migrations_dir = os.path.join(os.path.dirname(__file__), "migrations")
-
- # 确保迁移目录存在
- if not os.path.exists(self.migrations_dir):
- os.makedirs(self.migrations_dir)
-
- def init_database(self) -> bool:
- """
- 初始化数据库
-
- Returns:
- bool: 初始化是否成功
- """
- try:
- with self.db_connection.get_session() as session:
- # 创建所有表
- Base.metadata.create_all(bind=self.db_connection.get_engine())
-
- # 检查是否已有版本记录
- version_count = session.query(DatabaseVersion).count()
- if version_count == 0:
- # 创建初始版本记录
- initial_version = DatabaseVersion(
- version="1.0.0",
- description="初始数据库版本",
- migration_file="initial_migration"
- )
- session.add(initial_version)
- session.flush()
- logger.info("数据库初始化成功,版本: 1.0.0")
- else:
- logger.info("数据库已存在,跳过初始化")
-
- return True
-
- except Exception as e:
- logger.error(f"数据库初始化失败: {e}")
- return False
-
- def get_current_version(self) -> Optional[str]:
- """
- 获取当前数据库版本
-
- Returns:
- str: 当前版本号,如果未初始化则返回None
- """
- try:
- with self.db_connection.get_session() as session:
- latest_version = session.query(DatabaseVersion).order_by(
- DatabaseVersion.applied_at.desc()
- ).first()
- return latest_version.version if latest_version else None
- except Exception as e:
- logger.error(f"获取当前版本失败: {e}")
- return None
-
- def get_all_versions(self) -> List[DatabaseVersion]:
- """
- 获取所有已应用的版本
-
- Returns:
- List[DatabaseVersion]: 版本列表
- """
- try:
- with self.db_connection.get_session() as session:
- return session.query(DatabaseVersion).order_by(
- DatabaseVersion.applied_at.asc()
- ).all()
- except Exception as e:
- logger.error(f"获取版本列表失败: {e}")
- return []
-
- def create_migration(self, version: str, description: str,
- up_sql: str, down_sql: str) -> str:
- """
- 创建迁移文件
-
- Args:
- version: 版本号
- description: 版本描述
- up_sql: 升级SQL
- down_sql: 回滚SQL
-
- Returns:
- str: 迁移文件路径
- """
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- filename = f"{timestamp}_{version}_{description.replace(' ', '_')}.json"
- filepath = os.path.join(self.migrations_dir, filename)
-
- migration_data = {
- "version": version,
- "description": description,
- "created_at": datetime.now().isoformat(),
- "up_sql": up_sql,
- "down_sql": down_sql
- }
-
- with open(filepath, 'w', encoding='utf-8') as f:
- json.dump(migration_data, f, ensure_ascii=False, indent=2)
-
- logger.info(f"创建迁移文件: {filepath}")
- return filepath
-
- def get_pending_migrations(self) -> List[Dict[str, Any]]:
- """
- 获取待执行的迁移
-
- Returns:
- List[Dict]: 待执行的迁移列表
- """
- current_version = self.get_current_version()
- pending_migrations = []
-
- # 读取迁移目录中的所有迁移文件
- for filename in os.listdir(self.migrations_dir):
- if filename.endswith('.json'):
- filepath = os.path.join(self.migrations_dir, filename)
- try:
- with open(filepath, 'r', encoding='utf-8') as f:
- migration_data = json.load(f)
-
- # 检查是否已应用
- if not current_version or self._compare_versions(migration_data['version'], current_version) > 0:
- migration_data['filepath'] = filepath
- pending_migrations.append(migration_data)
-
- except Exception as e:
- logger.error(f"读取迁移文件失败 {filename}: {e}")
-
- # 按版本号排序
- pending_migrations.sort(key=lambda x: x['version'])
- return pending_migrations
-
- def migrate_up(self, target_version: Optional[str] = None) -> bool:
- """
- 执行数据库升级
-
- Args:
- target_version: 目标版本,如果为None则升级到最新版本
-
- Returns:
- bool: 升级是否成功
- """
- try:
- pending_migrations = self.get_pending_migrations()
-
- if not pending_migrations:
- logger.info("没有待执行的迁移")
- return True
-
- # 如果指定了目标版本,只执行到该版本
- if target_version:
- pending_migrations = [
- m for m in pending_migrations
- if self._compare_versions(m['version'], target_version) <= 0
- ]
-
- for migration in pending_migrations:
- if not self._apply_migration(migration):
- logger.error(f"应用迁移失败: {migration['version']}")
- return False
-
- logger.info(f"数据库升级成功,共执行 {len(pending_migrations)} 个迁移")
- return True
-
- except Exception as e:
- logger.error(f"数据库升级失败: {e}")
- return False
-
- def migrate_down(self, target_version: Optional[str] = None) -> bool:
- """
- 执行数据库回滚
-
- Args:
- target_version: 目标版本
-
- Returns:
- bool: 回滚是否成功
- """
- try:
- current_version = self.get_current_version()
- if not current_version:
- logger.error("数据库未初始化")
- return False
-
- if self._compare_versions(current_version, target_version) <= 0:
- logger.info("当前版本已经小于等于目标版本,无需回滚")
- return True
-
- # 获取需要回滚的迁移(按倒序)
- applied_versions = self.get_all_versions()
- rollback_migrations = []
-
- for version_record in reversed(applied_versions):
- if self._compare_versions(version_record.version, target_version) <= 0:
- break
-
- # 查找对应的迁移文件
- migration_file = self._find_migration_file(version_record.version)
- if migration_file:
- with open(migration_file, 'r', encoding='utf-8') as f:
- migration_data = json.load(f)
- rollback_migrations.append(migration_data)
-
- # 执行回滚
- for migration in rollback_migrations:
- if not self._rollback_migration(migration):
- logger.error(f"回滚迁移失败: {migration['version']}")
- return False
-
- logger.info(f"数据库回滚成功,共回滚 {len(rollback_migrations)} 个迁移")
- return True
-
- except Exception as e:
- logger.error(f"数据库回滚失败: {e}")
- return False
-
- def _apply_migration(self, migration: Dict[str, Any]) -> bool:
- """
- 应用单个迁移
-
- Args:
- migration: 迁移数据
-
- Returns:
- bool: 是否成功
- """
- try:
- with self.db_connection.get_session() as session:
- # 执行升级SQL(支持多条语句)
- if migration['up_sql'].strip():
- for sql in migration['up_sql'].split(';'):
- sql = sql.strip()
- if sql:
- session.execute(text(sql))
-
- # 记录版本
- version_record = DatabaseVersion(
- version=migration['version'],
- description=migration['description'],
- migration_file=os.path.basename(migration['filepath'])
- )
- session.add(version_record)
- session.flush()
-
- logger.info(f"应用迁移成功: {migration['version']} - {migration['description']}")
- return True
-
- except Exception as e:
- logger.error(f"应用迁移失败 {migration['version']}: {e}")
- return False
-
- def _rollback_migration(self, migration: Dict[str, Any]) -> bool:
- """
- 回滚单个迁移
-
- Args:
- migration: 迁移数据
-
- Returns:
- bool: 是否成功
- """
- try:
- with self.db_connection.get_session() as session:
- # 执行回滚SQL(支持多条语句)
- if migration['down_sql'].strip():
- for sql in migration['down_sql'].split(';'):
- sql = sql.strip()
- if sql:
- session.execute(text(sql))
-
- # 删除版本记录
- session.query(DatabaseVersion).filter(
- DatabaseVersion.version == migration['version']
- ).delete()
-
- logger.info(f"回滚迁移成功: {migration['version']} - {migration['description']}")
- return True
-
- except Exception as e:
- logger.error(f"回滚迁移失败 {migration['version']}: {e}")
- return False
-
- def _find_migration_file(self, version: str) -> Optional[str]:
- """
- 查找迁移文件
-
- Args:
- version: 版本号
-
- Returns:
- str: 文件路径,如果未找到则返回None
- """
- for filename in os.listdir(self.migrations_dir):
- if filename.endswith('.json'):
- filepath = os.path.join(self.migrations_dir, filename)
- try:
- with open(filepath, 'r', encoding='utf-8') as f:
- migration_data = json.load(f)
- if migration_data['version'] == version:
- return filepath
- except:
- continue
- return None
-
- def _compare_versions(self, version1: str, version2: str) -> int:
- """
- 比较版本号
-
- Args:
- version1: 版本1
- version2: 版本2
-
- Returns:
- int: 1表示version1>version2, -1表示version1<version2, 0表示相等
- """
- def version_to_tuple(v):
- return tuple(int(x) for x in v.split('.'))
-
- try:
- v1_tuple = version_to_tuple(version1)
- v2_tuple = version_to_tuple(version2)
-
- if v1_tuple > v2_tuple:
- return 1
- elif v1_tuple < v2_tuple:
- return -1
- else:
- return 0
- except:
- # 如果版本号格式不正确,按字符串比较
- if version1 > version2:
- return 1
- elif version1 < version2:
- return -1
- else:
- return 0
-
- def check_database_health(self) -> Dict[str, Any]:
- """
- 检查数据库健康状态
-
- Returns:
- Dict: 健康状态信息
- """
- health_info = {
- "status": "unknown",
- "current_version": None,
- "total_migrations": 0,
- "pending_migrations": 0,
- "errors": []
- }
-
- try:
- # 测试连接
- if not self.db_connection.test_connection():
- health_info["status"] = "error"
- health_info["errors"].append("数据库连接失败")
- return health_info
-
- # 获取当前版本
- current_version = self.get_current_version()
- health_info["current_version"] = current_version
-
- # 获取版本统计
- all_versions = self.get_all_versions()
- health_info["total_migrations"] = len(all_versions)
-
- # 获取待执行迁移
- pending_migrations = self.get_pending_migrations()
- health_info["pending_migrations"] = len(pending_migrations)
-
- # 判断状态
- if health_info["errors"]:
- health_info["status"] = "error"
- elif health_info["pending_migrations"] > 0:
- health_info["status"] = "outdated"
- else:
- health_info["status"] = "healthy"
-
- except Exception as e:
- health_info["status"] = "error"
- health_info["errors"].append(str(e))
-
- return health_info
- if __name__ == "__main__":
- migration = DatabaseMigration()
- migration.init_database()
|