| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405 |
- import os
- import time
- import requests
- import json
- import threading
- import asyncio
- import aiohttp
- from typing import Optional, Dict, Callable
- from dotenv import load_dotenv
- from utils.tools import encode_image, download_image, download_video
- from utils.upload import upload_file_to_tos
- from utils.logger_config import setup_logger
- load_dotenv()
- logger = setup_logger(__name__)
- class ArkImageGenerator:
- """Ark 图片生成 API 封装类"""
-
- def __init__(
- self,
- auth_token: str = None,
- model: str = "doubao-seedream-4-0-250828",
- sequential_generation: str = "disabled",
- response_format: str = "url",
- stream: bool = False,
- watermark: bool = True,
- timeout: int = 120
- ):
- """
- 初始化图片生成器
-
- 参数:
- auth_token: 认证令牌(Bearer Token)
- model: 模型名称(固定配置)
- sequential_generation: 序列生成开关(固定配置)
- response_format: 响应格式(固定配置)
- stream: 流式响应开关(固定配置)
- watermark: 水印开关(固定配置)
- timeout: 请求超时时间(秒)
- """
- self.api_url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
- if not auth_token:
- auth_token = os.getenv("ARK_API_KEY")
- self.headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {auth_token}"
- }
- # 固定配置参数
- self.config = {
- "model": model,
- "sequential_image_generation": sequential_generation,
- "response_format": response_format,
- "stream": stream,
- "watermark": watermark
- }
- self.timeout = timeout
- async def generate_without_refer(self, prompt: str, filename: str, size: str = "1440x2560") -> Optional[Dict]:
- if not prompt:
- logger.info("错误:prompt不能为空")
- return None
- payload = {
- **self.config,
- "prompt": prompt,
- "size": size
- }
- try:
- # 使用 aiohttp 进行异步请求
- async with aiohttp.ClientSession() as session:
- async with session.post(
- url=self.api_url,
- headers=self.headers,
- json=payload,
- timeout=self.timeout
- ) as response:
- response.raise_for_status()
- result_data = await response.json()
- result_image = result_data["data"][0]["url"]
- output_path = "./output/" + filename
- download_image(result_image, output_path)
- return result_image
-
- except Exception as e:
- logger.info(f"错误:生成图片时发生异常:{e}")
- return None
- def generate(self, prompt: str, image_url: str, filename: str, size: str = "1440x2560") -> Optional[Dict]:
- # 验证必填参数
- if not prompt or not image_url:
- logger.info("错误:prompt和image_url不能为空")
- return None
- # 如果image_url为图片路径,则编码为base64格式
- reference_image = encode_image(image_url) if "http" not in image_url else image_url
-
- # 构建请求体(合并固定配置和动态参数)
- payload = {
- **self.config,
- "prompt": prompt,
- "image": reference_image,
- "size": size
- }
-
- try:
- response = requests.post(
- url=self.api_url,
- headers=self.headers,
- data=json.dumps(payload),
- timeout=self.timeout
- )
- response.raise_for_status()
- result_image = response.json()["data"][0]["url"]
- output_path = "./output/" + filename
- download_image(result_image, output_path)
- return result_image
-
- except requests.exceptions.Timeout:
- logger.info(f"错误:请求超时({self.timeout}秒)")
- except requests.exceptions.ConnectionError:
- logger.info("错误:网络连接失败,请检查API地址")
- except requests.exceptions.HTTPError as e:
- logger.info(f"错误:HTTP请求失败(状态码{response.status_code}):{e}")
- if response.text:
- logger.info(f"API错误详情:{response.text}")
- except json.JSONDecodeError:
- logger.info("错误:响应内容不是合法JSON")
- except Exception as e:
- logger.info(f"错误:生成图片时发生异常:{e}")
-
- return None
- class ArkVideoGenerator:
- """Ark 图生视频 API 封装类,支持通过参考图和文本描述生成视频"""
-
- def __init__(
- self,
- auth_token: str = None,
- model: str = "doubao-seedance-1-0-pro-250528",
- timeout: int = 60,
- poll_interval: int = 5, # 轮询间隔(秒)
- max_poll_time: int = 500 # 最大轮询总时间(秒)
- ):
- # 固定 API 端点
- self.api_url = "https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks"
- if not auth_token:
- auth_token = os.getenv("ARK_API_KEY")
- # 固定请求头
- self.headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {auth_token}"
- }
- # 固定配置参数(模型、超时)
- self.model = model
- self.timeout = timeout
- self.poll_interval = poll_interval
- self.max_poll_time = max_poll_time
- def create_video_task(
- self,
- prompt: str,
- image_url: str,
- gen_params: str = ""
- ) -> Optional[Dict]:
- # 1. 验证动态参数合法性
- if not prompt.strip():
- logger.info("错误:文本描述(text_prompt)不能为空")
- return None
- if not image_url.strip():
- logger.info("错误:参考图 URL(reference_image_url)不能为空")
- return None
- reference_image = upload_file_to_tos(image_url) if "http" not in image_url else image_url
- logger.info(f"视频生成提示词: {prompt + gen_params}")
- logger.info(f"视频生成参考图片: {reference_image}")
- # refernece_image = encode_image(image_url) if "http" not in image_url else image_url
- # 检查reference_image是否可访问
- try:
- head_response = requests.head(reference_image, timeout=5)
- if head_response.status_code != 200:
- logger.warning(f"参考图片URL可能无法访问,状态码: {head_response.status_code}")
- except Exception as e:
- logger.warning(f"验证参考图片URL可访问性时出错: {str(e)}")
- return None
-
- # 2. 构建请求体(按 API 要求格式组装 content 列表)
- payload = {
- "model": self.model,
- "content": [
- {
- "type": "text",
- "text": prompt + gen_params
- },
- {
- "type": "image_url",
- "image_url": {
- "url": reference_image
- }
- }
- ]
- }
-
- # 3. 发送 POST 请求并处理响应
- try:
- response = requests.post(
- url=self.api_url,
- headers=self.headers,
- json=payload,
- timeout=self.timeout
- )
- response.raise_for_status()
- return response.json()
-
- except Exception as e:
- logger.info(f"创建任务失败:{str(e)}")
- return None
- def query_video_task(self, task_id: str) -> Optional[Dict]:
- """
- 新增:查询图生视频任务结果
-
- 参数:
- task_id: 视频任务 ID(从 create_video_task 响应中获取)
-
- 返回:
- 任务结果详情(含视频状态、视频 URL 等)或 None
- """
- # 1. 验证任务 ID
- if not task_id.strip():
- logger.info("错误:任务 ID(task_id)不能为空")
- return None
-
- # 2. 构建查询 URL(拼接 task_id)
- query_url = f"{self.api_url}/{task_id}"
-
- # 3. 发送 GET 请求查询结果
- try:
- response = requests.get(
- url=query_url,
- headers=self.headers,
- timeout=self.timeout
- )
- # 触发 HTTP 错误(如 404 任务不存在、401 令牌无效)
- response.raise_for_status()
- return response.json()
- except Exception as e:
- logger.info(f"查询任务失败:{str(e)}")
- return None
- def _background_poll(
- self,
- task_id: str,
- filename: str,
- callback: Callable[[str, str, Optional[Dict], Optional[str]], None]
- ):
- """
- 后台轮询任务状态的线程函数
- :param task_id: 任务ID
- :param callback: 回调函数,参数为 (task_id, 成功结果, 错误信息)
- """
- start_time = time.time()
- while True:
- elapsed = time.time() - start_time
- if elapsed > self.max_poll_time:
- callback(task_id, filename, None, f"任务超时(超过 {self.max_poll_time} 秒)")
- # 查询任务状态
- result = self.query_video_task(task_id)
- if not result:
- time.sleep(self.poll_interval)
- continue
-
- # 解析状态
- status = result.get("status", "").lower()
- if status == "succeeded":
- callback(task_id, filename, result, None)
- return
- elif status == "failed":
- error_msg = result.get("error", {}).get("message", "未知错误")
- callback(task_id, filename, None, error_msg)
- return
- elif status in ["pending", "processing"]:
- logger.info(f"任务 {task_id} 处理中({int(elapsed)}秒),状态:{status}")
- time.sleep(self.poll_interval)
- else:
- logger.info(f"任务 {task_id} 未知状态:{status},继续等待...")
- time.sleep(self.poll_interval)
- def create_video_task_async(
- self,
- prompt: str,
- image_url: str,
- gen_params: str,
- filename: str,
- callback: Callable[[str, str, Optional[Dict], Optional[str]], None]
- ) -> Optional[str]:
- # 1. 提交任务
- task_response = self.create_video_task(prompt, image_url, gen_params)
- print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
- print(f"task_response: {task_response}")
- if not task_response or "id" not in task_response:
- logger.info("任务提交失败,无法启动后台轮询")
- return None
- task_id = task_response["id"]
- logger.info(f"任务提交成功,task_id: {task_id},启动后台轮询...")
- # 2. 启动后台线程轮询结果
- poll_thread = threading.Thread(
- target=self._background_poll,
- args=(task_id, filename, callback),
- daemon=True # 守护线程:主程序退出时自动结束
- )
- poll_thread.start()
- return task_id
-
- # 1. 定义回调函数:任务完成/失败时会被调用
- def handle_video_result(task_id: str, filename, result: Optional[Dict], error: Optional[str]) -> None:
- if error:
- logger.info(f"\n任务 {task_id} 处理失败:{error}")
- else:
- video_url = result.get("content", {}).get("video_url")
- output_path = "./output/" + filename
- download_video(video_url, output_path)
- logger.info(f"生成视频已下载:{output_path}")
-
- # API配置
- API_URL = os.getenv("AUDIO_GEN_API")
- def audio_generator(text, spk_audio="./data/audio/voice_07.wav", emo_audio="./data/audio/emo_sad.wav"):
- """调用TTS API生成语音"""
- payload = {
- "text": text,
- "spk_audio_prompt": spk_audio,
- "emo_audio_prompt": emo_audio
- }
-
- response = requests.post(API_URL, json=payload)
-
- if response.status_code == 200:
- result = response.json()
- if result["status"] == "success":
- print(f"语音生成成功: {result['audio_file']}")
- return result["audio_file"]
- else:
- print(f"请求失败: {response.text}")
- return None
- image_generator = ArkImageGenerator()
- video_create = ArkVideoGenerator()
-
- if __name__ == "__main__":
- # 1. 初始化生成器(配置固定参数)
- image_generator = ArkImageGenerator()
- video_create = ArkVideoGenerator()
-
- # 2. 调用生成方法(仅传入动态参数)
- # result = image_generator.generate(
- # prompt="狗狗在草地上追逐蒲公英",
- # image_url="https://ark-project.tos-cn-beijing.volces.com/doc_image/seedream4_imageToimage.png",
- # filename="1.jpg"
- # )
-
- # logger.info(f"result:{result}")
- # print(video_generator.query_video_task("cgt-20251022103303-9hgrr"))
- # cgt-20251022101137-852fw cgt-20251022102536-kt8pj cgt-20251022103303-9hgrr
- # task_id = video_create.create_video_task_async(
- # prompt="狗狗不停地在草地上跳跃",
- # image_url="https://testdgxcx-oss.gloria.com.cn/video-create/new_frame_scene0_camera0_shot1.png",
- # gen_params="",
- # filename="1.mp4",
- # callback=handle_video_result
- # )
- task_id = video_create.create_video_task(
- prompt="狗狗不停地在草地上跳跃",
- image_url="https://testdgxcx-oss.gloria.com.cn/video-create/new_frame_scene0_camera0_shot1.png",
- gen_params=""
- )
- print(f"task_id: {task_id}")
- if task_id:
- print("\n主流程:任务已提交,开始执行其他操作...")
- for i in range(10):
- print(f"主流程:正在执行第 {i+1} 步操作...")
- time.sleep(1) # 模拟主流程耗时操作
- print("主流程:所有操作执行完毕,等待后台任务结果(若未完成)...")
-
- # 防止主程序提前退出(实际生产环境可能有其他阻塞逻辑)
- # 这里仅为演示:等待所有后台线程完成
- while threading.active_count() > 1:
- time.sleep(1)
|