text_generator.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685
  1. import os
  2. import base64
  3. import io
  4. import asyncio
  5. import aiohttp
  6. import time
  7. import functools
  8. from concurrent.futures import ThreadPoolExecutor
  9. from PIL import Image
  10. from typing import Optional, Dict, Any, Literal, Union, List, Callable
  11. from volcenginesdkarkruntime import Ark
  12. from utils.logger_config import setup_logger
  13. from utils.config_manager import ConfigManager
  14. from dotenv import load_dotenv
  15. # 加载.env文件
  16. load_dotenv()
  17. logger = setup_logger(__name__)
  18. def async_performance_monitor(func: Callable):
  19. """异步方法性能监控装饰器"""
  20. @functools.wraps(func)
  21. async def wrapper(*args, **kwargs):
  22. start_time = time.time()
  23. try:
  24. result = await func(*args, **kwargs)
  25. end_time = time.time()
  26. execution_time = end_time - start_time
  27. logger.info(f"{func.__name__} completed in {execution_time:.2f} seconds")
  28. return result
  29. except Exception as e:
  30. end_time = time.time()
  31. execution_time = end_time - start_time
  32. logger.error(f"{func.__name__} failed after {execution_time:.2f} seconds: {str(e)}")
  33. raise
  34. return wrapper
  35. def sync_performance_monitor(func: Callable):
  36. """同步方法性能监控装饰器"""
  37. @functools.wraps(func)
  38. def wrapper(*args, **kwargs):
  39. start_time = time.time()
  40. try:
  41. result = func(*args, **kwargs)
  42. end_time = time.time()
  43. execution_time = end_time - start_time
  44. logger.info(f"{func.__name__} completed in {execution_time:.2f} seconds")
  45. return result
  46. except Exception as e:
  47. end_time = time.time()
  48. execution_time = end_time - start_time
  49. logger.error(f"{func.__name__} failed after {execution_time:.2f} seconds: {str(e)}")
  50. raise
  51. return wrapper
  52. class MediaCaptioner:
  53. """媒体描述生成器,使用火山引擎API进行视频、图像和文本内容理解"""
  54. def __init__(self, api_key: Optional[str] = None,
  55. base_url: str = "https://ark.cn-beijing.volces.com/api/v3",
  56. model: str = "doubao-seed-1-6-250615",
  57. config_path: Optional[str] = None):
  58. """
  59. 初始化媒体描述生成器
  60. Args:
  61. api_key: 火山引擎API密钥,如果为None则从环境变量获取
  62. base_url: API基础URL
  63. model: 使用的模型ID
  64. config_path: 提示词配置文件路径
  65. """
  66. try:
  67. self.api_key = api_key or os.getenv("VOLC_API_KEY")
  68. if not self.api_key:
  69. raise ValueError("API key must be provided either through constructor or environment variable VOLC_API_KEY")
  70. self.client = Ark(
  71. api_key=self.api_key,
  72. base_url=base_url
  73. )
  74. self.base_url = base_url
  75. self.model = model
  76. self.config_manager = ConfigManager(config_path)
  77. logger.info(f"Initialized MediaCaptioner with model: {model}")
  78. except Exception as e:
  79. logger.error(f"Failed to initialize MediaCaptioner: {str(e)}")
  80. raise
  81. @sync_performance_monitor
  82. def _encode_video(self, video_path: str) -> str:
  83. """
  84. 将视频文件转换为base64编码
  85. Args:
  86. video_path: 视频文件路径
  87. Returns:
  88. str: base64编码的视频数据
  89. Raises:
  90. FileNotFoundError: 视频文件不存在
  91. IOError: 读取文件失败
  92. """
  93. if not os.path.exists(video_path):
  94. raise FileNotFoundError(f"Video file not found: {video_path}")
  95. with open(video_path, "rb") as f:
  96. return base64.b64encode(f.read()).decode("utf-8")
  97. @sync_performance_monitor
  98. def _encode_image(self, image_path: str) -> str:
  99. """
  100. 将图片文件转换为base64编码
  101. Args:
  102. image_path: 图片文件路径
  103. Returns:
  104. str: base64编码的图片数据
  105. Raises:
  106. FileNotFoundError: 图片文件不存在
  107. IOError: 读取或处理图片失败
  108. """
  109. if not os.path.exists(image_path):
  110. raise FileNotFoundError(f"Image file not found: {image_path}")
  111. with Image.open(image_path) as img:
  112. buffered = io.BytesIO()
  113. img.save(buffered, format="JPEG")
  114. return base64.b64encode(buffered.getvalue()).decode("utf-8")
  115. def generate_video_caption(self,
  116. video_path: str,
  117. prompt_type: str = "caption",
  118. scenario: Optional[str] = None,
  119. fps: int = 2,
  120. context_info: Optional[str] = None) -> Optional[str]:
  121. """
  122. 生成视频描述的同步包装器
  123. Args:
  124. video_path: 视频文件路径
  125. prompt_type: 提示词类型
  126. scenario: 场景类型
  127. fps: 视频采样帧率
  128. Returns:
  129. str: 视频描述,如果处理失败则返回None
  130. """
  131. try:
  132. loop = asyncio.get_event_loop()
  133. except RuntimeError:
  134. loop = asyncio.new_event_loop()
  135. asyncio.set_event_loop(loop)
  136. return loop.run_until_complete(
  137. self._process_video_async(
  138. file_path=video_path,
  139. prompt_type=prompt_type,
  140. scenario=scenario,
  141. fps=fps,
  142. context_info=context_info
  143. )
  144. )
  145. async def generate_video_caption_async(self,
  146. video_path: str,
  147. prompt_type: str = "caption",
  148. scenario: Optional[str] = None,
  149. fps: int = 2,
  150. context_info: Optional[str] = None) -> Optional[str]:
  151. """
  152. 异步生成视频描述
  153. Args:
  154. video_path: 视频文件路径
  155. prompt_type: 提示词类型
  156. scenario: 场景类型
  157. fps: 视频采样帧率
  158. Returns:
  159. str: 视频描述,如果处理失败则返回None
  160. """
  161. return await self._process_video_async(
  162. file_path=video_path,
  163. prompt_type=prompt_type,
  164. scenario=scenario,
  165. fps=fps,
  166. context_info=context_info
  167. )
  168. def generate_image_caption(self,
  169. image_path: str,
  170. prompt_type: str = "caption",
  171. scenario: Optional[str] = None,
  172. context_info: Optional[str] = None) -> Optional[str]:
  173. """
  174. 生成图片描述的同步包装器
  175. Args:
  176. image_path: 图片文件路径
  177. prompt_type: 提示词类型
  178. scenario: 场景类型
  179. Returns:
  180. str: 图片描述,如果处理失败则返回None
  181. """
  182. try:
  183. loop = asyncio.get_event_loop()
  184. except RuntimeError:
  185. loop = asyncio.new_event_loop()
  186. asyncio.set_event_loop(loop)
  187. return loop.run_until_complete(
  188. self._process_image_async(
  189. file_path=image_path,
  190. prompt_type=prompt_type,
  191. scenario=scenario,
  192. context_info=context_info
  193. )
  194. )
  195. async def generate_image_caption_async(self,
  196. image_path: str,
  197. prompt_type: str = "caption",
  198. scenario: Optional[str] = None,
  199. context_info: Optional[str] = None) -> Optional[str]:
  200. """
  201. 异步生成图片描述
  202. Args:
  203. image_path: 图片文件路径
  204. prompt_type: 提示词类型
  205. scenario: 场景类型
  206. Returns:
  207. str: 图片描述,如果处理失败则返回None
  208. """
  209. return await self._process_image_async(
  210. file_path=image_path,
  211. prompt_type=prompt_type,
  212. scenario=scenario,
  213. context_info=context_info
  214. )
  215. def generate_text_understanding(self,
  216. user_prompt: str,
  217. system_prompt: str,
  218. max_length: Optional[int] = None,
  219. context_info: Optional[str] = None) -> Optional[str]:
  220. """
  221. 生成文本理解结果的同步包装器
  222. Args:
  223. user_prompt: 需要理解的文本内容
  224. system_prompt: 提示词类型
  225. scenario: 场景类型
  226. max_length: 最大输出长度
  227. Returns:
  228. str: 文本理解结果,如果处理失败则返回None
  229. """
  230. try:
  231. loop = asyncio.get_event_loop()
  232. except RuntimeError:
  233. loop = asyncio.new_event_loop()
  234. asyncio.set_event_loop(loop)
  235. return loop.run_until_complete(
  236. self._process_text_async(
  237. user_prompt=user_prompt,
  238. system_prompt=system_prompt,
  239. max_length=max_length,
  240. context_info=context_info
  241. )
  242. )
  243. async def generate_text_understanding_async(self,
  244. text: str,
  245. prompt_type: str = "summary",
  246. scenario: Optional[str] = None,
  247. max_length: Optional[int] = None,
  248. context_info: Optional[str] = None) -> Optional[str]:
  249. """
  250. 异步生成文本理解结果
  251. Args:
  252. text: 需要理解的文本内容
  253. prompt_type: 提示词类型
  254. scenario: 场景类型
  255. max_length: 最大输出长度
  256. Returns:
  257. str: 文本理解结果,如果处理失败则返回None
  258. """
  259. return await self._process_text_async(
  260. text=text,
  261. prompt_type=prompt_type,
  262. scenario=scenario,
  263. max_length=max_length,
  264. context_info=context_info
  265. )
  266. def generate_multi_aspect_understanding(self,
  267. text: str,
  268. prompt_types: List[str],
  269. scenario: Optional[str] = None) -> Dict[str, Optional[str]]:
  270. """
  271. 从多个角度生成文本理解结果的同步包装器
  272. Args:
  273. text: 需要理解的文本内容
  274. prompt_types: 提示词类型列表
  275. scenario: 场景类型
  276. Returns:
  277. Dict[str, Optional[str]]: 提示词类型到理解结果的映射
  278. """
  279. try:
  280. loop = asyncio.get_event_loop()
  281. except RuntimeError:
  282. loop = asyncio.new_event_loop()
  283. asyncio.set_event_loop(loop)
  284. return loop.run_until_complete(
  285. self.generate_multi_aspect_understanding_async(
  286. text=text,
  287. prompt_types=prompt_types,
  288. scenario=scenario
  289. )
  290. )
  291. async def generate_multi_aspect_understanding_async(self,
  292. text: str,
  293. prompt_types: List[str],
  294. scenario: Optional[str] = None,
  295. context_info: Optional[str] = None) -> Dict[str, Optional[str]]:
  296. """
  297. 异步从多个角度生成文本理解结果
  298. Args:
  299. text: 需要理解的文本内容
  300. prompt_types: 提示词类型列表
  301. scenario: 场景类型
  302. Returns:
  303. Dict[str, Optional[str]]: 提示词类型到理解结果的映射
  304. """
  305. tasks = [
  306. self._process_text_async(
  307. text=text,
  308. prompt_type=prompt_type,
  309. scenario=scenario,
  310. context_info=context_info
  311. )
  312. for prompt_type in prompt_types
  313. ]
  314. results = await asyncio.gather(*tasks)
  315. return dict(zip(prompt_types, results))
  316. async def _make_api_request(self,
  317. endpoint: str,
  318. payload: Dict[str, Any],
  319. timeout: int = 180) -> Dict[str, Any]:
  320. """
  321. 发送API请求的通用方法
  322. Args:
  323. endpoint: API端点
  324. payload: 请求负载
  325. timeout: 超时时间(秒)
  326. Returns:
  327. Dict[str, Any]: API响应
  328. Raises:
  329. aiohttp.ClientError: API请求失败
  330. asyncio.TimeoutError: 请求超时
  331. """
  332. try:
  333. async with aiohttp.ClientSession() as session:
  334. async with session.post(
  335. f"{self.base_url}/{endpoint}",
  336. json=payload,
  337. headers={"Authorization": f"Bearer {self.api_key}"},
  338. timeout=timeout
  339. ) as response:
  340. if response.status != 200:
  341. error_text = await response.text()
  342. raise aiohttp.ClientError(f"API request failed with status {response.status}: {error_text}")
  343. return await response.json()
  344. except asyncio.TimeoutError:
  345. logger.error(f"API request timed out after {timeout} seconds")
  346. raise
  347. except Exception as e:
  348. logger.error(f"API request failed: {str(e)}")
  349. raise
  350. @async_performance_monitor
  351. async def _process_video_async(self, file_path: str, prompt_type: str,
  352. scenario: Optional[str] = None, fps: int = 2, context_info: Optional[str] = None) -> Optional[str]:
  353. """异步处理视频文件"""
  354. try:
  355. # 在线程池中执行文件IO操作
  356. loop = asyncio.get_event_loop()
  357. with ThreadPoolExecutor() as pool:
  358. base64_video = await loop.run_in_executor(
  359. pool, self._encode_video, file_path
  360. )
  361. prompt = self.config_manager.get_prompt("video", prompt_type, scenario)
  362. # 构建API请求
  363. payload = {
  364. "model": self.model,
  365. "messages": [{
  366. "role": "system",
  367. "content": [
  368. {
  369. "type": "video_url",
  370. "video_url": {
  371. "url": f"data:video/mp4;base64,{base64_video}",
  372. "fps": fps
  373. }
  374. },
  375. {
  376. "type": "text",
  377. "text": prompt
  378. }
  379. ]
  380. },
  381. {
  382. "role": "user",
  383. "content": f"上下文信息:{context_info}"
  384. }
  385. ]
  386. }
  387. # 发送API请求
  388. response = await self._make_api_request("chat/completions", payload)
  389. return response["choices"][0]["message"]["content"]
  390. except Exception as e:
  391. logger.error(f"Failed to process video async: {str(e)}")
  392. return None
  393. @async_performance_monitor
  394. async def _process_image_async(self, file_path: str, prompt_type: str,
  395. scenario: Optional[str] = None, context_info: Optional[str] = None) -> Optional[str]:
  396. """异步处理图片文件"""
  397. try:
  398. # 在线程池中执行文件IO操作
  399. loop = asyncio.get_event_loop()
  400. with ThreadPoolExecutor() as pool:
  401. base64_image = await loop.run_in_executor(
  402. pool, self._encode_image, file_path
  403. )
  404. prompt = self.config_manager.get_prompt("image", prompt_type, scenario)
  405. # 构建API请求
  406. payload = {
  407. "model": self.model,
  408. "messages": [{
  409. "role": "system",
  410. "content": [
  411. {
  412. "type": "image_url",
  413. "image_url": {
  414. "url": f"data:image/jpeg;base64,{base64_image}"
  415. }
  416. },
  417. {
  418. "type": "text",
  419. "text": prompt
  420. }
  421. ]
  422. },
  423. {
  424. "role": "user",
  425. "content": f"上下文信息:{context_info}"
  426. }
  427. ]
  428. }
  429. # 发送API请求
  430. response = await self._make_api_request("chat/completions", payload)
  431. return response["choices"][0]["message"]["content"]
  432. except Exception as e:
  433. logger.error(f"Failed to process image async: {str(e)}")
  434. return None
  435. @async_performance_monitor
  436. async def _process_text_async(self, user_prompt: str, system_prompt: str,
  437. max_length: Optional[int] = None,
  438. context_info: Optional[str] = None) -> Optional[str]:
  439. """异步处理文本内容"""
  440. # try:
  441. if not user_prompt.strip():
  442. logger.error("Empty text provided")
  443. return None
  444. # 构建API请求
  445. payload = {
  446. "model": self.model,
  447. "messages": [
  448. {
  449. "role": "system",
  450. "content": system_prompt
  451. },
  452. {
  453. "role": "user",
  454. "content": user_prompt
  455. },
  456. {
  457. "role": "user",
  458. "content": f"上下文信息:{context_info}"
  459. }
  460. ],
  461. "max_tokens": max_length if max_length else None
  462. }
  463. # 发送API请求
  464. response = await self._make_api_request("chat/completions", payload)
  465. return response["choices"][0]["message"]["content"]
  466. # except Exception as e:
  467. # logger.error(f"Failed to process text async: {str(e)}")
  468. # return None
  469. async def generate_batch_captions_async(self,
  470. files: Dict[str, Dict[str, Union[str, int]]],
  471. scenario: Optional[str] = None,
  472. max_concurrent: int = 5) -> Dict[str, Optional[str]]:
  473. """
  474. 异步批量生成媒体描述
  475. Args:
  476. files: 文件配置字典
  477. scenario: 场景类型
  478. max_concurrent: 最大并发数
  479. Returns:
  480. Dict[str, Optional[str]]: 文件路径或标识符到描述的映射
  481. """
  482. results = {}
  483. # 创建信号量控制并发
  484. semaphore = asyncio.Semaphore(max_concurrent)
  485. async def process_single_file(file_path: str, config: Dict[str, Any]) -> tuple[str, Optional[str]]:
  486. """处理单个文件的异步函数"""
  487. async with semaphore: # 使用信号量控制并发
  488. try:
  489. media_type = config["type"]
  490. prompt_type = config.get("prompt_type", "caption" if media_type != "text" else "summary")
  491. if media_type == "video":
  492. fps = config.get("fps", 2)
  493. result = await self._process_video_async(
  494. file_path=file_path,
  495. prompt_type=prompt_type,
  496. scenario=scenario,
  497. fps=fps,
  498. context_info=config.get("context_info")
  499. )
  500. elif media_type == "image":
  501. result = await self._process_image_async(
  502. file_path=file_path,
  503. prompt_type=prompt_type,
  504. scenario=scenario,
  505. context_info=config.get("context_info")
  506. )
  507. elif media_type == "text":
  508. if "content" not in config:
  509. logger.error(f"Text content not provided for {file_path}")
  510. return file_path, None
  511. result = await self._process_text_async(
  512. text=config["content"],
  513. prompt_type=prompt_type,
  514. scenario=scenario,
  515. max_length=config.get("max_length"),
  516. context_info=config.get("context_info")
  517. )
  518. else:
  519. logger.warning(f"Unsupported media type: {media_type}")
  520. return file_path, None
  521. return file_path, result
  522. except Exception as e:
  523. logger.error(f"Failed to process file {file_path}: {str(e)}")
  524. return file_path, None
  525. # 创建所有任务
  526. tasks = [
  527. process_single_file(file_path, config)
  528. for file_path, config in files.items()
  529. ]
  530. # 并行执行所有任务
  531. completed_tasks = await asyncio.gather(*tasks)
  532. # 整理结果
  533. results = dict(completed_tasks)
  534. return results
  535. def generate_batch_captions(self,
  536. files: Dict[str, Dict[str, Union[str, int]]],
  537. scenario: Optional[str] = None) -> Dict[str, Optional[str]]:
  538. """
  539. 批量生成媒体描述的同步包装器
  540. Args:
  541. files: 文件配置字典,格式为:
  542. {
  543. "file_path": {
  544. "type": "video"|"image"|"text",
  545. "prompt_type": str, # 可选
  546. "fps": int, # 仅视频可用
  547. "content": str, # 仅文本类型需要
  548. "max_length": int # 可选,仅文本类型可用
  549. }
  550. }
  551. scenario: 场景类型
  552. Returns:
  553. Dict[str, Optional[str]]: 文件路径或标识符到描述的映射
  554. """
  555. try:
  556. loop = asyncio.get_event_loop()
  557. except RuntimeError:
  558. loop = asyncio.new_event_loop()
  559. asyncio.set_event_loop(loop)
  560. # 运行异步方法
  561. return loop.run_until_complete(
  562. self.generate_batch_captions_async(files, scenario)
  563. )
  564. media_captioner: MediaCaptioner = MediaCaptioner()
  565. if __name__ == "__main__":
  566. async def main():
  567. # 初始化
  568. captioner = MediaCaptioner()
  569. # 处理文本
  570. text_content = """
  571. 近日,研究人员在深海发现了一种新的海洋生物物种。
  572. 这种生物具有独特的生物发光能力,可以在完全黑暗的环境中发出蓝绿色的光。
  573. 科学家们认为,这一发现对于了解深海生态系统具有重要意义。
  574. """
  575. # 批量处理示例
  576. files = {
  577. "./test_data/sample_video.mp4": {
  578. "type": "video",
  579. "prompt_type": "caption",
  580. "fps": 2
  581. },
  582. "./test_data/sample_image.jpg": {
  583. "type": "image",
  584. "prompt_type": "caption"
  585. },
  586. "text_sample": {
  587. "type": "text",
  588. "content": text_content,
  589. "prompt_type": "summary",
  590. "max_length": 200
  591. }
  592. }
  593. # 异步批量处理
  594. results = await captioner.generate_batch_captions_async(
  595. files,
  596. scenario="academic",
  597. max_concurrent=5
  598. )
  599. print("批量处理结果:", results)
  600. # 运行异步主函数
  601. asyncio.run(main())