""" 火山引擎ARK API异步客户端 封装ARK API的异步调用,提供类型安全的接口 """ import os import asyncio from typing import List, Optional, Dict, Any, Union from dataclasses import dataclass, asdict from enum import Enum from .base_client_async import AsyncAPIClient, APIError from .ark_client import ArkMessage, ContentType # 复用同步版本的消息类 from taskflow.logger import get_logger from taskflow.config import get_config logger = get_logger("api_modules.ark_client_async") class AsyncArkClient(AsyncAPIClient): """ 火山引擎ARK API异步客户端 封装ARK API的异步调用,提供便捷的接口 """ DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com" DEFAULT_ENDPOINT = "/api/v3/responses" DEFAULT_MODEL = "doubao-seed-1-6-251015" def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: Optional[str] = None, timeout: int = 300, **kwargs ): """ 初始化ARK API异步客户端 Args: api_key: API密钥(如果为None,会尝试从环境变量或配置中获取) base_url: API基础URL(默认使用官方URL) model: 模型名称(如果为None,会尝试从配置中获取) timeout: 请求超时时间(秒,默认300秒,因为AI模型可能需要较长时间) **kwargs: 传递给AsyncAPIClient的其他参数 """ # 获取API密钥(优先级:参数 > 环境变量 > 配置) if api_key is None: api_key = os.getenv("ARK_API_KEY") if api_key is None: config = get_config() api_key = config.get("api.ark.api_key") if not api_key: raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供") # 获取base_url(优先级:参数 > 配置 > 默认值) if base_url is None: config = get_config() base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL) # 获取model(优先级:参数 > 配置 > 默认值) if model is None: config = get_config() model = config.get("api.ark.model", self.DEFAULT_MODEL) super().__init__( base_url=base_url, api_key=api_key, timeout=timeout, **kwargs ) # 保存模型名称 self.model = model logger.info(f"ARK API异步客户端初始化完成,模型: {self.model}") async def chat( self, model: Optional[str] = None, messages: List[Union[ArkMessage, Dict[str, Any]]] = None, system_prompt: Optional[Union[str, ArkMessage, Dict[str, Any]]] = None, **kwargs ) -> Dict[str, Any]: """ 发送异步聊天请求 Args: model: 模型名称(如果为None,使用初始化时设置的模型或DEFAULT_MODEL) messages: 消息列表(ArkMessage对象或字典) system_prompt: 系统提示词,可以是字符串、ArkMessage对象或字典。 如果提供,会自动作为第一条消息(role="system")添加到消息列表前。 如果messages中已存在role="system"的消息,则不会重复添加。 **kwargs: 其他请求参数 Returns: API响应数据 Raises: APIError: 如果请求失败 """ # 如果没有提供model参数,使用实例变量中的model if model is None: model = getattr(self, 'model', self.DEFAULT_MODEL) # 转换消息格式 input_messages = [] # 检查是否已有system角色的消息 has_system_message = False for msg in messages: if isinstance(msg, ArkMessage): if msg.role == "system": has_system_message = True elif isinstance(msg, dict): if msg.get("role") == "system": has_system_message = True # 处理系统提示词 if system_prompt is not None and not has_system_message: if isinstance(system_prompt, str): # 字符串格式,转换为ArkMessage system_msg = ArkMessage(role="system") system_msg.add_text(system_prompt) input_messages.append(system_msg.to_dict()) elif isinstance(system_prompt, ArkMessage): # 确保role为system system_prompt.role = "system" input_messages.append(system_prompt.to_dict()) elif isinstance(system_prompt, dict): # 字典格式,确保role为system system_prompt = system_prompt.copy() system_prompt["role"] = "system" input_messages.append(system_prompt) else: raise ValueError(f"不支持的系统提示词类型: {type(system_prompt)}") # 添加其他消息 for msg in messages: if isinstance(msg, ArkMessage): input_messages.append(msg.to_dict()) else: input_messages.append(msg) # 构建请求体 request_data = { "model": model, "input": input_messages, **kwargs } logger.info(f"发送异步聊天请求,模型: {model}, 消息数: {len(input_messages)}") try: response = await self.post( endpoint=self.DEFAULT_ENDPOINT, json=request_data ) logger.info("异步聊天请求成功") return response except APIError as e: logger.error(f"异步聊天请求失败: {e}") raise async def chat_simple( self, model: Optional[str] = None, text: str = None, image_url: Optional[str] = None, system_prompt: Optional[Union[str, ArkMessage, Dict[str, Any]]] = None, **kwargs ) -> Dict[str, Any]: """ 简化的异步聊天接口(仅文本或文本+图片) Args: model: 模型名称(如果为None,使用初始化时设置的模型或DEFAULT_MODEL) text: 文本内容(必填) image_url: 可选的图片URL system_prompt: 系统提示词,可以是字符串、ArkMessage对象或字典。 如果提供,会自动作为第一条消息(role="system")添加到消息列表前。 **kwargs: 其他请求参数 Returns: API响应数据 """ # 如果没有提供model参数,使用实例变量中的model if model is None: model = getattr(self, 'model', self.DEFAULT_MODEL) if text is None: raise ValueError("文本内容不能为空") message = ArkMessage(role="user") if image_url: message.add_image(image_url) message.add_text(text) return await self.chat(model=model, messages=[message], system_prompt=system_prompt, **kwargs) def get_response_text(self, response: Dict[str, Any]) -> Optional[str]: """ 从响应中提取文本内容 Args: response: API响应数据 Returns: 提取的文本内容,如果不存在则返回None """ try: # 根据ARK API的实际响应结构提取文本 # 这里需要根据实际API响应格式调整 if "output" in response and isinstance(response["output"], list): for item in response["output"]: if isinstance(item, dict) and "content" in item: for content in item.get("content", []): if content.get("type") == ContentType.OUTPUT_TEXT: return content.get("text") # 备用提取方式 if "choices" in response: for choice in response["choices"]: if "message" in choice and "content" in choice["message"]: return choice["message"]["content"] return None except Exception as e: logger.warning(f"提取响应文本失败: {e}") return None