streamlit_ui.py 43 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141
  1. """
  2. Streamlit 聊天界面 - idea2video
  3. 提供对话式的视频创作界面
  4. 使用方法:
  5. streamlit run streamlit_ui.py
  6. 功能:
  7. 1. 聊天式输入创意(idea)
  8. 2. 侧边栏设置用户要求、重试次数等
  9. 3. 上传参考图片映射文件(可选)
  10. 4. 实时显示执行进度和步骤状态
  11. 5. 查看历史运行结果
  12. 6. 继续执行未完成的运行
  13. 7. 显示最终视频和中间结果(故事、剧本、角色肖像、视频帧等)
  14. """
  15. import streamlit as st
  16. import asyncio
  17. import json
  18. import time
  19. import logging
  20. import threading
  21. from pathlib import Path
  22. from typing import Dict, Optional, List, Tuple
  23. import sys
  24. import os
  25. from queue import Queue, Empty
  26. import re
  27. import hashlib
  28. import shutil
  29. # 添加项目根目录到路径
  30. sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
  31. from taskflow import TaskManager, FileIOHandler, RunManager
  32. from taskflow import setup_logger
  33. from examples.video_create.pipeline.idea2video_pipeline import Idea2VideoPipeline
  34. # 配置页面
  35. st.set_page_config(
  36. page_title="Idea2Video - 视频创作助手",
  37. page_icon="🎬",
  38. layout="wide",
  39. initial_sidebar_state="expanded"
  40. )
  41. # 初始化 session state
  42. if "messages" not in st.session_state:
  43. st.session_state.messages = []
  44. if "current_run_id" not in st.session_state:
  45. st.session_state.current_run_id = None
  46. if "pipeline_running" not in st.session_state:
  47. st.session_state.pipeline_running = False
  48. if "run_manager" not in st.session_state:
  49. st.session_state.run_manager = RunManager(base_output_dir="output")
  50. if "uploaded_images" not in st.session_state:
  51. st.session_state.uploaded_images = {} # {image_id: {"path": str, "name": str}}
  52. # 设置日志
  53. logger = setup_logger("streamlit_ui", level=logging.INFO)
  54. # 添加全局视频显示 CSS(确保每次页面加载都应用)
  55. def add_video_css():
  56. """添加视频显示的全局 CSS"""
  57. st.markdown("""
  58. <style>
  59. /* 全局视频样式 - 使用通配符和最高优先级选择器 */
  60. video {
  61. width: 100% !important;
  62. max-width: 100% !important;
  63. height: auto !important;
  64. max-height: 70vh !important;
  65. object-fit: contain !important;
  66. display: block !important;
  67. margin: 0 auto !important;
  68. }
  69. /* Streamlit video 容器 */
  70. div[data-testid="stVideo"],
  71. .stVideo {
  72. width: 100% !important;
  73. max-width: 100% !important;
  74. overflow: visible !important;
  75. margin: 0 !important;
  76. padding: 0 !important;
  77. }
  78. div[data-testid="stVideo"] > div,
  79. .stVideo > div {
  80. width: 100% !important;
  81. max-width: 100% !important;
  82. padding: 0 !important;
  83. margin: 0 !important;
  84. }
  85. div[data-testid="stVideo"] video,
  86. .stVideo video {
  87. width: 100% !important;
  88. max-width: 100% !important;
  89. height: auto !important;
  90. max-height: 70vh !important;
  91. object-fit: contain !important;
  92. display: block !important;
  93. margin: 0 auto !important;
  94. }
  95. /* 聊天消息中的视频 */
  96. div[data-testid="stChatMessage"] {
  97. max-width: 100% !important;
  98. width: 100% !important;
  99. }
  100. div[data-testid="stChatMessage"] > div {
  101. max-width: 100% !important;
  102. width: 100% !important;
  103. }
  104. div[data-testid="stChatMessage"] div[data-testid="stVideo"],
  105. div[data-testid="stChatMessage"] .stVideo {
  106. width: 100% !important;
  107. max-width: 100% !important;
  108. margin: 0 !important;
  109. padding: 0 !important;
  110. }
  111. div[data-testid="stChatMessage"] video {
  112. width: 100% !important;
  113. max-width: 100% !important;
  114. height: auto !important;
  115. max-height: 70vh !important;
  116. object-fit: contain !important;
  117. display: block !important;
  118. margin: 0 auto !important;
  119. }
  120. /* 主内容区域 */
  121. section[data-testid="stAppViewContainer"] main {
  122. max-width: 100% !important;
  123. }
  124. section[data-testid="stAppViewContainer"] main > div {
  125. max-width: 100% !important;
  126. }
  127. /* 确保所有可能的容器都不限制视频宽度和高度 */
  128. [class*="video"],
  129. [id*="video"] {
  130. max-width: 100% !important;
  131. }
  132. /* 确保 Streamlit 的列布局不限制视频 */
  133. .stColumn > div,
  134. [data-testid="column"] > div {
  135. max-width: 100% !important;
  136. }
  137. /* 确保所有容器都允许视频完整显示 */
  138. * {
  139. box-sizing: border-box;
  140. }
  141. </style>
  142. """, unsafe_allow_html=True)
  143. def load_task_state(run_output_dir: str) -> Optional[Dict]:
  144. """加载任务状态"""
  145. state_file = Path(run_output_dir) / "task_state.json"
  146. if state_file.exists():
  147. try:
  148. with open(state_file, 'r', encoding='utf-8') as f:
  149. return json.load(f)
  150. except Exception as e:
  151. logger.error(f"加载任务状态失败: {e}")
  152. return None
  153. def get_step_status(state: Dict, step_name: str) -> str:
  154. """获取步骤状态"""
  155. if state is None:
  156. return "pending"
  157. steps = state.get("steps", {})
  158. step_info = steps.get(step_name, {})
  159. return step_info.get("status", "pending")
  160. def format_step_name(step_name: str) -> str:
  161. """格式化步骤名称"""
  162. step_names = {
  163. "step1": "📝 步骤1: 开发故事",
  164. "step2": "📄 步骤2: 开发剧本",
  165. "step3": "👥 步骤3: 提取角色",
  166. "step4": "🎨 步骤4: 创建分镜",
  167. "step5": "🖼️ 步骤5: 生成角色肖像",
  168. "step6": "📹 步骤6: 创建镜头树",
  169. "step7": "🎞️ 步骤7: 生成视频帧",
  170. "step8": "🎬 步骤8: 生成视频片段",
  171. "step9": "🔗 步骤9: 拼接最终视频"
  172. }
  173. return step_names.get(step_name, step_name)
  174. def save_uploaded_image(uploaded_file) -> str:
  175. """保存上传的图片并返回图片ID"""
  176. # 创建临时图片目录
  177. temp_image_dir = Path("temp_uploaded_images")
  178. temp_image_dir.mkdir(parents=True, exist_ok=True)
  179. # 生成图片ID(基于文件名和内容的哈希)
  180. file_content = uploaded_file.read()
  181. file_hash = hashlib.md5(file_content).hexdigest()[:8]
  182. file_name = uploaded_file.name
  183. image_id = f"{Path(file_name).stem}_{file_hash}"
  184. # 保存图片
  185. image_path = temp_image_dir / f"{image_id}{Path(file_name).suffix}"
  186. uploaded_file.seek(0) # 重置文件指针
  187. with open(image_path, "wb") as f:
  188. f.write(file_content)
  189. # 保存到 session_state
  190. st.session_state.uploaded_images[image_id] = {
  191. "path": str(image_path),
  192. "name": file_name
  193. }
  194. return image_id
  195. def parse_refer_image_references(text: str) -> Dict[str, List[str]]:
  196. """
  197. 解析文本中的图片引用,支持两种格式:
  198. 1. @图片ID - 全局引用(所有角色共享)
  199. 2. @角色名:图片ID - 特定角色引用
  200. 返回: {角色名: [图片路径列表]}
  201. """
  202. refer_image_map = {}
  203. # 匹配 @图片ID 或 @角色名:图片ID 的模式
  204. # 支持中文角色名和英文图片ID
  205. pattern = r'@([^@\s:]+)(?::([^@\s]+))?'
  206. matches = re.findall(pattern, text)
  207. for match in matches:
  208. if len(match) == 2:
  209. role_or_id, image_id = match
  210. if image_id: # 格式: @角色名:图片ID
  211. role_name = role_or_id.strip()
  212. if image_id in st.session_state.uploaded_images:
  213. image_path = st.session_state.uploaded_images[image_id]["path"]
  214. if role_name not in refer_image_map:
  215. refer_image_map[role_name] = []
  216. refer_image_map[role_name].append(image_path)
  217. else: # 格式: @图片ID(全局引用)
  218. image_id = role_or_id.strip()
  219. if image_id in st.session_state.uploaded_images:
  220. image_path = st.session_state.uploaded_images[image_id]["path"]
  221. # 全局引用使用特殊键 "__global__"
  222. if "__global__" not in refer_image_map:
  223. refer_image_map["__global__"] = []
  224. refer_image_map["__global__"].append(image_path)
  225. return refer_image_map
  226. def list_available_runs() -> List[Dict]:
  227. """列出所有可用的运行"""
  228. try:
  229. runs = st.session_state.run_manager.list_runs()
  230. return runs
  231. except Exception as e:
  232. logger.error(f"列出运行失败: {e}")
  233. return []
  234. def find_incomplete_run() -> Optional[str]:
  235. """查找未完成的运行"""
  236. runs = list_available_runs()
  237. for run_info in runs:
  238. run_path = Path(run_info["path"])
  239. state_file = run_path / "task_state.json"
  240. if state_file.exists():
  241. try:
  242. with open(state_file, 'r', encoding='utf-8') as f:
  243. state = json.load(f)
  244. steps = state.get("steps", {})
  245. has_failed = any(
  246. step.get("status") == "failed"
  247. for step in steps.values()
  248. )
  249. has_pending = any(
  250. step.get("status") in ["pending", "running"]
  251. for step in steps.values()
  252. )
  253. if has_failed or has_pending:
  254. return run_info["run_id"]
  255. except Exception as e:
  256. logger.warning(f"检查运行 {run_info['run_id']} 状态时出错: {e}")
  257. continue
  258. return None
  259. def run_pipeline_sync(
  260. idea: str,
  261. user_requirement: Optional[str] = None,
  262. refer_image_map: Optional[Dict[str, List[str]]] = None,
  263. run_id: Optional[str] = None,
  264. new_run: bool = False,
  265. max_retries: int = 3,
  266. status_queue: Optional[Queue] = None,
  267. run_manager: Optional[RunManager] = None
  268. ) -> Dict:
  269. """同步包装器,用于在线程中运行异步pipeline"""
  270. return asyncio.run(run_pipeline(
  271. idea=idea,
  272. user_requirement=user_requirement,
  273. refer_image_map=refer_image_map,
  274. run_id=run_id,
  275. new_run=new_run,
  276. max_retries=max_retries,
  277. status_queue=status_queue,
  278. run_manager=run_manager
  279. ))
  280. async def run_pipeline(
  281. idea: str,
  282. user_requirement: Optional[str] = None,
  283. refer_image_map: Optional[Dict[str, List[str]]] = None,
  284. run_id: Optional[str] = None,
  285. new_run: bool = False,
  286. max_retries: int = 3,
  287. status_queue: Optional[Queue] = None,
  288. run_manager: Optional[RunManager] = None
  289. ) -> Dict:
  290. """运行视频创作流程"""
  291. # 如果没有传入 run_manager,创建一个新的(线程中无法访问 session_state)
  292. if run_manager is None:
  293. run_manager = RunManager(base_output_dir="output")
  294. # 确定运行目录策略
  295. if new_run:
  296. run_output_dir = run_manager.create_run_directory()
  297. run_id = run_manager.get_run_id()
  298. elif run_id:
  299. run_output_dir = run_manager.create_run_directory(run_id=run_id)
  300. run_id = run_manager.get_run_id()
  301. else:
  302. # 默认创建新运行
  303. run_output_dir = run_manager.create_run_directory()
  304. run_id = run_manager.get_run_id()
  305. # 通过 status_queue 通知主线程更新 current_run_id(线程中无法直接修改 session_state)
  306. if status_queue:
  307. status_queue.put({
  308. "type": "set_run_id",
  309. "run_id": run_id
  310. })
  311. # 创建文件I/O处理器
  312. io_handler = FileIOHandler()
  313. # 创建任务管理器
  314. state_file = str(Path(run_output_dir) / "task_state.json")
  315. cache_dir = str(Path(run_output_dir) / "task_cache")
  316. manager = TaskManager(
  317. state_file=state_file,
  318. cache_dir=cache_dir
  319. )
  320. # 创建视频创作任务流
  321. pipeline = Idea2VideoPipeline(io_handler, run_output_dir, manager)
  322. # 注册步骤
  323. async def step1_func():
  324. return await pipeline.step1_develop_story(idea=idea, user_requirement=user_requirement)
  325. async def step2_func():
  326. return await pipeline.step2_develop_script(user_requirement=user_requirement)
  327. async def step3_func():
  328. return await pipeline.step3_extract_characters()
  329. async def step4_func():
  330. return await pipeline.step4_create_storyboard(user_requirement=user_requirement)
  331. async def step5_func():
  332. # 处理全局引用和角色特定引用
  333. global_refer_images = None
  334. role_specific_map = None
  335. if refer_image_map:
  336. # 创建副本以避免修改原始字典
  337. refer_image_map_copy = refer_image_map.copy()
  338. # 分离全局引用和角色特定引用
  339. if "__global__" in refer_image_map_copy:
  340. global_refer_images = refer_image_map_copy.pop("__global__")
  341. # 如果还有角色特定的映射,使用它
  342. if refer_image_map_copy:
  343. role_specific_map = refer_image_map_copy
  344. return await pipeline.step5_generate_portrait(
  345. size="2048x2048",
  346. refer_image=global_refer_images,
  347. refer_image_map=role_specific_map,
  348. style="写实"
  349. )
  350. async def step6_func():
  351. return await pipeline.step6_create_camera_tree()
  352. async def step7_func():
  353. return await pipeline.step7_generate_video_frames()
  354. async def step8_func():
  355. return await pipeline.step8_generate_video()
  356. async def step9_func():
  357. return await pipeline.step9_concat_clip()
  358. manager.register_step("step1", step1_func, force_rerun=False)
  359. manager.register_step("step2", step2_func, depends_on=["step1"], force_rerun=False)
  360. manager.register_step("step3", step3_func, depends_on=["step1"], force_rerun=False)
  361. manager.register_step("step4", step4_func, depends_on=["step2", "step3"], force_rerun=False)
  362. manager.register_step("step5", step5_func, depends_on=["step3"], force_rerun=False)
  363. manager.register_step("step6", step6_func, depends_on=["step4"], force_rerun=False)
  364. manager.register_step("step7", step7_func, depends_on=["step5", "step6"], force_rerun=False)
  365. manager.register_step("step8", step8_func, depends_on=["step7"], force_rerun=False)
  366. manager.register_step("step9", step9_func, depends_on=["step8"], force_rerun=False)
  367. # 执行所有步骤
  368. async def run_pipeline_async():
  369. step_order = ["step1", "step2", "step3", "step4", "step5", "step6", "step7", "step8", "step9"]
  370. # 如果提供了状态队列,在执行过程中发送状态更新
  371. if status_queue:
  372. # 发送初始状态
  373. status_queue.put({
  374. "type": "init",
  375. "run_id": run_id,
  376. "run_output_dir": run_output_dir
  377. })
  378. await manager.run_all_async(
  379. step_order=step_order,
  380. continue_on_error=False
  381. )
  382. # 发送完成状态
  383. if status_queue:
  384. status_queue.put({
  385. "type": "completed",
  386. "run_id": run_id,
  387. "run_output_dir": run_output_dir
  388. })
  389. # 重试机制
  390. last_exception = None
  391. total_attempts = max_retries + 1
  392. for attempt in range(total_attempts):
  393. try:
  394. if attempt > 0:
  395. wait_time = min(2 ** (attempt - 1), 60)
  396. await asyncio.sleep(wait_time)
  397. if status_queue:
  398. status_queue.put({
  399. "type": "retry",
  400. "attempt": attempt + 1,
  401. "total_attempts": total_attempts
  402. })
  403. await run_pipeline_async()
  404. break
  405. except Exception as e:
  406. last_exception = e
  407. if status_queue:
  408. status_queue.put({
  409. "type": "error",
  410. "error": str(e),
  411. "attempt": attempt + 1
  412. })
  413. if attempt == total_attempts - 1:
  414. raise last_exception
  415. continue
  416. return {
  417. "run_id": run_id,
  418. "run_output_dir": run_output_dir,
  419. "success": True
  420. }
  421. def display_video(video_path: str, width: str = "100%"):
  422. """显示视频,确保完整显示画面(支持各种宽高比:16:9、9:16、4:3、3:4、1:1等)"""
  423. video_path_obj = Path(video_path)
  424. if not video_path_obj.exists():
  425. st.error(f"视频文件不存在: {video_path}")
  426. return
  427. # 使用容器包装,确保视频有足够的显示空间
  428. with st.container():
  429. # 使用 st.video 显示视频,全局 CSS 会确保完整显示
  430. st.video(str(video_path_obj), format="video/mp4")
  431. def display_step_result(step_name: str, run_output_dir: str, step_data: Optional[Dict] = None):
  432. """显示单个步骤的结果"""
  433. run_path = Path(run_output_dir)
  434. step_display_names = {
  435. "step1": ("📝 步骤1: 开发故事", "story"),
  436. "step2": ("📄 步骤2: 开发剧本", "script"),
  437. "step3": ("👥 步骤3: 提取角色", "characters"),
  438. "step4": ("🎨 步骤4: 创建分镜", "storyboard"),
  439. "step5": ("🖼️ 步骤5: 生成角色肖像", "portrait"),
  440. "step6": ("📹 步骤6: 创建镜头树", "camera_tree"),
  441. "step7": ("🎞️ 步骤7: 生成视频帧", "video_frames"),
  442. "step8": ("🎬 步骤8: 生成视频片段", "video_clips"),
  443. "step9": ("🔗 步骤9: 拼接最终视频", "final_video")
  444. }
  445. display_name, file_prefix = step_display_names.get(step_name, (step_name, ""))
  446. # 根据步骤类型显示不同内容
  447. if step_name == "step1":
  448. story_file = run_path / "step1_story.json"
  449. if story_file.exists():
  450. with open(story_file, 'r', encoding='utf-8') as f:
  451. story = json.load(f)
  452. st.json(story, expanded=False)
  453. elif step_name == "step2":
  454. script_file = run_path / "step2_script.json"
  455. if script_file.exists():
  456. with open(script_file, 'r', encoding='utf-8') as f:
  457. script = json.load(f)
  458. st.json(script, expanded=False)
  459. elif step_name == "step3":
  460. characters_file = run_path / "step3_characters.json"
  461. if characters_file.exists():
  462. with open(characters_file, 'r', encoding='utf-8') as f:
  463. characters = json.load(f)
  464. st.json(characters, expanded=False)
  465. elif step_name == "step4":
  466. storyboard_file = run_path / "step4_storyboard.json"
  467. if storyboard_file.exists():
  468. with open(storyboard_file, 'r', encoding='utf-8') as f:
  469. storyboard = json.load(f)
  470. # 只显示摘要信息
  471. if isinstance(storyboard, dict):
  472. scenes_count = len(storyboard.get("storyboard", []))
  473. st.info(f"已创建 {scenes_count} 个场景的分镜")
  474. with st.expander("查看详细分镜"):
  475. st.json(storyboard, expanded=False)
  476. elif step_name == "step5":
  477. portraits_dir = run_path / "portraits"
  478. if portraits_dir.exists():
  479. portrait_files = sorted(list(portraits_dir.glob("*.jpg")) + list(portraits_dir.glob("*.png")))
  480. if portrait_files:
  481. st.info(f"已生成 {len(portrait_files)} 个角色肖像")
  482. cols = st.columns(min(len(portrait_files), 4))
  483. for idx, portrait_file in enumerate(portrait_files[:4]):
  484. with cols[idx % 4]:
  485. st.image(str(portrait_file), caption=portrait_file.name, use_container_width=True)
  486. elif step_name == "step7":
  487. frames_dir = run_path / "video_frames"
  488. if frames_dir.exists():
  489. frame_files = sorted(list(frames_dir.glob("*.png")))
  490. if frame_files:
  491. st.info(f"已生成 {len(frame_files)} 个视频帧")
  492. # 显示前8张预览
  493. cols = st.columns(4)
  494. for idx, frame_file in enumerate(frame_files[:8]):
  495. with cols[idx % 4]:
  496. st.image(str(frame_file), caption=frame_file.name, use_container_width=True)
  497. elif step_name == "step8":
  498. clips_dir = run_path / "video_clips"
  499. if clips_dir.exists():
  500. clip_files = sorted(list(clips_dir.glob("*.mp4")))
  501. if clip_files:
  502. st.info(f"已生成 {len(clip_files)} 个视频片段")
  503. # 显示第一个片段预览
  504. if clip_files:
  505. display_video(str(clip_files[0]))
  506. elif step_name == "step9":
  507. final_video = run_path / "video_save" / "final_video.mp4"
  508. if final_video.exists():
  509. st.success("✅ 最终视频已生成!")
  510. display_video(str(final_video))
  511. def display_run_results(run_output_dir: str):
  512. """显示运行结果"""
  513. run_path = Path(run_output_dir)
  514. # 显示最终视频
  515. final_video = run_path / "video_save" / "final_video.mp4"
  516. if final_video.exists():
  517. st.success("✅ 视频创作完成!")
  518. display_video(str(final_video))
  519. # 显示中间结果
  520. with st.expander("📊 查看所有中间结果", expanded=False):
  521. col1, col2 = st.columns(2)
  522. with col1:
  523. st.subheader("📝 故事")
  524. story_file = run_path / "step1_story.txt"
  525. if story_file.exists():
  526. with open(story_file, 'r', encoding='utf-8') as f:
  527. story = f.read()
  528. st.text(story)
  529. with col2:
  530. st.subheader("📄 剧本")
  531. script_file = run_path / "step2_script.json"
  532. if script_file.exists():
  533. with open(script_file, 'r', encoding='utf-8') as f:
  534. script = json.load(f)
  535. st.json(script)
  536. # 显示角色肖像
  537. portraits_dir = run_path / "portraits"
  538. if portraits_dir.exists():
  539. st.subheader("🖼️ 角色肖像")
  540. portrait_files = list(portraits_dir.glob("*.jpg")) + list(portraits_dir.glob("*.png"))
  541. if portrait_files:
  542. cols = st.columns(min(len(portrait_files), 4))
  543. for idx, portrait_file in enumerate(portrait_files[:4]):
  544. with cols[idx % 4]:
  545. st.image(str(portrait_file), caption=portrait_file.name)
  546. # 显示视频帧
  547. frames_dir = run_path / "video_frames"
  548. if frames_dir.exists():
  549. st.subheader("🎞️ 视频帧预览")
  550. frame_files = sorted(list(frames_dir.glob("*.png")))[:12] # 最多显示12张
  551. if frame_files:
  552. cols = st.columns(4)
  553. for idx, frame_file in enumerate(frame_files):
  554. with cols[idx % 4]:
  555. st.image(str(frame_file), caption=frame_file.name)
  556. # 侧边栏
  557. with st.sidebar:
  558. st.title("⚙️ 设置")
  559. # 用户要求输入
  560. st.subheader("📋 用户要求(可选)")
  561. user_requirement_input = st.text_area(
  562. "输入额外的用户要求",
  563. help="例如:设计三个场景、使用现代风格等",
  564. height=100
  565. )
  566. # 最大重试次数
  567. st.subheader("🔄 重试设置")
  568. max_retries = st.number_input("最大重试次数", min_value=0, max_value=10, value=3)
  569. # 运行选项
  570. st.subheader("运行选项")
  571. new_run = st.checkbox("强制创建新运行", value=False)
  572. resume_run = st.checkbox("继续未完成的运行", value=False)
  573. # 历史运行
  574. st.subheader("📚 历史运行")
  575. runs = list_available_runs()
  576. if runs:
  577. run_options = [f"{r['run_id']} - {r.get('created_at', 'N/A')}" for r in runs[:10]]
  578. selected_run_idx = st.selectbox("选择运行", options=[""] + run_options)
  579. if selected_run_idx:
  580. selected_run_id = runs[run_options.index(selected_run_idx)]["run_id"]
  581. if st.button("查看运行结果"):
  582. run_info = next((r for r in runs if r["run_id"] == selected_run_id), None)
  583. if run_info:
  584. st.session_state.current_run_id = selected_run_id
  585. st.rerun()
  586. else:
  587. st.info("暂无历史运行")
  588. # 参考图片映射
  589. st.subheader("🖼️ 参考图片映射")
  590. # 图片上传功能
  591. st.markdown("**方式1: 上传图片**")
  592. uploaded_images = st.file_uploader(
  593. "上传参考图片",
  594. type=["jpg", "jpeg", "png", "webp"],
  595. accept_multiple_files=True,
  596. help="上传图片后,可以在输入框中使用 @图片ID 或 @角色名:图片ID 来引用"
  597. )
  598. # 处理上传的图片
  599. if uploaded_images:
  600. for uploaded_file in uploaded_images:
  601. image_id = save_uploaded_image(uploaded_file)
  602. st.success(f"✅ 图片已上传: `{uploaded_file.name}` (ID: `{image_id}`)")
  603. st.caption(f"使用方式: `@{image_id}` 或 `@角色名:{image_id}`")
  604. # 显示已上传的图片
  605. if st.session_state.uploaded_images:
  606. st.markdown("**已上传的图片:**")
  607. for image_id, image_info in st.session_state.uploaded_images.items():
  608. col1, col2 = st.columns([3, 1])
  609. with col1:
  610. st.text(f"ID: `{image_id}` - {image_info['name']}")
  611. with col2:
  612. if st.button("删除", key=f"delete_{image_id}"):
  613. # 删除文件
  614. image_path = Path(image_info["path"])
  615. if image_path.exists():
  616. image_path.unlink()
  617. # 从 session_state 中删除
  618. del st.session_state.uploaded_images[image_id]
  619. st.rerun()
  620. # 显示图片预览
  621. with st.expander("预览已上传的图片"):
  622. cols = st.columns(min(len(st.session_state.uploaded_images), 3))
  623. for idx, (image_id, image_info) in enumerate(st.session_state.uploaded_images.items()):
  624. with cols[idx % 3]:
  625. if Path(image_info["path"]).exists():
  626. st.image(image_info["path"], caption=f"{image_id}\n{image_info['name']}", use_container_width=True)
  627. # JSON文件上传(方式2)
  628. st.markdown("**方式2: 上传JSON映射文件**")
  629. refer_image_file = st.file_uploader(
  630. "上传参考图片映射文件 (JSON)",
  631. type=["json"],
  632. help="格式: {\"角色名\": [\"图片路径1\", \"图片路径2\"]}",
  633. key="refer_image_json_file"
  634. )
  635. refer_image_map_from_file = None
  636. if refer_image_file:
  637. try:
  638. refer_image_map_from_file = json.load(refer_image_file)
  639. st.success("✅ 参考图片映射文件已加载")
  640. st.json(refer_image_map_from_file)
  641. except Exception as e:
  642. st.error(f"❌ 解析文件失败: {e}")
  643. # 使用说明
  644. with st.expander("📖 使用说明"):
  645. st.markdown("""
  646. **在输入框中引用图片的方式:**
  647. 1. **全局引用**(所有角色共享):
  648. ```
  649. @图片ID
  650. ```
  651. 例如: `@img_001` 或 `@abc123`
  652. 2. **特定角色引用**:
  653. ```
  654. @角色名:图片ID
  655. ```
  656. 例如: `@林小星:img_001` 或 `@主角:abc123`
  657. 3. **多个引用**:
  658. 可以在同一句话中使用多个引用,例如:
  659. ```
  660. 我想创作一个故事 @林小星:img_001 @阿凯:img_002
  661. ```
  662. **优先级:**
  663. - JSON文件映射 > 输入框中的@引用
  664. - 如果同时使用,JSON文件的映射会覆盖@引用
  665. """)
  666. # 主界面
  667. # 首先添加全局视频 CSS,确保每次页面加载都应用
  668. add_video_css()
  669. st.title("🎬 Idea2Video - 视频创作助手")
  670. st.markdown("---")
  671. # 显示当前运行
  672. if st.session_state.current_run_id:
  673. st.info(f"当前运行ID: `{st.session_state.current_run_id}`")
  674. # 聊天界面
  675. for message in st.session_state.messages:
  676. with st.chat_message(message["role"]):
  677. st.markdown(message["content"])
  678. if "run_id" in message:
  679. st.caption(f"运行ID: {message['run_id']}")
  680. # 用户输入
  681. if prompt := st.chat_input("请输入您的创意(idea)... 可使用 @图片ID 或 @角色名:图片ID 引用图片"):
  682. # 解析输入中的图片引用
  683. refer_image_map_from_input = parse_refer_image_references(prompt)
  684. # 合并参考图片映射(优先级:JSON文件 > 输入框引用)
  685. refer_image_map = None
  686. if refer_image_map_from_file:
  687. # JSON文件优先级最高,直接使用
  688. refer_image_map = refer_image_map_from_file.copy()
  689. # 只添加JSON文件中没有的角色(包括全局引用)
  690. if refer_image_map_from_input:
  691. for role_name, image_paths in refer_image_map_from_input.items():
  692. if role_name not in refer_image_map:
  693. refer_image_map[role_name] = image_paths
  694. elif refer_image_map_from_input:
  695. # 只有输入框引用时,直接使用
  696. refer_image_map = refer_image_map_from_input.copy()
  697. # 添加用户消息
  698. st.session_state.messages.append({"role": "user", "content": prompt})
  699. with st.chat_message("user"):
  700. # 显示原始输入
  701. st.markdown(prompt)
  702. # 显示解析到的引用(如果有)
  703. if refer_image_map_from_input:
  704. ref_info = []
  705. for role_name, image_paths in refer_image_map_from_input.items():
  706. if role_name == "__global__":
  707. ref_info.append(f"全局引用: {len(image_paths)} 张图片")
  708. else:
  709. ref_info.append(f"{role_name}: {len(image_paths)} 张图片")
  710. if ref_info:
  711. st.info(f"📎 检测到图片引用: {', '.join(ref_info)}")
  712. # 检查是否有用户要求(从侧边栏或之前的消息中获取)
  713. user_requirement = None
  714. # 注意:user_requirement 可以通过后续对话提供,这里先设为 None
  715. # 显示助手响应
  716. with st.chat_message("assistant"):
  717. message_placeholder = st.empty()
  718. progress_placeholder = st.empty()
  719. # 在聊天消息外部创建步骤状态显示区域(避免布局限制)
  720. step_names = ["step1", "step2", "step3", "step4", "step5", "step6", "step7", "step8", "step9"]
  721. steps_status_container = st.container()
  722. try:
  723. message_placeholder.markdown("🤔 正在思考您的创意...")
  724. # 检查是否需要继续运行
  725. run_id_to_use = None
  726. if resume_run:
  727. incomplete_run_id = find_incomplete_run()
  728. if incomplete_run_id:
  729. run_id_to_use = incomplete_run_id
  730. message_placeholder.markdown(f"🔄 继续执行未完成的运行: {incomplete_run_id}")
  731. # 获取用户要求(优先使用侧边栏输入)
  732. final_user_requirement = user_requirement_input if user_requirement_input else user_requirement
  733. # 运行流程
  734. st.session_state.pipeline_running = True
  735. # 创建进度条和状态显示
  736. progress_bar = progress_placeholder.progress(0)
  737. status_text = progress_placeholder.empty()
  738. # 在外部容器中创建步骤状态显示(使用列布局)
  739. with steps_status_container:
  740. st.markdown("**📋 执行步骤状态:**")
  741. # 使用3列布局显示步骤
  742. cols_per_row = 3
  743. step_cols = [st.columns(cols_per_row) for _ in range((len(step_names) + cols_per_row - 1) // cols_per_row)]
  744. # 为每个步骤创建独立的显示区域
  745. step_displays = {}
  746. for idx, step_name in enumerate(step_names):
  747. row_idx = idx // cols_per_row
  748. col_idx = idx % cols_per_row
  749. step_displays[step_name] = {
  750. "display": step_cols[row_idx][col_idx].empty(),
  751. "status": None,
  752. "result_shown": False
  753. }
  754. message_placeholder.markdown("🚀 开始执行视频创作流程...")
  755. status_text.text("⏳ 正在初始化...")
  756. # 创建状态队列用于线程通信
  757. status_queue = Queue()
  758. # 使用字典存储结果,避免nonlocal作用域问题
  759. thread_result = {"result": None, "error": None}
  760. # 在主线程中获取 run_manager(线程中无法访问 session_state)
  761. run_manager = st.session_state.run_manager
  762. # 在线程中运行pipeline
  763. def run_in_thread():
  764. try:
  765. thread_result["result"] = run_pipeline_sync(
  766. idea=prompt,
  767. user_requirement=final_user_requirement,
  768. refer_image_map=refer_image_map,
  769. run_id=run_id_to_use,
  770. new_run=new_run,
  771. max_retries=max_retries,
  772. status_queue=status_queue,
  773. run_manager=run_manager
  774. )
  775. except Exception as e:
  776. thread_result["error"] = e
  777. # 启动执行线程
  778. exec_thread = threading.Thread(target=run_in_thread, daemon=True)
  779. exec_thread.start()
  780. # 实时更新UI
  781. run_output_dir = None
  782. last_update_time = time.time()
  783. update_interval = 1.0 # 每1秒更新一次
  784. # 检查状态队列获取初始run_output_dir
  785. try:
  786. while True:
  787. status_update = status_queue.get_nowait()
  788. if status_update["type"] == "init":
  789. run_output_dir = status_update["run_output_dir"]
  790. message_placeholder.markdown(f"🚀 开始执行视频创作流程...\n运行ID: `{status_update['run_id']}`")
  791. break
  792. except Empty:
  793. pass
  794. # 如果还没有run_output_dir,等待一下
  795. if not run_output_dir:
  796. time.sleep(0.5)
  797. # 尝试从最新的运行中获取
  798. runs = list_available_runs()
  799. if runs:
  800. latest_run = runs[0]
  801. run_output_dir = latest_run["path"]
  802. # 实时更新循环
  803. max_iterations = 3600 # 最多等待1小时(3600秒)
  804. iteration = 0
  805. while exec_thread.is_alive() and iteration < max_iterations:
  806. current_time = time.time()
  807. # 定期更新状态
  808. if current_time - last_update_time >= update_interval:
  809. # 检查状态队列
  810. try:
  811. while True:
  812. status_update = status_queue.get_nowait()
  813. if status_update["type"] == "init":
  814. run_output_dir = status_update["run_output_dir"]
  815. message_placeholder.markdown(f"🚀 开始执行视频创作流程...\n运行ID: `{status_update['run_id']}`")
  816. elif status_update["type"] == "set_run_id":
  817. # 在主线程中更新 session_state
  818. st.session_state.current_run_id = status_update["run_id"]
  819. elif status_update["type"] == "retry":
  820. message_placeholder.markdown(f"🔄 第 {status_update['attempt']} 次重试...")
  821. elif status_update["type"] == "completed":
  822. run_output_dir = status_update["run_output_dir"]
  823. message_placeholder.markdown("✅ 视频创作流程执行完成!")
  824. elif status_update["type"] == "error":
  825. message_placeholder.error(f"❌ 执行出错: {status_update['error']}")
  826. except Empty:
  827. pass
  828. # 更新步骤状态和进度
  829. if run_output_dir and Path(run_output_dir).exists():
  830. state = load_task_state(run_output_dir)
  831. if state:
  832. completed_steps = 0
  833. running_steps = []
  834. for step_name in step_names:
  835. status = get_step_status(state, step_name)
  836. # 更新步骤显示
  837. step_info = step_displays[step_name]
  838. if step_info["status"] != status:
  839. step_info["status"] = status
  840. status_emoji = {
  841. "completed": "✅",
  842. "running": "🔄",
  843. "failed": "❌",
  844. "pending": "⏳"
  845. }.get(status, "❓")
  846. # 显示步骤状态
  847. step_title = f"{status_emoji} {format_step_name(step_name)}"
  848. if status == "running":
  849. step_title += " (执行中...)"
  850. running_steps.append(step_name)
  851. elif status == "completed":
  852. step_title += " (已完成)"
  853. completed_steps += 1
  854. elif status == "failed":
  855. step_title += " (失败)"
  856. # 更新显示
  857. step_info["display"].markdown(f"**{step_title}**")
  858. elif step_info["status"] == "running":
  859. running_steps.append(step_name)
  860. elif step_info["status"] == "completed":
  861. completed_steps += 1
  862. # 更新进度条
  863. progress = completed_steps / len(step_names)
  864. progress_bar.progress(progress)
  865. if running_steps:
  866. current_step_name = running_steps[0]
  867. status_text.text(f"🔄 当前执行: {format_step_name(current_step_name)} ({completed_steps}/{len(step_names)} 已完成)")
  868. else:
  869. status_text.text(f"⏳ 等待中... ({completed_steps}/{len(step_names)} 已完成)")
  870. last_update_time = current_time
  871. iteration += 1
  872. # 短暂休眠,避免CPU占用过高
  873. time.sleep(0.2)
  874. # 等待线程完成
  875. exec_thread.join(timeout=1)
  876. # 处理最终结果
  877. if thread_result["error"]:
  878. message_placeholder.error(f"❌ 执行失败: {str(thread_result['error'])}")
  879. raise thread_result["error"]
  880. # 初始化 pipeline_result,避免未定义错误
  881. pipeline_result = None
  882. if thread_result["result"]:
  883. pipeline_result = thread_result["result"]
  884. run_output_dir = pipeline_result["run_output_dir"]
  885. # 最终更新所有步骤状态
  886. state = load_task_state(run_output_dir)
  887. if state:
  888. completed_steps = 0
  889. for step_name in step_names:
  890. status = get_step_status(state, step_name)
  891. if status == "completed":
  892. completed_steps += 1
  893. progress_bar.progress(1.0)
  894. status_text.text(f"✅ 所有步骤已完成 ({completed_steps}/{len(step_names)})")
  895. message_placeholder.markdown("✅ 视频创作完成!")
  896. # 显示所有步骤的详细结果(使用新的容器)
  897. st.markdown("---")
  898. st.subheader("📊 步骤执行结果详情")
  899. # 按列显示步骤结果(使用3列布局,更紧凑)
  900. cols_per_row = 3
  901. num_rows = (len(step_names) + cols_per_row - 1) // cols_per_row
  902. step_cols_grid = [st.columns(cols_per_row) for _ in range(num_rows)]
  903. for idx, step_name in enumerate(step_names):
  904. if run_output_dir:
  905. state = load_task_state(run_output_dir)
  906. if state:
  907. status = get_step_status(state, step_name)
  908. row_idx = idx // cols_per_row
  909. col_idx = idx % cols_per_row
  910. with step_cols_grid[row_idx][col_idx]:
  911. status_emoji = {
  912. "completed": "✅",
  913. "running": "🔄",
  914. "failed": "❌",
  915. "pending": "⏳"
  916. }.get(status, "❓")
  917. step_title = f"{status_emoji} {format_step_name(step_name)}"
  918. if status == "completed":
  919. with st.expander(step_title, expanded=False):
  920. display_step_result(step_name, run_output_dir)
  921. else:
  922. st.info(step_title)
  923. # 显示最终视频(如果存在)
  924. display_run_results(run_output_dir)
  925. st.session_state.pipeline_running = False
  926. # 添加助手消息
  927. if pipeline_result:
  928. response = f"✅ 视频创作完成!\n\n运行ID: `{pipeline_result['run_id']}`\n输出目录: `{pipeline_result['run_output_dir']}`"
  929. st.session_state.messages.append({
  930. "role": "assistant",
  931. "content": response,
  932. "run_id": pipeline_result["run_id"]
  933. })
  934. except Exception as e:
  935. st.session_state.pipeline_running = False
  936. error_msg = f"❌ 执行失败: {str(e)}"
  937. message_placeholder.error(error_msg)
  938. st.session_state.messages.append({
  939. "role": "assistant",
  940. "content": error_msg
  941. })
  942. logger.error(f"执行失败: {e}", exc_info=True)
  943. # 如果当前有运行ID,显示结果
  944. if st.session_state.current_run_id:
  945. runs = list_available_runs()
  946. current_run = next((r for r in runs if r["run_id"] == st.session_state.current_run_id), None)
  947. if current_run:
  948. st.markdown("---")
  949. st.subheader("📊 当前运行结果")
  950. display_run_results(current_run["path"])
  951. # 显示任务状态
  952. state = load_task_state(current_run["path"])
  953. if state:
  954. st.subheader("📈 任务状态")
  955. steps = state.get("steps", {})
  956. for step_name in ["step1", "step2", "step3", "step4", "step5", "step6", "step7", "step8", "step9"]:
  957. step_info = steps.get(step_name, {})
  958. status = step_info.get("status", "pending")
  959. status_emoji = {
  960. "completed": "✅",
  961. "running": "🔄",
  962. "failed": "❌",
  963. "pending": "⏳"
  964. }.get(status, "❓")
  965. st.write(f"{status_emoji} {format_step_name(step_name)}: {status}")