media_captioner.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. import os
  2. import base64
  3. import io
  4. import asyncio
  5. import aiohttp
  6. import time
  7. import functools
  8. from concurrent.futures import ThreadPoolExecutor
  9. from PIL import Image
  10. from typing import Optional, Dict, Any, Literal, Union, List, Callable
  11. from volcenginesdkarkruntime import Ark
  12. from utils.logger_config import setup_logger
  13. from utils.config_manager import ConfigManager
  14. from dotenv import load_dotenv
  15. # 加载.env文件
  16. load_dotenv()
  17. logger = setup_logger(__name__)
  18. def async_performance_monitor(func: Callable):
  19. """异步方法性能监控装饰器"""
  20. @functools.wraps(func)
  21. async def wrapper(*args, **kwargs):
  22. start_time = time.time()
  23. try:
  24. result = await func(*args, **kwargs)
  25. end_time = time.time()
  26. execution_time = end_time - start_time
  27. logger.info(f"{func.__name__} completed in {execution_time:.2f} seconds")
  28. return result
  29. except Exception as e:
  30. end_time = time.time()
  31. execution_time = end_time - start_time
  32. logger.error(f"{func.__name__} failed after {execution_time:.2f} seconds: {str(e)}")
  33. raise
  34. return wrapper
  35. def sync_performance_monitor(func: Callable):
  36. """同步方法性能监控装饰器"""
  37. @functools.wraps(func)
  38. def wrapper(*args, **kwargs):
  39. start_time = time.time()
  40. try:
  41. result = func(*args, **kwargs)
  42. end_time = time.time()
  43. execution_time = end_time - start_time
  44. logger.info(f"{func.__name__} completed in {execution_time:.2f} seconds")
  45. return result
  46. except Exception as e:
  47. end_time = time.time()
  48. execution_time = end_time - start_time
  49. logger.error(f"{func.__name__} failed after {execution_time:.2f} seconds: {str(e)}")
  50. raise
  51. return wrapper
  52. class MediaCaptioner:
  53. """媒体描述生成器,使用火山引擎API进行视频、图像和文本内容理解"""
  54. def __init__(self, api_key: Optional[str] = None,
  55. base_url: str = "https://ark.cn-beijing.volces.com/api/v3",
  56. model: str = "doubao-seed-1-6-250615",
  57. config_path: Optional[str] = None):
  58. """
  59. 初始化媒体描述生成器
  60. Args:
  61. api_key: 火山引擎API密钥,如果为None则从环境变量获取
  62. base_url: API基础URL
  63. model: 使用的模型ID
  64. config_path: 提示词配置文件路径
  65. """
  66. try:
  67. self.api_key = api_key or os.getenv("VOLC_API_KEY")
  68. if not self.api_key:
  69. raise ValueError("API key must be provided either through constructor or environment variable VOLC_API_KEY")
  70. self.client = Ark(
  71. api_key=self.api_key,
  72. base_url=base_url
  73. )
  74. self.base_url = base_url
  75. self.model = model
  76. self.config_manager = ConfigManager(config_path)
  77. logger.info(f"Initialized MediaCaptioner with model: {model}")
  78. except Exception as e:
  79. logger.error(f"Failed to initialize MediaCaptioner: {str(e)}")
  80. raise
  81. @sync_performance_monitor
  82. def _encode_video(self, video_path: str) -> str:
  83. """
  84. 将视频文件转换为base64编码
  85. Args:
  86. video_path: 视频文件路径
  87. Returns:
  88. str: base64编码的视频数据
  89. Raises:
  90. FileNotFoundError: 视频文件不存在
  91. IOError: 读取文件失败
  92. """
  93. if not os.path.exists(video_path):
  94. raise FileNotFoundError(f"Video file not found: {video_path}")
  95. with open(video_path, "rb") as f:
  96. return base64.b64encode(f.read()).decode("utf-8")
  97. @sync_performance_monitor
  98. def _encode_image(self, image_path: str) -> str:
  99. """
  100. 将图片文件转换为base64编码
  101. Args:
  102. image_path: 图片文件路径
  103. Returns:
  104. str: base64编码的图片数据
  105. Raises:
  106. FileNotFoundError: 图片文件不存在
  107. IOError: 读取或处理图片失败
  108. """
  109. if not os.path.exists(image_path):
  110. raise FileNotFoundError(f"Image file not found: {image_path}")
  111. with Image.open(image_path) as img:
  112. buffered = io.BytesIO()
  113. img.save(buffered, format="JPEG")
  114. return base64.b64encode(buffered.getvalue()).decode("utf-8")
  115. def generate_video_caption(self,
  116. video_path: str,
  117. prompt_type: str = "caption",
  118. scenario: Optional[str] = None,
  119. fps: int = 2,
  120. context_info: Optional[str] = None) -> Optional[str]:
  121. """
  122. 生成视频描述的同步包装器
  123. Args:
  124. video_path: 视频文件路径
  125. prompt_type: 提示词类型
  126. scenario: 场景类型
  127. fps: 视频采样帧率
  128. Returns:
  129. str: 视频描述,如果处理失败则返回None
  130. """
  131. loop = asyncio.get_event_loop()
  132. return loop.run_until_complete(
  133. self._process_video_async(
  134. file_path=video_path,
  135. prompt_type=prompt_type,
  136. scenario=scenario,
  137. fps=fps,
  138. context_info=context_info
  139. )
  140. )
  141. async def generate_video_caption_async(self,
  142. video_path: str,
  143. prompt_type: str = "caption",
  144. scenario: Optional[str] = None,
  145. fps: int = 2,
  146. context_info: Optional[str] = None) -> Optional[str]:
  147. """
  148. 异步生成视频描述
  149. Args:
  150. video_path: 视频文件路径
  151. prompt_type: 提示词类型
  152. scenario: 场景类型
  153. fps: 视频采样帧率
  154. Returns:
  155. str: 视频描述,如果处理失败则返回None
  156. """
  157. return await self._process_video_async(
  158. file_path=video_path,
  159. prompt_type=prompt_type,
  160. scenario=scenario,
  161. fps=fps,
  162. context_info=context_info
  163. )
  164. def generate_image_caption(self,
  165. image_path: str,
  166. prompt_type: str = "caption",
  167. scenario: Optional[str] = None,
  168. context_info: Optional[str] = None) -> Optional[str]:
  169. """
  170. 生成图片描述的同步包装器
  171. Args:
  172. image_path: 图片文件路径
  173. prompt_type: 提示词类型
  174. scenario: 场景类型
  175. Returns:
  176. str: 图片描述,如果处理失败则返回None
  177. """
  178. loop = asyncio.get_event_loop()
  179. return loop.run_until_complete(
  180. self._process_image_async(
  181. file_path=image_path,
  182. prompt_type=prompt_type,
  183. scenario=scenario,
  184. context_info=context_info
  185. )
  186. )
  187. async def generate_image_caption_async(self,
  188. image_path: str,
  189. prompt_type: str = "caption",
  190. scenario: Optional[str] = None,
  191. context_info: Optional[str] = None) -> Optional[str]:
  192. """
  193. 异步生成图片描述
  194. Args:
  195. image_path: 图片文件路径
  196. prompt_type: 提示词类型
  197. scenario: 场景类型
  198. Returns:
  199. str: 图片描述,如果处理失败则返回None
  200. """
  201. return await self._process_image_async(
  202. file_path=image_path,
  203. prompt_type=prompt_type,
  204. scenario=scenario,
  205. context_info=context_info
  206. )
  207. def generate_text_understanding(self,
  208. text: str,
  209. prompt_type: str = "summary",
  210. scenario: Optional[str] = None,
  211. max_length: Optional[int] = None,
  212. context_info: Optional[str] = None) -> Optional[str]:
  213. """
  214. 生成文本理解结果的同步包装器
  215. Args:
  216. text: 需要理解的文本内容
  217. prompt_type: 提示词类型
  218. scenario: 场景类型
  219. max_length: 最大输出长度
  220. Returns:
  221. str: 文本理解结果,如果处理失败则返回None
  222. """
  223. loop = asyncio.get_event_loop()
  224. return loop.run_until_complete(
  225. self._process_text_async(
  226. text=text,
  227. prompt_type=prompt_type,
  228. scenario=scenario,
  229. max_length=max_length,
  230. context_info=context_info
  231. )
  232. )
  233. async def generate_text_understanding_async(self,
  234. text: str,
  235. prompt_type: str = "summary",
  236. scenario: Optional[str] = None,
  237. max_length: Optional[int] = None,
  238. context_info: Optional[str] = None) -> Optional[str]:
  239. """
  240. 异步生成文本理解结果
  241. Args:
  242. text: 需要理解的文本内容
  243. prompt_type: 提示词类型
  244. scenario: 场景类型
  245. max_length: 最大输出长度
  246. Returns:
  247. str: 文本理解结果,如果处理失败则返回None
  248. """
  249. return await self._process_text_async(
  250. text=text,
  251. prompt_type=prompt_type,
  252. scenario=scenario,
  253. max_length=max_length,
  254. context_info=context_info
  255. )
  256. def generate_multi_aspect_understanding(self,
  257. text: str,
  258. prompt_types: List[str],
  259. scenario: Optional[str] = None) -> Dict[str, Optional[str]]:
  260. """
  261. 从多个角度生成文本理解结果的同步包装器
  262. Args:
  263. text: 需要理解的文本内容
  264. prompt_types: 提示词类型列表
  265. scenario: 场景类型
  266. Returns:
  267. Dict[str, Optional[str]]: 提示词类型到理解结果的映射
  268. """
  269. loop = asyncio.get_event_loop()
  270. return loop.run_until_complete(
  271. self.generate_multi_aspect_understanding_async(
  272. text=text,
  273. prompt_types=prompt_types,
  274. scenario=scenario
  275. )
  276. )
  277. async def generate_multi_aspect_understanding_async(self,
  278. text: str,
  279. prompt_types: List[str],
  280. scenario: Optional[str] = None,
  281. context_info: Optional[str] = None) -> Dict[str, Optional[str]]:
  282. """
  283. 异步从多个角度生成文本理解结果
  284. Args:
  285. text: 需要理解的文本内容
  286. prompt_types: 提示词类型列表
  287. scenario: 场景类型
  288. Returns:
  289. Dict[str, Optional[str]]: 提示词类型到理解结果的映射
  290. """
  291. tasks = [
  292. self._process_text_async(
  293. text=text,
  294. prompt_type=prompt_type,
  295. scenario=scenario,
  296. context_info=context_info
  297. )
  298. for prompt_type in prompt_types
  299. ]
  300. results = await asyncio.gather(*tasks)
  301. return dict(zip(prompt_types, results))
  302. async def _make_api_request(self,
  303. endpoint: str,
  304. payload: Dict[str, Any],
  305. timeout: int = 180) -> Dict[str, Any]:
  306. """
  307. 发送API请求的通用方法
  308. Args:
  309. endpoint: API端点
  310. payload: 请求负载
  311. timeout: 超时时间(秒)
  312. Returns:
  313. Dict[str, Any]: API响应
  314. Raises:
  315. aiohttp.ClientError: API请求失败
  316. asyncio.TimeoutError: 请求超时
  317. """
  318. try:
  319. async with aiohttp.ClientSession() as session:
  320. async with session.post(
  321. f"{self.base_url}/{endpoint}",
  322. json=payload,
  323. headers={"Authorization": f"Bearer {self.api_key}"},
  324. timeout=timeout
  325. ) as response:
  326. if response.status != 200:
  327. error_text = await response.text()
  328. raise aiohttp.ClientError(f"API request failed with status {response.status}: {error_text}")
  329. return await response.json()
  330. except asyncio.TimeoutError:
  331. logger.error(f"API request timed out after {timeout} seconds")
  332. raise
  333. except Exception as e:
  334. logger.error(f"API request failed: {str(e)}")
  335. raise
  336. @async_performance_monitor
  337. async def _process_video_async(self, file_path: str, prompt_type: str,
  338. scenario: Optional[str] = None, fps: int = 2, context_info: Optional[str] = None) -> Optional[str]:
  339. """异步处理视频文件"""
  340. try:
  341. # 在线程池中执行文件IO操作
  342. loop = asyncio.get_event_loop()
  343. with ThreadPoolExecutor() as pool:
  344. base64_video = await loop.run_in_executor(
  345. pool, self._encode_video, file_path
  346. )
  347. prompt = self.config_manager.get_prompt("video", prompt_type, scenario)
  348. # 构建API请求
  349. payload = {
  350. "model": self.model,
  351. "messages": [{
  352. "role": "system",
  353. "content": [
  354. {
  355. "type": "video_url",
  356. "video_url": {
  357. "url": f"data:video/mp4;base64,{base64_video}",
  358. "fps": fps
  359. }
  360. },
  361. {
  362. "type": "text",
  363. "text": prompt
  364. }
  365. ]
  366. },
  367. {
  368. "role": "user",
  369. "content": f"上下文信息:{context_info}"
  370. }
  371. ]
  372. }
  373. # 发送API请求
  374. response = await self._make_api_request("chat/completions", payload)
  375. return response["choices"][0]["message"]["content"]
  376. except Exception as e:
  377. logger.error(f"Failed to process video async: {str(e)}")
  378. return None
  379. @async_performance_monitor
  380. async def _process_image_async(self, file_path: str, prompt_type: str,
  381. scenario: Optional[str] = None, context_info: Optional[str] = None) -> Optional[str]:
  382. """异步处理图片文件"""
  383. try:
  384. # 在线程池中执行文件IO操作
  385. loop = asyncio.get_event_loop()
  386. with ThreadPoolExecutor() as pool:
  387. base64_image = await loop.run_in_executor(
  388. pool, self._encode_image, file_path
  389. )
  390. prompt = self.config_manager.get_prompt("image", prompt_type, scenario)
  391. # 构建API请求
  392. payload = {
  393. "model": self.model,
  394. "messages": [{
  395. "role": "system",
  396. "content": [
  397. {
  398. "type": "image_url",
  399. "image_url": {
  400. "url": f"data:image/jpeg;base64,{base64_image}"
  401. }
  402. },
  403. {
  404. "type": "text",
  405. "text": prompt
  406. }
  407. ]
  408. },
  409. {
  410. "role": "user",
  411. "content": f"上下文信息:{context_info}"
  412. }
  413. ]
  414. }
  415. # 发送API请求
  416. response = await self._make_api_request("chat/completions", payload)
  417. return response["choices"][0]["message"]["content"]
  418. except Exception as e:
  419. logger.error(f"Failed to process image async: {str(e)}")
  420. return None
  421. @async_performance_monitor
  422. async def _process_text_async(self, text: str, prompt_type: str,
  423. scenario: Optional[str] = None,
  424. max_length: Optional[int] = None,
  425. context_info: Optional[str] = None) -> Optional[str]:
  426. """异步处理文本内容"""
  427. # try:
  428. if not text.strip():
  429. logger.error("Empty text provided")
  430. return None
  431. prompt = self.config_manager.get_prompt("video", prompt_type, scenario)
  432. # 构建API请求
  433. payload = {
  434. "model": self.model,
  435. "messages": [
  436. {
  437. "role": "system",
  438. "content": prompt
  439. },
  440. {
  441. "role": "user",
  442. "content": text
  443. },
  444. {
  445. "role": "user",
  446. "content": f"上下文信息:{context_info}"
  447. }
  448. ],
  449. "max_tokens": max_length if max_length else None
  450. }
  451. # 发送API请求
  452. response = await self._make_api_request("chat/completions", payload)
  453. return response["choices"][0]["message"]["content"]
  454. # except Exception as e:
  455. # logger.error(f"Failed to process text async: {str(e)}")
  456. # return None
  457. async def generate_batch_captions_async(self,
  458. files: Dict[str, Dict[str, Union[str, int]]],
  459. scenario: Optional[str] = None,
  460. max_concurrent: int = 5) -> Dict[str, Optional[str]]:
  461. """
  462. 异步批量生成媒体描述
  463. Args:
  464. files: 文件配置字典
  465. scenario: 场景类型
  466. max_concurrent: 最大并发数
  467. Returns:
  468. Dict[str, Optional[str]]: 文件路径或标识符到描述的映射
  469. """
  470. results = {}
  471. # 创建信号量控制并发
  472. semaphore = asyncio.Semaphore(max_concurrent)
  473. async def process_single_file(file_path: str, config: Dict[str, Any]) -> tuple[str, Optional[str]]:
  474. """处理单个文件的异步函数"""
  475. async with semaphore: # 使用信号量控制并发
  476. try:
  477. media_type = config["type"]
  478. prompt_type = config.get("prompt_type", "caption" if media_type != "text" else "summary")
  479. if media_type == "video":
  480. fps = config.get("fps", 2)
  481. result = await self._process_video_async(
  482. file_path=file_path,
  483. prompt_type=prompt_type,
  484. scenario=scenario,
  485. fps=fps,
  486. context_info=config.get("context_info")
  487. )
  488. elif media_type == "image":
  489. result = await self._process_image_async(
  490. file_path=file_path,
  491. prompt_type=prompt_type,
  492. scenario=scenario,
  493. context_info=config.get("context_info")
  494. )
  495. elif media_type == "text":
  496. if "content" not in config:
  497. logger.error(f"Text content not provided for {file_path}")
  498. return file_path, None
  499. result = await self._process_text_async(
  500. text=config["content"],
  501. prompt_type=prompt_type,
  502. scenario=scenario,
  503. max_length=config.get("max_length"),
  504. context_info=config.get("context_info")
  505. )
  506. else:
  507. logger.warning(f"Unsupported media type: {media_type}")
  508. return file_path, None
  509. return file_path, result
  510. except Exception as e:
  511. logger.error(f"Failed to process file {file_path}: {str(e)}")
  512. return file_path, None
  513. # 创建所有任务
  514. tasks = [
  515. process_single_file(file_path, config)
  516. for file_path, config in files.items()
  517. ]
  518. # 并行执行所有任务
  519. completed_tasks = await asyncio.gather(*tasks)
  520. # 整理结果
  521. results = dict(completed_tasks)
  522. return results
  523. def generate_batch_captions(self,
  524. files: Dict[str, Dict[str, Union[str, int]]],
  525. scenario: Optional[str] = None) -> Dict[str, Optional[str]]:
  526. """
  527. 批量生成媒体描述的同步包装器
  528. Args:
  529. files: 文件配置字典,格式为:
  530. {
  531. "file_path": {
  532. "type": "video"|"image"|"text",
  533. "prompt_type": str, # 可选
  534. "fps": int, # 仅视频可用
  535. "content": str, # 仅文本类型需要
  536. "max_length": int # 可选,仅文本类型可用
  537. }
  538. }
  539. scenario: 场景类型
  540. Returns:
  541. Dict[str, Optional[str]]: 文件路径或标识符到描述的映射
  542. """
  543. # 创建事件循环
  544. loop = asyncio.get_event_loop()
  545. # 运行异步方法
  546. return loop.run_until_complete(
  547. self.generate_batch_captions_async(files, scenario)
  548. )
  549. media_captioner: MediaCaptioner = MediaCaptioner()
  550. if __name__ == "__main__":
  551. async def main():
  552. # 初始化
  553. captioner = MediaCaptioner()
  554. # 处理文本
  555. text_content = """
  556. 近日,研究人员在深海发现了一种新的海洋生物物种。
  557. 这种生物具有独特的生物发光能力,可以在完全黑暗的环境中发出蓝绿色的光。
  558. 科学家们认为,这一发现对于了解深海生态系统具有重要意义。
  559. """
  560. # 批量处理示例
  561. files = {
  562. "./test_data/sample_video.mp4": {
  563. "type": "video",
  564. "prompt_type": "caption",
  565. "fps": 2
  566. },
  567. "./test_data/sample_image.jpg": {
  568. "type": "image",
  569. "prompt_type": "caption"
  570. },
  571. "text_sample": {
  572. "type": "text",
  573. "content": text_content,
  574. "prompt_type": "summary",
  575. "max_length": 200
  576. }
  577. }
  578. # 异步批量处理
  579. results = await captioner.generate_batch_captions_async(
  580. files,
  581. scenario="academic",
  582. max_concurrent=5
  583. )
  584. print("批量处理结果:", results)
  585. # 运行异步主函数
  586. asyncio.run(main())