media_generator.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. import os
  2. import time
  3. import requests
  4. import json
  5. import threading
  6. import asyncio
  7. import aiohttp
  8. from typing import Optional, Dict, Callable
  9. from dotenv import load_dotenv
  10. from utils.tools import encode_image, download_image, download_video
  11. from utils.upload import upload_file_to_tos
  12. from utils.logger_config import setup_logger
  13. load_dotenv()
  14. logger = setup_logger(__name__)
  15. class ArkImageGenerator:
  16. """Ark 图片生成 API 封装类"""
  17. def __init__(
  18. self,
  19. auth_token: str = None,
  20. model: str = "doubao-seedream-4-0-250828",
  21. sequential_generation: str = "disabled",
  22. response_format: str = "url",
  23. stream: bool = False,
  24. watermark: bool = True,
  25. timeout: int = 120
  26. ):
  27. """
  28. 初始化图片生成器
  29. 参数:
  30. auth_token: 认证令牌(Bearer Token)
  31. model: 模型名称(固定配置)
  32. sequential_generation: 序列生成开关(固定配置)
  33. response_format: 响应格式(固定配置)
  34. stream: 流式响应开关(固定配置)
  35. watermark: 水印开关(固定配置)
  36. timeout: 请求超时时间(秒)
  37. """
  38. self.api_url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
  39. if not auth_token:
  40. auth_token = os.getenv("ARK_API_KEY")
  41. self.headers = {
  42. "Content-Type": "application/json",
  43. "Authorization": f"Bearer {auth_token}"
  44. }
  45. # 固定配置参数
  46. self.config = {
  47. "model": model,
  48. "sequential_image_generation": sequential_generation,
  49. "response_format": response_format,
  50. "stream": stream,
  51. "watermark": watermark
  52. }
  53. self.timeout = timeout
  54. async def generate_without_refer(self, prompt: str, filename: str, size: str = "1440x2560") -> Optional[Dict]:
  55. if not prompt:
  56. logger.info("错误:prompt不能为空")
  57. return None
  58. payload = {
  59. **self.config,
  60. "prompt": prompt,
  61. "size": size
  62. }
  63. try:
  64. # 使用 aiohttp 进行异步请求
  65. async with aiohttp.ClientSession() as session:
  66. async with session.post(
  67. url=self.api_url,
  68. headers=self.headers,
  69. json=payload,
  70. timeout=self.timeout
  71. ) as response:
  72. response.raise_for_status()
  73. result_data = await response.json()
  74. result_image = result_data["data"][0]["url"]
  75. output_path = "./output/" + filename
  76. download_image(result_image, output_path)
  77. return result_image
  78. except Exception as e:
  79. logger.info(f"错误:生成图片时发生异常:{e}")
  80. return None
  81. def generate(self, prompt: str, image_url: str, filename: str, size: str = "1440x2560") -> Optional[Dict]:
  82. # 验证必填参数
  83. if not prompt or not image_url:
  84. logger.info("错误:prompt和image_url不能为空")
  85. return None
  86. # 如果image_url为图片路径,则编码为base64格式
  87. reference_image = encode_image(image_url) if "http" not in image_url else image_url
  88. # 构建请求体(合并固定配置和动态参数)
  89. payload = {
  90. **self.config,
  91. "prompt": prompt,
  92. "image": reference_image,
  93. "size": size
  94. }
  95. try:
  96. response = requests.post(
  97. url=self.api_url,
  98. headers=self.headers,
  99. data=json.dumps(payload),
  100. timeout=self.timeout
  101. )
  102. response.raise_for_status()
  103. result_image = response.json()["data"][0]["url"]
  104. output_path = "./output/" + filename
  105. download_image(result_image, output_path)
  106. return result_image
  107. except requests.exceptions.Timeout:
  108. logger.info(f"错误:请求超时({self.timeout}秒)")
  109. except requests.exceptions.ConnectionError:
  110. logger.info("错误:网络连接失败,请检查API地址")
  111. except requests.exceptions.HTTPError as e:
  112. logger.info(f"错误:HTTP请求失败(状态码{response.status_code}):{e}")
  113. if response.text:
  114. logger.info(f"API错误详情:{response.text}")
  115. except json.JSONDecodeError:
  116. logger.info("错误:响应内容不是合法JSON")
  117. except Exception as e:
  118. logger.info(f"错误:生成图片时发生异常:{e}")
  119. return None
  120. class ArkVideoGenerator:
  121. """Ark 图生视频 API 封装类,支持通过参考图和文本描述生成视频"""
  122. def __init__(
  123. self,
  124. auth_token: str = None,
  125. model: str = "doubao-seedance-1-0-pro-250528",
  126. timeout: int = 60,
  127. poll_interval: int = 5, # 轮询间隔(秒)
  128. max_poll_time: int = 500 # 最大轮询总时间(秒)
  129. ):
  130. # 固定 API 端点
  131. self.api_url = "https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks"
  132. if not auth_token:
  133. auth_token = os.getenv("ARK_API_KEY")
  134. # 固定请求头
  135. self.headers = {
  136. "Content-Type": "application/json",
  137. "Authorization": f"Bearer {auth_token}"
  138. }
  139. # 固定配置参数(模型、超时)
  140. self.model = model
  141. self.timeout = timeout
  142. self.poll_interval = poll_interval
  143. self.max_poll_time = max_poll_time
  144. def create_video_task(
  145. self,
  146. prompt: str,
  147. image_url: str,
  148. gen_params: str = ""
  149. ) -> Optional[Dict]:
  150. # 1. 验证动态参数合法性
  151. if not prompt.strip():
  152. logger.info("错误:文本描述(text_prompt)不能为空")
  153. return None
  154. if not image_url.strip():
  155. logger.info("错误:参考图 URL(reference_image_url)不能为空")
  156. return None
  157. reference_image = upload_file_to_tos(image_url) if "http" not in image_url else image_url
  158. logger.info(f"视频生成提示词: {prompt + gen_params}")
  159. logger.info(f"视频生成参考图片: {reference_image}")
  160. # refernece_image = encode_image(image_url) if "http" not in image_url else image_url
  161. # 检查reference_image是否可访问
  162. try:
  163. head_response = requests.head(reference_image, timeout=5)
  164. if head_response.status_code != 200:
  165. logger.warning(f"参考图片URL可能无法访问,状态码: {head_response.status_code}")
  166. except Exception as e:
  167. logger.warning(f"验证参考图片URL可访问性时出错: {str(e)}")
  168. return None
  169. # 2. 构建请求体(按 API 要求格式组装 content 列表)
  170. payload = {
  171. "model": self.model,
  172. "content": [
  173. {
  174. "type": "text",
  175. "text": prompt + gen_params
  176. },
  177. {
  178. "type": "image_url",
  179. "image_url": {
  180. "url": reference_image
  181. }
  182. }
  183. ]
  184. }
  185. # 3. 发送 POST 请求并处理响应
  186. try:
  187. response = requests.post(
  188. url=self.api_url,
  189. headers=self.headers,
  190. json=payload,
  191. timeout=self.timeout
  192. )
  193. response.raise_for_status()
  194. return response.json()
  195. except Exception as e:
  196. logger.info(f"创建任务失败:{str(e)}")
  197. return None
  198. def query_video_task(self, task_id: str) -> Optional[Dict]:
  199. """
  200. 新增:查询图生视频任务结果
  201. 参数:
  202. task_id: 视频任务 ID(从 create_video_task 响应中获取)
  203. 返回:
  204. 任务结果详情(含视频状态、视频 URL 等)或 None
  205. """
  206. # 1. 验证任务 ID
  207. if not task_id.strip():
  208. logger.info("错误:任务 ID(task_id)不能为空")
  209. return None
  210. # 2. 构建查询 URL(拼接 task_id)
  211. query_url = f"{self.api_url}/{task_id}"
  212. # 3. 发送 GET 请求查询结果
  213. try:
  214. response = requests.get(
  215. url=query_url,
  216. headers=self.headers,
  217. timeout=self.timeout
  218. )
  219. # 触发 HTTP 错误(如 404 任务不存在、401 令牌无效)
  220. response.raise_for_status()
  221. return response.json()
  222. except Exception as e:
  223. logger.info(f"查询任务失败:{str(e)}")
  224. return None
  225. def _background_poll(
  226. self,
  227. task_id: str,
  228. filename: str,
  229. callback: Callable[[str, str, Optional[Dict], Optional[str]], None]
  230. ):
  231. """
  232. 后台轮询任务状态的线程函数
  233. :param task_id: 任务ID
  234. :param callback: 回调函数,参数为 (task_id, 成功结果, 错误信息)
  235. """
  236. start_time = time.time()
  237. while True:
  238. elapsed = time.time() - start_time
  239. if elapsed > self.max_poll_time:
  240. callback(task_id, filename, None, f"任务超时(超过 {self.max_poll_time} 秒)")
  241. # 查询任务状态
  242. result = self.query_video_task(task_id)
  243. if not result:
  244. time.sleep(self.poll_interval)
  245. continue
  246. # 解析状态
  247. status = result.get("status", "").lower()
  248. if status == "succeeded":
  249. callback(task_id, filename, result, None)
  250. return
  251. elif status == "failed":
  252. error_msg = result.get("error", {}).get("message", "未知错误")
  253. callback(task_id, filename, None, error_msg)
  254. return
  255. elif status in ["pending", "processing"]:
  256. logger.info(f"任务 {task_id} 处理中({int(elapsed)}秒),状态:{status}")
  257. time.sleep(self.poll_interval)
  258. else:
  259. logger.info(f"任务 {task_id} 未知状态:{status},继续等待...")
  260. time.sleep(self.poll_interval)
  261. def create_video_task_async(
  262. self,
  263. prompt: str,
  264. image_url: str,
  265. gen_params: str,
  266. filename: str,
  267. callback: Callable[[str, str, Optional[Dict], Optional[str]], None]
  268. ) -> Optional[str]:
  269. # 1. 提交任务
  270. task_response = self.create_video_task(prompt, image_url, gen_params)
  271. print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
  272. print(f"task_response: {task_response}")
  273. if not task_response or "id" not in task_response:
  274. logger.info("任务提交失败,无法启动后台轮询")
  275. return None
  276. task_id = task_response["id"]
  277. logger.info(f"任务提交成功,task_id: {task_id},启动后台轮询...")
  278. # 2. 启动后台线程轮询结果
  279. poll_thread = threading.Thread(
  280. target=self._background_poll,
  281. args=(task_id, filename, callback),
  282. daemon=True # 守护线程:主程序退出时自动结束
  283. )
  284. poll_thread.start()
  285. return task_id
  286. # 1. 定义回调函数:任务完成/失败时会被调用
  287. def handle_video_result(task_id: str, filename, result: Optional[Dict], error: Optional[str]) -> None:
  288. if error:
  289. logger.info(f"\n任务 {task_id} 处理失败:{error}")
  290. else:
  291. video_url = result.get("content", {}).get("video_url")
  292. output_path = "./output/" + filename
  293. download_video(video_url, output_path)
  294. logger.info(f"生成视频已下载:{output_path}")
  295. # API配置
  296. API_URL = os.getenv("AUDIO_GEN_API")
  297. def audio_generator(text, spk_audio="./data/audio/voice_07.wav", emo_audio="./data/audio/emo_sad.wav"):
  298. """调用TTS API生成语音"""
  299. payload = {
  300. "text": text,
  301. "spk_audio_prompt": spk_audio,
  302. "emo_audio_prompt": emo_audio
  303. }
  304. response = requests.post(API_URL, json=payload)
  305. if response.status_code == 200:
  306. result = response.json()
  307. if result["status"] == "success":
  308. print(f"语音生成成功: {result['audio_file']}")
  309. return result["audio_file"]
  310. else:
  311. print(f"请求失败: {response.text}")
  312. return None
  313. image_generator = ArkImageGenerator()
  314. video_create = ArkVideoGenerator()
  315. if __name__ == "__main__":
  316. # 1. 初始化生成器(配置固定参数)
  317. image_generator = ArkImageGenerator()
  318. video_create = ArkVideoGenerator()
  319. # 2. 调用生成方法(仅传入动态参数)
  320. # result = image_generator.generate(
  321. # prompt="狗狗在草地上追逐蒲公英",
  322. # image_url="https://ark-project.tos-cn-beijing.volces.com/doc_image/seedream4_imageToimage.png",
  323. # filename="1.jpg"
  324. # )
  325. # logger.info(f"result:{result}")
  326. # print(video_generator.query_video_task("cgt-20251022103303-9hgrr"))
  327. # cgt-20251022101137-852fw cgt-20251022102536-kt8pj cgt-20251022103303-9hgrr
  328. # task_id = video_create.create_video_task_async(
  329. # prompt="狗狗不停地在草地上跳跃",
  330. # image_url="https://testdgxcx-oss.gloria.com.cn/video-create/new_frame_scene0_camera0_shot1.png",
  331. # gen_params="",
  332. # filename="1.mp4",
  333. # callback=handle_video_result
  334. # )
  335. task_id = video_create.create_video_task(
  336. prompt="狗狗不停地在草地上跳跃",
  337. image_url="https://testdgxcx-oss.gloria.com.cn/video-create/new_frame_scene0_camera0_shot1.png",
  338. gen_params=""
  339. )
  340. print(f"task_id: {task_id}")
  341. if task_id:
  342. print("\n主流程:任务已提交,开始执行其他操作...")
  343. for i in range(10):
  344. print(f"主流程:正在执行第 {i+1} 步操作...")
  345. time.sleep(1) # 模拟主流程耗时操作
  346. print("主流程:所有操作执行完毕,等待后台任务结果(若未完成)...")
  347. # 防止主程序提前退出(实际生产环境可能有其他阻塞逻辑)
  348. # 这里仅为演示:等待所有后台线程完成
  349. while threading.active_count() > 1:
  350. time.sleep(1)