""" 火山引擎ARK API客户端 封装ARK API的调用,提供类型安全的接口 """ import io import os import base64 import asyncio from PIL import Image from typing import List, Optional, Dict, Any, Union from dataclasses import dataclass, asdict from enum import Enum from .base_client import APIClient, APIError from taskflow.logger import get_logger from taskflow.config import get_config logger = get_logger("api_modules.ark_client") class ContentType(str, Enum): """内容类型枚举""" INPUT_TEXT = "input_text" INPUT_IMAGE = "input_image" INPUT_VIDEO = "input_video" OUTPUT_TEXT = "output_text" OUTPUT_IMAGE = "output_image" @dataclass class ArkTextContent: """ARK文本内容""" type: str = ContentType.INPUT_TEXT text: str = "" def to_dict(self) -> Dict[str, Any]: """转换为字典""" return {"type": self.type, "text": self.text} @dataclass class ArkImageContent: """ARK图片内容""" type: str = ContentType.INPUT_IMAGE image_url: str = "" def to_dict(self) -> Dict[str, Any]: """转换为字典""" return {"type": self.type, "image_url": self.image_url} @dataclass class ArkVideoContent: """ARK视频内容""" type: str = ContentType.INPUT_VIDEO video_url: str = "" def to_dict(self) -> Dict[str, Any]: """转换为字典""" return {"type": self.type, "video_url": self.video_url} @dataclass class ArkMessage: """ARK消息""" role: str = "user" content: List[Union[ArkTextContent, ArkImageContent, Dict[str, Any]]] = None def __post_init__(self): """初始化后处理""" if self.content is None: self.content = [] def add_text(self, text: str): """添加文本内容""" self.content.append(ArkTextContent(text=text)) def add_image(self, image_url: str): """添加图片内容""" if "http" not in image_url: image_base64 = self._encode_image(image_url) image_url = f"data:image/jpeg;base64,{image_base64}" self.content.append(ArkImageContent(image_url=image_url)) def add_video(self, video_url: str): """添加视频内容""" if "http" not in video_url: video_base64 = self._encode_video(video_url) video_url = f"data:video/mp4;base64,{video_base64}" self.content.append(ArkVideoContent(video_url=video_url)) def to_dict(self) -> Dict[str, Any]: """转换为字典""" content_list = [] for item in self.content: if isinstance(item, (ArkTextContent, ArkImageContent, ArkVideoContent)): content_list.append(item.to_dict()) else: content_list.append(item) return { "role": self.role, "content": content_list } def _encode_video(self, video_path: str) -> str: """ 将视频文件转换为base64编码 Args: video_path: 视频文件路径 Returns: str: base64编码的视频数据 Raises: FileNotFoundError: 视频文件不存在 IOError: 读取文件失败 """ if not os.path.exists(video_path): raise FileNotFoundError(f"Video file not found: {video_path}") with open(video_path, "rb") as f: return base64.b64encode(f.read()).decode("utf-8") def _encode_image(self, image_path: str) -> str: """ 将图片文件转换为base64编码 Args: image_path: 图片文件路径 Returns: str: base64编码的图片数据 Raises: FileNotFoundError: 图片文件不存在 IOError: 读取或处理图片失败 """ if not os.path.exists(image_path): raise FileNotFoundError(f"Image file not found: {image_path}") with Image.open(image_path) as img: buffered = io.BytesIO() img.save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue()).decode("utf-8") class ArkClient(APIClient): """ 火山引擎ARK API客户端 封装ARK API的调用,提供便捷的接口 """ DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com" DEFAULT_ENDPOINT = "/api/v3/responses" DEFAULT_MODEL = "doubao-seed-1-6-251015" def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: Optional[str] = None, timeout: int = 300, **kwargs ): """ 初始化ARK API客户端 Args: api_key: API密钥(如果为None,会尝试从环境变量或配置中获取) base_url: API基础URL(默认使用官方URL) model: 模型名称(如果为None,会尝试从配置中获取) timeout: 请求超时时间(秒,默认300秒,因为AI模型可能需要较长时间) **kwargs: 传递给APIClient的其他参数 """ # 获取API密钥(优先级:参数 > 环境变量 > 配置) if api_key is None: api_key = os.getenv("ARK_API_KEY") if api_key is None: config = get_config() api_key = config.get("api.ark.api_key") if not api_key: raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供") # 获取base_url(优先级:参数 > 配置 > 默认值) if base_url is None: config = get_config() base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL) # 获取model(优先级:参数 > 配置 > 默认值) if model is None: config = get_config() model = config.get("api.ark.model", self.DEFAULT_MODEL) super().__init__( base_url=base_url, api_key=api_key, timeout=timeout, **kwargs ) # 保存模型名称 self.model = model logger.info(f"ARK API客户端初始化完成,模型: {self.model}") def chat( self, model: Optional[str] = None, messages: List[Union[ArkMessage, Dict[str, Any]]] = None, system_prompt: Optional[Union[str, ArkMessage, Dict[str, Any]]] = None, **kwargs ) -> Dict[str, Any]: """ 发送聊天请求 Args: model: 模型名称(如果为None,使用初始化时设置的模型或DEFAULT_MODEL) messages: 消息列表(ArkMessage对象或字典) system_prompt: 系统提示词,可以是字符串、ArkMessage对象或字典。 如果提供,会自动作为第一条消息(role="system")添加到消息列表前。 如果messages中已存在role="system"的消息,则不会重复添加。 **kwargs: 其他请求参数 Returns: API响应数据 Raises: APIError: 如果请求失败 """ # 如果没有提供model参数,使用实例变量中的model if model is None: model = getattr(self, 'model', self.DEFAULT_MODEL) # 转换消息格式 input_messages = [] # 检查是否已有system角色的消息 has_system_message = False for msg in messages: if isinstance(msg, ArkMessage): if msg.role == "system": has_system_message = True elif isinstance(msg, dict): if msg.get("role") == "system": has_system_message = True # 处理系统提示词 if system_prompt is not None and not has_system_message: if isinstance(system_prompt, str): # 字符串格式,转换为ArkMessage system_msg = ArkMessage(role="system") system_msg.add_text(system_prompt) input_messages.append(system_msg.to_dict()) elif isinstance(system_prompt, ArkMessage): # 确保role为system system_prompt.role = "system" input_messages.append(system_prompt.to_dict()) elif isinstance(system_prompt, dict): # 字典格式,确保role为system system_prompt = system_prompt.copy() system_prompt["role"] = "system" input_messages.append(system_prompt) else: raise ValueError(f"不支持的系统提示词类型: {type(system_prompt)}") # 添加其他消息 for msg in messages: if isinstance(msg, ArkMessage): input_messages.append(msg.to_dict()) else: input_messages.append(msg) # 构建请求体 request_data = { "model": model, "input": input_messages, **kwargs } logger.info(f"发送聊天请求,模型: {model}, 消息数: {len(input_messages)}") try: response = self.post( endpoint=self.DEFAULT_ENDPOINT, json=request_data ) logger.info("聊天请求成功") return response except APIError as e: logger.error(f"聊天请求失败: {e}") raise def chat_simple( self, model: Optional[str] = None, text: str = None, image_url: Optional[str] = None, system_prompt: Optional[Union[str, ArkMessage, Dict[str, Any]]] = None, **kwargs ) -> Dict[str, Any]: """ 简化的聊天接口(仅文本或文本+图片) Args: model: 模型名称(如果为None,使用初始化时设置的模型或DEFAULT_MODEL) text: 文本内容(必填) image_url: 可选的图片URL system_prompt: 系统提示词,可以是字符串、ArkMessage对象或字典。 如果提供,会自动作为第一条消息(role="system")添加到消息列表前。 **kwargs: 其他请求参数 Returns: API响应数据 """ # 如果没有提供model参数,使用实例变量中的model if model is None: model = getattr(self, 'model', self.DEFAULT_MODEL) if text is None: raise ValueError("文本内容不能为空") message = ArkMessage(role="user") if image_url: message.add_image(image_url) message.add_text(text) return self.chat(model=model, messages=[message], system_prompt=system_prompt, **kwargs) def get_response_text(self, response: Dict[str, Any]) -> Optional[str]: """ 从响应中提取文本内容 Args: response: API响应数据 Returns: 提取的文本内容,如果不存在则返回None """ try: # 根据ARK API的实际响应结构提取文本 # 这里需要根据实际API响应格式调整 if "output" in response and isinstance(response["output"], list): for item in response["output"]: if isinstance(item, dict) and "content" in item: for content in item.get("content", []): if content.get("type") == ContentType.OUTPUT_TEXT: return content.get("text") # 备用提取方式 if "choices" in response: for choice in response["choices"]: if "message" in choice and "content" in choice["message"]: return choice["message"]["content"] return None except Exception as e: logger.warning(f"提取响应文本失败: {e}") return None