ark_client.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. """
  2. 火山引擎ARK API客户端
  3. 封装ARK API的调用,提供类型安全的接口
  4. """
  5. import io
  6. import os
  7. import base64
  8. import asyncio
  9. from PIL import Image
  10. from typing import List, Optional, Dict, Any, Union
  11. from dataclasses import dataclass, asdict
  12. from enum import Enum
  13. from .base_client import APIClient, APIError
  14. from taskflow.logger import get_logger
  15. from taskflow.config import get_config
  16. logger = get_logger("api_modules.ark_client")
  17. class ContentType(str, Enum):
  18. """内容类型枚举"""
  19. INPUT_TEXT = "input_text"
  20. INPUT_IMAGE = "input_image"
  21. INPUT_VIDEO = "input_video"
  22. OUTPUT_TEXT = "output_text"
  23. OUTPUT_IMAGE = "output_image"
  24. @dataclass
  25. class ArkTextContent:
  26. """ARK文本内容"""
  27. type: str = ContentType.INPUT_TEXT
  28. text: str = ""
  29. def to_dict(self) -> Dict[str, Any]:
  30. """转换为字典"""
  31. return {"type": self.type, "text": self.text}
  32. @dataclass
  33. class ArkImageContent:
  34. """ARK图片内容"""
  35. type: str = ContentType.INPUT_IMAGE
  36. image_url: str = ""
  37. def to_dict(self) -> Dict[str, Any]:
  38. """转换为字典"""
  39. return {"type": self.type, "image_url": self.image_url}
  40. @dataclass
  41. class ArkVideoContent:
  42. """ARK视频内容"""
  43. type: str = ContentType.INPUT_VIDEO
  44. video_url: str = ""
  45. def to_dict(self) -> Dict[str, Any]:
  46. """转换为字典"""
  47. return {"type": self.type, "video_url": self.video_url}
  48. @dataclass
  49. class ArkMessage:
  50. """ARK消息"""
  51. role: str = "user"
  52. content: List[Union[ArkTextContent, ArkImageContent, Dict[str, Any]]] = None
  53. def __post_init__(self):
  54. """初始化后处理"""
  55. if self.content is None:
  56. self.content = []
  57. def add_text(self, text: str):
  58. """添加文本内容"""
  59. self.content.append(ArkTextContent(text=text))
  60. def add_image(self, image_url: str):
  61. """添加图片内容"""
  62. if "http" not in image_url:
  63. image_base64 = self._encode_image(image_url)
  64. image_url = f"data:image/jpeg;base64,{image_base64}"
  65. self.content.append(ArkImageContent(image_url=image_url))
  66. def add_video(self, video_url: str):
  67. """添加视频内容"""
  68. if "http" not in video_url:
  69. video_base64 = self._encode_video(video_url)
  70. video_url = f"data:video/mp4;base64,{video_base64}"
  71. self.content.append(ArkVideoContent(video_url=video_url))
  72. def to_dict(self) -> Dict[str, Any]:
  73. """转换为字典"""
  74. content_list = []
  75. for item in self.content:
  76. if isinstance(item, (ArkTextContent, ArkImageContent, ArkVideoContent)):
  77. content_list.append(item.to_dict())
  78. else:
  79. content_list.append(item)
  80. return {
  81. "role": self.role,
  82. "content": content_list
  83. }
  84. def _encode_video(self, video_path: str) -> str:
  85. """
  86. 将视频文件转换为base64编码
  87. Args:
  88. video_path: 视频文件路径
  89. Returns:
  90. str: base64编码的视频数据
  91. Raises:
  92. FileNotFoundError: 视频文件不存在
  93. IOError: 读取文件失败
  94. """
  95. if not os.path.exists(video_path):
  96. raise FileNotFoundError(f"Video file not found: {video_path}")
  97. with open(video_path, "rb") as f:
  98. return base64.b64encode(f.read()).decode("utf-8")
  99. def _encode_image(self, image_path: str) -> str:
  100. """
  101. 将图片文件转换为base64编码
  102. Args:
  103. image_path: 图片文件路径
  104. Returns:
  105. str: base64编码的图片数据
  106. Raises:
  107. FileNotFoundError: 图片文件不存在
  108. IOError: 读取或处理图片失败
  109. """
  110. if not os.path.exists(image_path):
  111. raise FileNotFoundError(f"Image file not found: {image_path}")
  112. with Image.open(image_path) as img:
  113. buffered = io.BytesIO()
  114. img.save(buffered, format="JPEG")
  115. return base64.b64encode(buffered.getvalue()).decode("utf-8")
  116. class ArkClient(APIClient):
  117. """
  118. 火山引擎ARK API客户端
  119. 封装ARK API的调用,提供便捷的接口
  120. """
  121. DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com"
  122. DEFAULT_ENDPOINT = "/api/v3/responses"
  123. DEFAULT_MODEL = "doubao-seed-1-6-251015"
  124. def __init__(
  125. self,
  126. api_key: Optional[str] = None,
  127. base_url: Optional[str] = None,
  128. model: Optional[str] = None,
  129. timeout: int = 300,
  130. **kwargs
  131. ):
  132. """
  133. 初始化ARK API客户端
  134. Args:
  135. api_key: API密钥(如果为None,会尝试从环境变量或配置中获取)
  136. base_url: API基础URL(默认使用官方URL)
  137. model: 模型名称(如果为None,会尝试从配置中获取)
  138. timeout: 请求超时时间(秒,默认300秒,因为AI模型可能需要较长时间)
  139. **kwargs: 传递给APIClient的其他参数
  140. """
  141. # 获取API密钥(优先级:参数 > 环境变量 > 配置)
  142. if api_key is None:
  143. api_key = os.getenv("ARK_API_KEY")
  144. if api_key is None:
  145. config = get_config()
  146. api_key = config.get("api.ark.api_key")
  147. if not api_key:
  148. raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供")
  149. # 获取base_url(优先级:参数 > 配置 > 默认值)
  150. if base_url is None:
  151. config = get_config()
  152. base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL)
  153. # 获取model(优先级:参数 > 配置 > 默认值)
  154. if model is None:
  155. config = get_config()
  156. model = config.get("api.ark.model", self.DEFAULT_MODEL)
  157. super().__init__(
  158. base_url=base_url,
  159. api_key=api_key,
  160. timeout=timeout,
  161. **kwargs
  162. )
  163. # 保存模型名称
  164. self.model = model
  165. logger.info(f"ARK API客户端初始化完成,模型: {self.model}")
  166. def chat(
  167. self,
  168. model: Optional[str] = None,
  169. messages: List[Union[ArkMessage, Dict[str, Any]]] = None,
  170. system_prompt: Optional[Union[str, ArkMessage, Dict[str, Any]]] = None,
  171. **kwargs
  172. ) -> Dict[str, Any]:
  173. """
  174. 发送聊天请求
  175. Args:
  176. model: 模型名称(如果为None,使用初始化时设置的模型或DEFAULT_MODEL)
  177. messages: 消息列表(ArkMessage对象或字典)
  178. system_prompt: 系统提示词,可以是字符串、ArkMessage对象或字典。
  179. 如果提供,会自动作为第一条消息(role="system")添加到消息列表前。
  180. 如果messages中已存在role="system"的消息,则不会重复添加。
  181. **kwargs: 其他请求参数
  182. Returns:
  183. API响应数据
  184. Raises:
  185. APIError: 如果请求失败
  186. """
  187. # 如果没有提供model参数,使用实例变量中的model
  188. if model is None:
  189. model = getattr(self, 'model', self.DEFAULT_MODEL)
  190. # 转换消息格式
  191. input_messages = []
  192. # 检查是否已有system角色的消息
  193. has_system_message = False
  194. for msg in messages:
  195. if isinstance(msg, ArkMessage):
  196. if msg.role == "system":
  197. has_system_message = True
  198. elif isinstance(msg, dict):
  199. if msg.get("role") == "system":
  200. has_system_message = True
  201. # 处理系统提示词
  202. if system_prompt is not None and not has_system_message:
  203. if isinstance(system_prompt, str):
  204. # 字符串格式,转换为ArkMessage
  205. system_msg = ArkMessage(role="system")
  206. system_msg.add_text(system_prompt)
  207. input_messages.append(system_msg.to_dict())
  208. elif isinstance(system_prompt, ArkMessage):
  209. # 确保role为system
  210. system_prompt.role = "system"
  211. input_messages.append(system_prompt.to_dict())
  212. elif isinstance(system_prompt, dict):
  213. # 字典格式,确保role为system
  214. system_prompt = system_prompt.copy()
  215. system_prompt["role"] = "system"
  216. input_messages.append(system_prompt)
  217. else:
  218. raise ValueError(f"不支持的系统提示词类型: {type(system_prompt)}")
  219. # 添加其他消息
  220. for msg in messages:
  221. if isinstance(msg, ArkMessage):
  222. input_messages.append(msg.to_dict())
  223. else:
  224. input_messages.append(msg)
  225. # 构建请求体
  226. request_data = {
  227. "model": model,
  228. "input": input_messages,
  229. **kwargs
  230. }
  231. logger.info(f"发送聊天请求,模型: {model}, 消息数: {len(input_messages)}")
  232. try:
  233. response = self.post(
  234. endpoint=self.DEFAULT_ENDPOINT,
  235. json=request_data
  236. )
  237. logger.info("聊天请求成功")
  238. return response
  239. except APIError as e:
  240. logger.error(f"聊天请求失败: {e}")
  241. raise
  242. def chat_simple(
  243. self,
  244. model: Optional[str] = None,
  245. text: str = None,
  246. image_url: Optional[str] = None,
  247. system_prompt: Optional[Union[str, ArkMessage, Dict[str, Any]]] = None,
  248. **kwargs
  249. ) -> Dict[str, Any]:
  250. """
  251. 简化的聊天接口(仅文本或文本+图片)
  252. Args:
  253. model: 模型名称(如果为None,使用初始化时设置的模型或DEFAULT_MODEL)
  254. text: 文本内容(必填)
  255. image_url: 可选的图片URL
  256. system_prompt: 系统提示词,可以是字符串、ArkMessage对象或字典。
  257. 如果提供,会自动作为第一条消息(role="system")添加到消息列表前。
  258. **kwargs: 其他请求参数
  259. Returns:
  260. API响应数据
  261. """
  262. # 如果没有提供model参数,使用实例变量中的model
  263. if model is None:
  264. model = getattr(self, 'model', self.DEFAULT_MODEL)
  265. if text is None:
  266. raise ValueError("文本内容不能为空")
  267. message = ArkMessage(role="user")
  268. if image_url:
  269. message.add_image(image_url)
  270. message.add_text(text)
  271. return self.chat(model=model, messages=[message], system_prompt=system_prompt, **kwargs)
  272. def get_response_text(self, response: Dict[str, Any]) -> Optional[str]:
  273. """
  274. 从响应中提取文本内容
  275. Args:
  276. response: API响应数据
  277. Returns:
  278. 提取的文本内容,如果不存在则返回None
  279. """
  280. try:
  281. # 根据ARK API的实际响应结构提取文本
  282. # 这里需要根据实际API响应格式调整
  283. if "output" in response and isinstance(response["output"], list):
  284. for item in response["output"]:
  285. if isinstance(item, dict) and "content" in item:
  286. for content in item.get("content", []):
  287. if content.get("type") == ContentType.OUTPUT_TEXT:
  288. return content.get("text")
  289. # 备用提取方式
  290. if "choices" in response:
  291. for choice in response["choices"]:
  292. if "message" in choice and "content" in choice["message"]:
  293. return choice["message"]["content"]
  294. return None
  295. except Exception as e:
  296. logger.warning(f"提取响应文本失败: {e}")
  297. return None