import os import time import requests import json import threading import asyncio import aiohttp from typing import Optional, Dict, Callable from dotenv import load_dotenv from interfaces.image_output import ImageOutput from utils.tools import encode_image, download_image, download_video from utils.upload import upload_file_to_tos from utils.logger_config import setup_logger load_dotenv() logger = setup_logger(__name__) class ArkImageGenerator: """Ark 图片生成 API 封装类""" def __init__( self, auth_token: str = None, model: str = "doubao-seedream-4-0-250828", sequential_generation: str = "disabled", response_format: str = "url", stream: bool = False, watermark: bool = True, timeout: int = 120 ): """ 初始化图片生成器 参数: auth_token: 认证令牌(Bearer Token) model: 模型名称(固定配置) sequential_generation: 序列生成开关(固定配置) response_format: 响应格式(固定配置) stream: 流式响应开关(固定配置) watermark: 水印开关(固定配置) timeout: 请求超时时间(秒) """ self.api_url = "https://ark.cn-beijing.volces.com/api/v3/images/generations" if not auth_token: auth_token = os.getenv("ARK_API_KEY") self.headers = { "Content-Type": "application/json", "Authorization": f"Bearer {auth_token}" } # 固定配置参数 self.config = { "model": model, "sequential_image_generation": sequential_generation, "response_format": response_format, "stream": stream, "watermark": watermark } self.timeout = timeout async def generate_without_refer(self, prompt: str, size: str = "1440x2560") -> Optional[Dict]: if not prompt: logger.info("错误:prompt不能为空") return None payload = { **self.config, "prompt": prompt, "size": size } try: # 使用 aiohttp 进行异步请求 async with aiohttp.ClientSession() as session: async with session.post( url=self.api_url, headers=self.headers, json=payload, timeout=self.timeout ) as response: response.raise_for_status() result_data = await response.json() result_image = result_data["data"][0]["url"] return ImageOutput(fmt="url", ext="png", data=result_image) except Exception as e: logger.info(f"错误:生成图片时发生异常:{e}") return None async def generate(self, prompt: str, image_url: list[str], size: str = "1440x2560") -> Optional[Dict]: # 验证必填参数 if not prompt or not image_url: logger.info("错误:prompt和image_url不能为空") return None # 如果image_url为图片路径,则编码为base64格式 reference_image = [encode_image(image_url[i]) if "http" not in image_url[i] else image_url[i] for i in range(len(image_url))] # 构建请求体(合并固定配置和动态参数) payload = { **self.config, "prompt": prompt, "image": reference_image, "size": size } try: # 使用 aiohttp 进行异步请求 async with aiohttp.ClientSession() as session: async with session.post( url=self.api_url, headers=self.headers, json=payload, timeout=self.timeout ) as response: response.raise_for_status() result_data = await response.json() result_image = result_data["data"][0]["url"] return ImageOutput(fmt="url", ext="png", data=result_image) except Exception as e: logger.info(f"错误:生成图片时发生异常:{e}") return None class ArkVideoGenerator: """Ark 图生视频 API 封装类,支持通过参考图和文本描述生成视频""" def __init__( self, auth_token: str = None, model: str = "doubao-seedance-1-0-pro-250528", timeout: int = 60, poll_interval: int = 5, # 轮询间隔(秒) max_poll_time: int = 500 # 最大轮询总时间(秒) ): # 固定 API 端点 self.api_url = "https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks" if not auth_token: auth_token = os.getenv("ARK_API_KEY") # 固定请求头 self.headers = { "Content-Type": "application/json", "Authorization": f"Bearer {auth_token}" } # 固定配置参数(模型、超时) self.model = model self.timeout = timeout self.poll_interval = poll_interval self.max_poll_time = max_poll_time def create_video_task( self, prompt: str, image_url: str, gen_params: str = "" ) -> Optional[Dict]: # 1. 验证动态参数合法性 if not prompt.strip(): logger.info("错误:文本描述(text_prompt)不能为空") return None if not image_url.strip(): logger.info("错误:参考图 URL(reference_image_url)不能为空") return None refernece_image = encode_image(image_url) if "http" not in image_url else image_url # 2. 构建请求体(按 API 要求格式组装 content 列表) payload = { "model": self.model, "content": [ { "type": "text", "text": prompt + gen_params }, { "type": "image_url", "image_url": { "url": refernece_image } } ] } # 3. 发送 POST 请求并处理响应 try: response = requests.post( url=self.api_url, headers=self.headers, data=json.dumps(payload, ensure_ascii=False), timeout=self.timeout ) response.raise_for_status() return response.json() except Exception as e: logger.info(f"创建任务失败:{str(e)}") return None def query_video_task(self, task_id: str) -> Optional[Dict]: """ 新增:查询图生视频任务结果 参数: task_id: 视频任务 ID(从 create_video_task 响应中获取) 返回: 任务结果详情(含视频状态、视频 URL 等)或 None """ # 1. 验证任务 ID if not task_id.strip(): logger.info("错误:任务 ID(task_id)不能为空") return None # 2. 构建查询 URL(拼接 task_id) query_url = f"{self.api_url}/{task_id}" # 3. 发送 GET 请求查询结果 try: response = requests.get( url=query_url, headers=self.headers, timeout=self.timeout ) # 触发 HTTP 错误(如 404 任务不存在、401 令牌无效) response.raise_for_status() return response.json() except Exception as e: logger.info(f"查询任务失败:{str(e)}") return None def _background_poll( self, task_id: str, filename: str, callback: Callable[[str, str, Optional[Dict], Optional[str]], None] ): """ 后台轮询任务状态的线程函数 :param task_id: 任务ID :param callback: 回调函数,参数为 (task_id, 成功结果, 错误信息) """ start_time = time.time() while True: elapsed = time.time() - start_time if elapsed > self.max_poll_time: callback(task_id, filename, None, f"任务超时(超过 {self.max_poll_time} 秒)") # 查询任务状态 result = self.query_video_task(task_id) if not result: time.sleep(self.poll_interval) continue # 解析状态 status = result.get("status", "").lower() if status == "succeeded": callback(task_id, filename, result, None) return elif status == "failed": error_msg = result.get("error", {}).get("message", "未知错误") callback(task_id, filename, None, error_msg) return elif status in ["pending", "processing"]: logger.info(f"任务 {task_id} 处理中({int(elapsed)}秒),状态:{status}") time.sleep(self.poll_interval) else: logger.info(f"任务 {task_id} 未知状态:{status},继续等待...") time.sleep(self.poll_interval) def create_video_task_async( self, prompt: str, image_url: str, gen_params: str, filename: str, callback: Callable[[str, str, Optional[Dict], Optional[str]], None] ) -> Optional[str]: # 1. 提交任务 task_response = self.create_video_task(prompt, image_url, gen_params) if not task_response or "id" not in task_response: logger.info("任务提交失败,无法启动后台轮询") return None task_id = task_response["id"] logger.info(f"任务提交成功,task_id: {task_id},启动后台轮询...") # 2. 启动后台线程轮询结果 poll_thread = threading.Thread( target=self._background_poll, args=(task_id, filename, callback), daemon=True # 守护线程:主程序退出时自动结束 ) poll_thread.start() return task_id # 1. 定义回调函数:任务完成/失败时会被调用 def handle_video_result(task_id: str, filename, result: Optional[Dict], error: Optional[str]) -> None: if error: logger.info(f"\n任务 {task_id} 处理失败:{error}") else: video_url = result.get("content", {}).get("video_url") output_path = "./output/" + filename download_video(video_url, output_path) logger.info(f"生成视频已下载:{output_path}") # API配置 API_URL = os.getenv("AUDIO_GEN_API") def audio_generator(text, spk_audio="./data/audio/voice_07.wav", emo_audio="./data/audio/emo_sad.wav"): """调用TTS API生成语音""" payload = { "text": text, "spk_audio_prompt": spk_audio, "emo_audio_prompt": emo_audio } response = requests.post(API_URL, json=payload) if response.status_code == 200: result = response.json() if result["status"] == "success": print(f"语音生成成功: {result['audio_file']}") return result["audio_file"] else: print(f"请求失败: {response.text}") return None image_generator = ArkImageGenerator() video_generator = ArkVideoGenerator() if __name__ == "__main__": # 1. 初始化生成器(配置固定参数) image_generator = ArkImageGenerator() video_generator = ArkVideoGenerator() # 2. 调用生成方法(仅传入动态参数) result = image_generator.generate( prompt="狗狗在草地上追逐蒲公英", image_url="https://ark-project.tos-cn-beijing.volces.com/doc_image/seedream4_imageToimage.png", filename="1.jpg" ) logger.info(f"result:{result}") # print(video_generator.query_video_task("cgt-20251022103303-9hgrr")) # cgt-20251022101137-852fw cgt-20251022102536-kt8pj cgt-20251022103303-9hgrr # task_id = video_generator.create_video_task_async( # prompt="狗狗不停地在草地上跳跃", # image_url="https://ark-project.tos-cn-beijing.volces.com/doc_image/seedream4_imageToimage.png", # gen_params="", # filename="1.mp4", # callback=handle_video_result # ) # if task_id: # print("\n主流程:任务已提交,开始执行其他操作...") # for i in range(10): # print(f"主流程:正在执行第 {i+1} 步操作...") # time.sleep(1) # 模拟主流程耗时操作 # print("主流程:所有操作执行完毕,等待后台任务结果(若未完成)...") # # 防止主程序提前退出(实际生产环境可能有其他阻塞逻辑) # # 这里仅为演示:等待所有后台线程完成 # while threading.active_count() > 1: # time.sleep(1)