| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- #!/usr/bin/env python3
- """
- 初始化测试数据脚本
- 创建测试用户和测试素材数据
- """
- import os
- import sys
- import hashlib
- from pathlib import Path
- # 添加项目根目录到Python路径
- project_root = Path(__file__).parent
- sys.path.insert(0, str(project_root))
- from backend.modules.database import DatabaseMigration, DatabaseOperations
- from backend.utils.logger_config import setup_logger
- logger = setup_logger(__name__)
- def init_test_data():
- """初始化测试数据"""
-
- print("=== 初始化测试数据 ===")
-
- # 1. 初始化数据库
- print("\n1. 初始化数据库...")
- migration = DatabaseMigration()
- if migration.init_database():
- print("✓ 数据库初始化成功")
- else:
- print("✗ 数据库初始化失败")
- return False
-
- # 2. 创建数据库操作实例
- db_ops = DatabaseOperations()
-
- # 3. 创建测试用户
- print("\n2. 创建测试用户...")
-
- # 检查用户是否已存在
- existing_user = db_ops.get_user_by_username("test_user")
- if existing_user:
- print(f"✓ 测试用户已存在: {existing_user['username']} (ID: {existing_user['id']})")
- user_id = existing_user['id']
- else:
- # 创建新用户
- password_hash = hashlib.sha256("test123".encode()).hexdigest()
- user = db_ops.create_user(
- username="test_user",
- password_hash=password_hash,
- is_admin=False
- )
- print(f"✓ 创建测试用户: {user['username']} (ID: {user['id']})")
- user_id = user['id']
-
- # 4. 创建测试素材目录
- print("\n3. 创建测试素材目录...")
- materials_dir = Path("backend/data/materials")
- materials_dir.mkdir(parents=True, exist_ok=True)
- print(f"✓ 素材目录已创建: {materials_dir}")
-
- # 5. 创建测试图片文件
- print("\n4. 创建测试图片文件...")
-
- # 创建测试PNG图片数据
- png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf5\x27\xdc\xf9\x00\x00\x00\x00IEND\xaeB`\x82'
-
- # 创建测试素材文件
- test_files = [
- ("test_face_1.png", "face", "测试人脸图片1"),
- ("test_face_2.png", "face", "测试人脸图片2"),
- ("test_cloth_1.png", "cloth", "测试服装图片1"),
- ("test_cloth_2.png", "cloth", "测试服装图片2"),
- ]
-
- created_materials = []
-
- for filename, image_type, description in test_files:
- file_path = materials_dir / filename
-
- # 创建文件
- with open(file_path, "wb") as f:
- f.write(png_data)
-
- # 创建数据库记录
- image_record = db_ops.create_image_record(
- user_id=user_id,
- image_type=image_type,
- original_filename=filename,
- stored_path=str(file_path),
- file_size=len(png_data),
- image_hash=None
- )
-
- if image_record:
- created_materials.append({
- "id": image_record['id'],
- "filename": filename,
- "type": image_type,
- "description": description
- })
- print(f"✓ 创建素材: {filename} (ID: {image_record['id']})")
- else:
- print(f"✗ 创建素材失败: {filename}")
-
- # 6. 验证数据
- print("\n5. 验证测试数据...")
-
- # 检查用户素材
- user_images = db_ops.get_user_images(user_id)
- print(f"✓ 用户素材总数: {user_images['total']}")
-
- # 检查人脸素材
- face_images = db_ops.get_user_images(user_id, image_type="face")
- print(f"✓ 人脸素材数量: {face_images['total']}")
-
- # 检查服装素材
- cloth_images = db_ops.get_user_images(user_id, image_type="cloth")
- print(f"✓ 服装素材数量: {cloth_images['total']}")
-
- # 7. 输出测试信息
- print("\n=== 测试信息 ===")
- print(f"测试用户ID: {user_id}")
- print(f"测试用户名: test_user")
- print(f"测试密码: test123")
- print(f"素材目录: {materials_dir}")
- print(f"创建素材数量: {len(created_materials)}")
-
- print("\n=== 素材列表 ===")
- for material in created_materials:
- print(f"- {material['filename']} ({material['type']}): {material['description']}")
-
- print("\n=== API测试URL ===")
- print(f"获取素材列表: GET http://localhost:8000/api/v1/users/{user_id}/materials")
- print(f"获取人脸素材: GET http://localhost:8000/api/v1/users/{user_id}/materials?material_type=face")
- print(f"获取服装素材: GET http://localhost:8000/api/v1/users/{user_id}/materials?material_type=cloth")
-
- print("\n✓ 测试数据初始化完成!")
- return True
- if __name__ == "__main__":
- try:
- init_test_data()
- except Exception as e:
- print(f"✗ 初始化失败: {e}")
- import traceback
- traceback.print_exc()
|