image_generator.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. import os
  2. import time
  3. import requests
  4. import json
  5. import threading
  6. import asyncio
  7. import aiohttp
  8. from typing import Optional, Dict, Callable
  9. from dotenv import load_dotenv
  10. from interfaces.image_output import ImageOutput
  11. from utils.tools import encode_image, download_image, download_video
  12. from utils.upload import upload_file_to_tos
  13. from utils.logger_config import setup_logger
  14. load_dotenv()
  15. logger = setup_logger(__name__)
  16. class ArkImageGenerator:
  17. """Ark 图片生成 API 封装类"""
  18. def __init__(
  19. self,
  20. auth_token: str = None,
  21. model: str = "doubao-seedream-4-0-250828",
  22. sequential_generation: str = "disabled",
  23. response_format: str = "url",
  24. stream: bool = False,
  25. watermark: bool = True,
  26. timeout: int = 120
  27. ):
  28. """
  29. 初始化图片生成器
  30. 参数:
  31. auth_token: 认证令牌(Bearer Token)
  32. model: 模型名称(固定配置)
  33. sequential_generation: 序列生成开关(固定配置)
  34. response_format: 响应格式(固定配置)
  35. stream: 流式响应开关(固定配置)
  36. watermark: 水印开关(固定配置)
  37. timeout: 请求超时时间(秒)
  38. """
  39. self.api_url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
  40. if not auth_token:
  41. auth_token = os.getenv("ARK_API_KEY")
  42. self.headers = {
  43. "Content-Type": "application/json",
  44. "Authorization": f"Bearer {auth_token}"
  45. }
  46. # 固定配置参数
  47. self.config = {
  48. "model": model,
  49. "sequential_image_generation": sequential_generation,
  50. "response_format": response_format,
  51. "stream": stream,
  52. "watermark": watermark
  53. }
  54. self.timeout = timeout
  55. async def generate_without_refer(self, prompt: str, size: str = "1440x2560") -> Optional[Dict]:
  56. if not prompt:
  57. logger.info("错误:prompt不能为空")
  58. return None
  59. payload = {
  60. **self.config,
  61. "prompt": prompt,
  62. "size": size
  63. }
  64. try:
  65. # 使用 aiohttp 进行异步请求
  66. async with aiohttp.ClientSession() as session:
  67. async with session.post(
  68. url=self.api_url,
  69. headers=self.headers,
  70. json=payload,
  71. timeout=self.timeout
  72. ) as response:
  73. response.raise_for_status()
  74. result_data = await response.json()
  75. result_image = result_data["data"][0]["url"]
  76. return ImageOutput(fmt="url", ext="png", data=result_image)
  77. except Exception as e:
  78. logger.info(f"错误:生成图片时发生异常:{e}")
  79. return None
  80. async def generate(self, prompt: str, image_url: list[str], size: str = "1440x2560") -> Optional[Dict]:
  81. # 验证必填参数
  82. if not prompt or not image_url:
  83. logger.info("错误:prompt和image_url不能为空")
  84. return None
  85. # 如果image_url为图片路径,则编码为base64格式
  86. reference_image = [encode_image(image_url[i]) if "http" not in image_url[i] else image_url[i] for i in range(len(image_url))]
  87. # 构建请求体(合并固定配置和动态参数)
  88. payload = {
  89. **self.config,
  90. "prompt": prompt,
  91. "image": reference_image,
  92. "size": size
  93. }
  94. try:
  95. # 使用 aiohttp 进行异步请求
  96. async with aiohttp.ClientSession() as session:
  97. async with session.post(
  98. url=self.api_url,
  99. headers=self.headers,
  100. json=payload,
  101. timeout=self.timeout
  102. ) as response:
  103. response.raise_for_status()
  104. result_data = await response.json()
  105. result_image = result_data["data"][0]["url"]
  106. return ImageOutput(fmt="url", ext="png", data=result_image)
  107. except Exception as e:
  108. logger.info(f"错误:生成图片时发生异常:{e}")
  109. return None
  110. class ArkVideoGenerator:
  111. """Ark 图生视频 API 封装类,支持通过参考图和文本描述生成视频"""
  112. def __init__(
  113. self,
  114. auth_token: str = None,
  115. model: str = "doubao-seedance-1-0-pro-250528",
  116. timeout: int = 60,
  117. poll_interval: int = 5, # 轮询间隔(秒)
  118. max_poll_time: int = 500 # 最大轮询总时间(秒)
  119. ):
  120. # 固定 API 端点
  121. self.api_url = "https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks"
  122. if not auth_token:
  123. auth_token = os.getenv("ARK_API_KEY")
  124. # 固定请求头
  125. self.headers = {
  126. "Content-Type": "application/json",
  127. "Authorization": f"Bearer {auth_token}"
  128. }
  129. # 固定配置参数(模型、超时)
  130. self.model = model
  131. self.timeout = timeout
  132. self.poll_interval = poll_interval
  133. self.max_poll_time = max_poll_time
  134. def create_video_task(
  135. self,
  136. prompt: str,
  137. image_url: str,
  138. gen_params: str = ""
  139. ) -> Optional[Dict]:
  140. # 1. 验证动态参数合法性
  141. if not prompt.strip():
  142. logger.info("错误:文本描述(text_prompt)不能为空")
  143. return None
  144. if not image_url.strip():
  145. logger.info("错误:参考图 URL(reference_image_url)不能为空")
  146. return None
  147. refernece_image = encode_image(image_url) if "http" not in image_url else image_url
  148. # 2. 构建请求体(按 API 要求格式组装 content 列表)
  149. payload = {
  150. "model": self.model,
  151. "content": [
  152. {
  153. "type": "text",
  154. "text": prompt + gen_params
  155. },
  156. {
  157. "type": "image_url",
  158. "image_url": {
  159. "url": refernece_image
  160. }
  161. }
  162. ]
  163. }
  164. # 3. 发送 POST 请求并处理响应
  165. try:
  166. response = requests.post(
  167. url=self.api_url,
  168. headers=self.headers,
  169. data=json.dumps(payload, ensure_ascii=False),
  170. timeout=self.timeout
  171. )
  172. response.raise_for_status()
  173. return response.json()
  174. except Exception as e:
  175. logger.info(f"创建任务失败:{str(e)}")
  176. return None
  177. def query_video_task(self, task_id: str) -> Optional[Dict]:
  178. """
  179. 新增:查询图生视频任务结果
  180. 参数:
  181. task_id: 视频任务 ID(从 create_video_task 响应中获取)
  182. 返回:
  183. 任务结果详情(含视频状态、视频 URL 等)或 None
  184. """
  185. # 1. 验证任务 ID
  186. if not task_id.strip():
  187. logger.info("错误:任务 ID(task_id)不能为空")
  188. return None
  189. # 2. 构建查询 URL(拼接 task_id)
  190. query_url = f"{self.api_url}/{task_id}"
  191. # 3. 发送 GET 请求查询结果
  192. try:
  193. response = requests.get(
  194. url=query_url,
  195. headers=self.headers,
  196. timeout=self.timeout
  197. )
  198. # 触发 HTTP 错误(如 404 任务不存在、401 令牌无效)
  199. response.raise_for_status()
  200. return response.json()
  201. except Exception as e:
  202. logger.info(f"查询任务失败:{str(e)}")
  203. return None
  204. def _background_poll(
  205. self,
  206. task_id: str,
  207. filename: str,
  208. callback: Callable[[str, str, Optional[Dict], Optional[str]], None]
  209. ):
  210. """
  211. 后台轮询任务状态的线程函数
  212. :param task_id: 任务ID
  213. :param callback: 回调函数,参数为 (task_id, 成功结果, 错误信息)
  214. """
  215. start_time = time.time()
  216. while True:
  217. elapsed = time.time() - start_time
  218. if elapsed > self.max_poll_time:
  219. callback(task_id, filename, None, f"任务超时(超过 {self.max_poll_time} 秒)")
  220. # 查询任务状态
  221. result = self.query_video_task(task_id)
  222. if not result:
  223. time.sleep(self.poll_interval)
  224. continue
  225. # 解析状态
  226. status = result.get("status", "").lower()
  227. if status == "succeeded":
  228. callback(task_id, filename, result, None)
  229. return
  230. elif status == "failed":
  231. error_msg = result.get("error", {}).get("message", "未知错误")
  232. callback(task_id, filename, None, error_msg)
  233. return
  234. elif status in ["pending", "processing"]:
  235. logger.info(f"任务 {task_id} 处理中({int(elapsed)}秒),状态:{status}")
  236. time.sleep(self.poll_interval)
  237. else:
  238. logger.info(f"任务 {task_id} 未知状态:{status},继续等待...")
  239. time.sleep(self.poll_interval)
  240. def create_video_task_async(
  241. self,
  242. prompt: str,
  243. image_url: str,
  244. gen_params: str,
  245. filename: str,
  246. callback: Callable[[str, str, Optional[Dict], Optional[str]], None]
  247. ) -> Optional[str]:
  248. # 1. 提交任务
  249. task_response = self.create_video_task(prompt, image_url, gen_params)
  250. if not task_response or "id" not in task_response:
  251. logger.info("任务提交失败,无法启动后台轮询")
  252. return None
  253. task_id = task_response["id"]
  254. logger.info(f"任务提交成功,task_id: {task_id},启动后台轮询...")
  255. # 2. 启动后台线程轮询结果
  256. poll_thread = threading.Thread(
  257. target=self._background_poll,
  258. args=(task_id, filename, callback),
  259. daemon=True # 守护线程:主程序退出时自动结束
  260. )
  261. poll_thread.start()
  262. return task_id
  263. # 1. 定义回调函数:任务完成/失败时会被调用
  264. def handle_video_result(task_id: str, filename, result: Optional[Dict], error: Optional[str]) -> None:
  265. if error:
  266. logger.info(f"\n任务 {task_id} 处理失败:{error}")
  267. else:
  268. video_url = result.get("content", {}).get("video_url")
  269. output_path = "./output/" + filename
  270. download_video(video_url, output_path)
  271. logger.info(f"生成视频已下载:{output_path}")
  272. # API配置
  273. API_URL = os.getenv("AUDIO_GEN_API")
  274. def audio_generator(text, spk_audio="./data/audio/voice_07.wav", emo_audio="./data/audio/emo_sad.wav"):
  275. """调用TTS API生成语音"""
  276. payload = {
  277. "text": text,
  278. "spk_audio_prompt": spk_audio,
  279. "emo_audio_prompt": emo_audio
  280. }
  281. response = requests.post(API_URL, json=payload)
  282. if response.status_code == 200:
  283. result = response.json()
  284. if result["status"] == "success":
  285. print(f"语音生成成功: {result['audio_file']}")
  286. return result["audio_file"]
  287. else:
  288. print(f"请求失败: {response.text}")
  289. return None
  290. image_generator = ArkImageGenerator()
  291. video_generator = ArkVideoGenerator()
  292. if __name__ == "__main__":
  293. # 1. 初始化生成器(配置固定参数)
  294. image_generator = ArkImageGenerator()
  295. video_generator = ArkVideoGenerator()
  296. # 2. 调用生成方法(仅传入动态参数)
  297. result = image_generator.generate(
  298. prompt="狗狗在草地上追逐蒲公英",
  299. image_url="https://ark-project.tos-cn-beijing.volces.com/doc_image/seedream4_imageToimage.png",
  300. filename="1.jpg"
  301. )
  302. logger.info(f"result:{result}")
  303. # print(video_generator.query_video_task("cgt-20251022103303-9hgrr"))
  304. # cgt-20251022101137-852fw cgt-20251022102536-kt8pj cgt-20251022103303-9hgrr
  305. # task_id = video_generator.create_video_task_async(
  306. # prompt="狗狗不停地在草地上跳跃",
  307. # image_url="https://ark-project.tos-cn-beijing.volces.com/doc_image/seedream4_imageToimage.png",
  308. # gen_params="",
  309. # filename="1.mp4",
  310. # callback=handle_video_result
  311. # )
  312. # if task_id:
  313. # print("\n主流程:任务已提交,开始执行其他操作...")
  314. # for i in range(10):
  315. # print(f"主流程:正在执行第 {i+1} 步操作...")
  316. # time.sleep(1) # 模拟主流程耗时操作
  317. # print("主流程:所有操作执行完毕,等待后台任务结果(若未完成)...")
  318. # # 防止主程序提前退出(实际生产环境可能有其他阻塞逻辑)
  319. # # 这里仅为演示:等待所有后台线程完成
  320. # while threading.active_count() > 1:
  321. # time.sleep(1)