llm_director.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import os
  2. import re
  3. import json
  4. import httpx
  5. from volcenginesdkarkruntime import Ark
  6. client = Ark(
  7. base_url="https://ark.cn-beijing.volces.com/api/v3",
  8. api_key="817dff39-5586-4f9b-acba-55004167c0b1",
  9. timeout=1800
  10. )
  11. def read_input(json_file):
  12. with open(json_file, 'r', encoding='utf-8') as file:
  13. data = json.load(file)
  14. # 将数据转换为字符串
  15. json_string = json.dumps(data, ensure_ascii=False, indent=4)
  16. return json_string
  17. def convert_json_to_dict(json_string, output_file_path=None):
  18. match = re.search(r'\{.*\}', json_string, re.DOTALL)
  19. if match:
  20. json_content = match.group(0) # 提取匹配的内容
  21. else:
  22. print("错误: 未找到有效的 JSON 内容")
  23. return {}
  24. try:
  25. # 尝试将提取的内容转换为字典
  26. data_dict = json.loads(json_content)
  27. if output_file_path:
  28. with open(output_file_path, 'w', encoding='utf-8') as file:
  29. json.dump(data_dict, file, ensure_ascii=False, indent=4)
  30. return data_dict
  31. except json.JSONDecodeError as e:
  32. print(f"JSON 解码错误: {e}")
  33. return {}
  34. def caption_correct(user_prompt):
  35. count = len(user_prompt)
  36. cut = "不需要断句,直接输出纠正错别字后的文本"
  37. if count > 15:
  38. cut = """必须要断句,且只能断句一处;用符号"-"进行断句,断句位置要合理"""
  39. system_prompt = f"""
  40. ## 对输入文本进行错别字纠正,错别字通常都是因为字词的发音相似而引起的。如"歌莉娅"写成了"哥李呀";"的"写成了"得"。
  41. ## 用户输入:{user_prompt}
  42. ## 要求:
  43. - 需要去除文本中的标点符号。
  44. - {cut}
  45. - 与"歌莉娅"发音相似的词,都要纠正为"歌莉娅"
  46. - 只输出纠正错别字后的文本,不能有任何多余的文本输出
  47. """
  48. completion = client.chat.completions.create(
  49. messages = [
  50. {"role": "system", "content": system_prompt},
  51. ],
  52. model="deepseek-v3-241226", # ep-20241018084532-cgm84 deepseek-v3-241226 deepseek-r1-250120
  53. temperature = 0.01,
  54. max_tokens = 500
  55. )
  56. result = completion.choices[0].message.content.replace("-", "\n")
  57. return result
  58. def director(user_prompt):
  59. system_prompt = """
  60. ## 我需要剪辑一个衣服口播讲解视频,该视频需要包含三个视频片段;请从用户输入中挑选讲解衣服面料、版型、工艺的视频片段各一到三个;并输出三个衣服口播讲解视频脚本(各视频片段组合方式);并以JSON格式进行输出。
  61. ## 输出案例:
  62. ```json
  63. {
  64. "面料":["clip_001.mp4", "clip_004.mp4", "clip_006.mp4"],
  65. "版型":["clip_011.mp4", "clip_002.mp4", "clip_007.mp4"],
  66. "工艺":["clip_012.mp4", "clip_013.mp4", "clip_014.mp4"],
  67. "脚本":[["clip_001.mp4", "clip_011.mp4", "clip_013.mp4"], ["clip_004.mp4", "clip_014.mp4", "clip_002.mp4"], ["clip_006.mp4", "clip_011.mp4", "clip_012.mp4"]]
  68. }
  69. ## 严格按照输出案例的格式输出结果,不能输出任何多余的内容。
  70. ## 尽可能找句子比较长的视频片段
  71. """
  72. completion = client.chat.completions.create(
  73. messages = [
  74. {"role": "system", "content": system_prompt},
  75. {"role": "user", "content": user_prompt},
  76. ],
  77. model="deepseek-r1-250120", # ep-20241018084532-cgm84 deepseek-v3-241226 deepseek-r1-250120
  78. temperature = 0.01,
  79. max_tokens = 500
  80. )
  81. return completion.choices[0].message.content
  82. def director_json(json_path):
  83. user_input = read_input(json_path)
  84. answer = director(user_input)
  85. output_path = json_path.replace("filter_4", "script")
  86. dict_answer = convert_json_to_dict(answer, output_path)
  87. return dict_answer
  88. if __name__ == "__main__":
  89. print(director_json("output/filter_4/videoa.json"))