| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 |
- """
- 火山引擎ARK API异步客户端
- 封装ARK API的异步调用,提供类型安全的接口
- """
- import os
- import asyncio
- from typing import List, Optional, Dict, Any, Union
- from dataclasses import dataclass, asdict
- from enum import Enum
- from .base_client_async import AsyncAPIClient, APIError
- from .ark_client import ArkMessage, ContentType # 复用同步版本的消息类
- from taskflow.logger import get_logger
- from taskflow.config import get_config
- logger = get_logger("api_modules.ark_client_async")
- class AsyncArkClient(AsyncAPIClient):
- """
- 火山引擎ARK API异步客户端
-
- 封装ARK API的异步调用,提供便捷的接口
- """
-
- DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com"
- DEFAULT_ENDPOINT = "/api/v3/responses"
- DEFAULT_MODEL = "doubao-seed-1-6-251015"
-
- def __init__(
- self,
- api_key: Optional[str] = None,
- base_url: Optional[str] = None,
- model: Optional[str] = None,
- timeout: int = 300,
- **kwargs
- ):
- """
- 初始化ARK API异步客户端
-
- Args:
- api_key: API密钥(如果为None,会尝试从环境变量或配置中获取)
- base_url: API基础URL(默认使用官方URL)
- model: 模型名称(如果为None,会尝试从配置中获取)
- timeout: 请求超时时间(秒,默认300秒,因为AI模型可能需要较长时间)
- **kwargs: 传递给AsyncAPIClient的其他参数
- """
- # 获取API密钥(优先级:参数 > 环境变量 > 配置)
- if api_key is None:
- api_key = os.getenv("ARK_API_KEY")
- if api_key is None:
- config = get_config()
- api_key = config.get("api.ark.api_key")
-
- if not api_key:
- raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供")
-
- # 获取base_url(优先级:参数 > 配置 > 默认值)
- if base_url is None:
- config = get_config()
- base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL)
-
- # 获取model(优先级:参数 > 配置 > 默认值)
- if model is None:
- config = get_config()
- model = config.get("api.ark.model", self.DEFAULT_MODEL)
-
- super().__init__(
- base_url=base_url,
- api_key=api_key,
- timeout=timeout,
- **kwargs
- )
-
- # 保存模型名称
- self.model = model
-
- logger.info(f"ARK API异步客户端初始化完成,模型: {self.model}")
-
- async def chat(
- self,
- model: Optional[str] = None,
- messages: List[Union[ArkMessage, Dict[str, Any]]] = None,
- system_prompt: Optional[Union[str, ArkMessage, Dict[str, Any]]] = None,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 发送异步聊天请求
-
- Args:
- model: 模型名称(如果为None,使用初始化时设置的模型或DEFAULT_MODEL)
- messages: 消息列表(ArkMessage对象或字典)
- system_prompt: 系统提示词,可以是字符串、ArkMessage对象或字典。
- 如果提供,会自动作为第一条消息(role="system")添加到消息列表前。
- 如果messages中已存在role="system"的消息,则不会重复添加。
- **kwargs: 其他请求参数
-
- Returns:
- API响应数据
-
- Raises:
- APIError: 如果请求失败
- """
- # 如果没有提供model参数,使用实例变量中的model
- if model is None:
- model = getattr(self, 'model', self.DEFAULT_MODEL)
-
- # 转换消息格式
- input_messages = []
-
- # 检查是否已有system角色的消息
- has_system_message = False
- for msg in messages:
- if isinstance(msg, ArkMessage):
- if msg.role == "system":
- has_system_message = True
- elif isinstance(msg, dict):
- if msg.get("role") == "system":
- has_system_message = True
-
- # 处理系统提示词
- if system_prompt is not None and not has_system_message:
- if isinstance(system_prompt, str):
- # 字符串格式,转换为ArkMessage
- system_msg = ArkMessage(role="system")
- system_msg.add_text(system_prompt)
- input_messages.append(system_msg.to_dict())
- elif isinstance(system_prompt, ArkMessage):
- # 确保role为system
- system_prompt.role = "system"
- input_messages.append(system_prompt.to_dict())
- elif isinstance(system_prompt, dict):
- # 字典格式,确保role为system
- system_prompt = system_prompt.copy()
- system_prompt["role"] = "system"
- input_messages.append(system_prompt)
- else:
- raise ValueError(f"不支持的系统提示词类型: {type(system_prompt)}")
-
- # 添加其他消息
- for msg in messages:
- if isinstance(msg, ArkMessage):
- input_messages.append(msg.to_dict())
- else:
- input_messages.append(msg)
-
- # 构建请求体
- request_data = {
- "model": model,
- "input": input_messages,
- **kwargs
- }
-
- logger.info(f"发送异步聊天请求,模型: {model}, 消息数: {len(input_messages)}")
- try:
- response = await self.post(
- endpoint=self.DEFAULT_ENDPOINT,
- json=request_data
- )
-
- logger.info("异步聊天请求成功")
- return response
-
- except APIError as e:
- logger.error(f"异步聊天请求失败: {e}")
- raise
-
- async def chat_simple(
- self,
- model: Optional[str] = None,
- text: str = None,
- image_url: Optional[str] = None,
- system_prompt: Optional[Union[str, ArkMessage, Dict[str, Any]]] = None,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 简化的异步聊天接口(仅文本或文本+图片)
-
- Args:
- model: 模型名称(如果为None,使用初始化时设置的模型或DEFAULT_MODEL)
- text: 文本内容(必填)
- image_url: 可选的图片URL
- system_prompt: 系统提示词,可以是字符串、ArkMessage对象或字典。
- 如果提供,会自动作为第一条消息(role="system")添加到消息列表前。
- **kwargs: 其他请求参数
-
- Returns:
- API响应数据
- """
- # 如果没有提供model参数,使用实例变量中的model
- if model is None:
- model = getattr(self, 'model', self.DEFAULT_MODEL)
-
- if text is None:
- raise ValueError("文本内容不能为空")
-
- message = ArkMessage(role="user")
-
- if image_url:
- message.add_image(image_url)
- message.add_text(text)
-
- return await self.chat(model=model, messages=[message], system_prompt=system_prompt, **kwargs)
-
- def get_response_text(self, response: Dict[str, Any]) -> Optional[str]:
- """
- 从响应中提取文本内容
-
- Args:
- response: API响应数据
-
- Returns:
- 提取的文本内容,如果不存在则返回None
- """
- try:
- # 根据ARK API的实际响应结构提取文本
- # 这里需要根据实际API响应格式调整
- if "output" in response and isinstance(response["output"], list):
- for item in response["output"]:
- if isinstance(item, dict) and "content" in item:
- for content in item.get("content", []):
- if content.get("type") == ContentType.OUTPUT_TEXT:
- return content.get("text")
-
- # 备用提取方式
- if "choices" in response:
- for choice in response["choices"]:
- if "message" in choice and "content" in choice["message"]:
- return choice["message"]["content"]
-
- return None
-
- except Exception as e:
- logger.warning(f"提取响应文本失败: {e}")
- return None
|