| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268 |
- """
- API客户端基类
- 提供通用API调用功能,包括错误处理、重试机制、日志记录等
- """
- import time
- import logging
- from typing import Any, Dict, Optional, Callable
- from dataclasses import dataclass
- import requests
- from requests.adapters import HTTPAdapter
- from urllib3.util import Retry
- from taskflow.logger import get_logger
- logger = get_logger("api_modules.base_client")
- @dataclass
- class RetryConfig:
- """
- 重试配置类
- 该类用于配置 API 请求的重试机制,包括最大重试次数、退避因子、需要重试的 HTTP 状态码及需要重试的异常类型。
- 属性:
- max_retries (int): 最大重试次数,默认为 3。当请求失败达到该次数后不再重试。
- backoff_factor (float): 退避因子,控制每次重试间的等待时长。默认为 1.0。
- retry_on_status (tuple): 需要进行重试的 HTTP 状态码,默认为 (500, 502, 503, 504)。
- retry_on_exception (tuple): 需要重试的异常类型,例如连接超时或连接错误,默认为 (requests.exceptions.ConnectionError, requests.exceptions.Timeout)。
- """
- max_retries: int = 3
- backoff_factor: float = 1.0
- retry_on_status: tuple = (500, 502, 503, 504)
- retry_on_exception: tuple = (requests.exceptions.ConnectionError, requests.exceptions.Timeout)
- class APIError(Exception):
- """API调用异常"""
-
- def __init__(self, message: str, status_code: Optional[int] = None, response: Optional[Dict] = None):
- """
- 初始化API错误
-
- Args:
- message: 错误消息
- status_code: HTTP状态码
- response: 响应内容
- """
- super().__init__(message)
- self.message = message
- self.status_code = status_code
- self.response = response
-
- def __str__(self):
- if self.status_code:
- return f"{self.message} (Status: {self.status_code})"
- return self.message
- class APIClient:
- """
- API客户端基类
-
- 提供通用的API调用功能:
- - 统一的请求接口
- - 自动重试机制
- - 错误处理
- - 日志记录
- - 超时控制
-
- 使用示例:
- >>> client = APIClient(base_url="https://api.example.com", api_key="your_key")
- >>> response = client.post("/endpoint", json={"data": "value"})
- """
- def __init__(
- self,
- base_url: str,
- api_key: Optional[str] = None,
- timeout: int = 300,
- retry_config: Optional[RetryConfig] = None,
- headers: Optional[Dict[str, str]] = None
- ):
- """
- 初始化API客户端
-
- Args:
- base_url: API基础URL
- api_key: API密钥(可选,也可以通过headers传入)
- timeout: 请求超时时间(秒)
- retry_config: 重试配置
- headers: 默认请求头
- """
- self.base_url = base_url.rstrip('/')
- self.api_key = api_key
- self.timeout = timeout
- self.retry_config = retry_config or RetryConfig()
-
- # 设置默认请求头
- self.default_headers = {
- "Content-Type": "application/json",
- **({} if headers is None else headers)
- }
-
- if api_key:
- self.default_headers["Authorization"] = f"Bearer {api_key}"
-
- # 创建session并配置重试
- self.session = requests.Session()
- self._setup_retry()
-
- logger.info(f"初始化API客户端: {self.base_url}")
- def _setup_retry(self):
- """配置重试机制"""
- retry = Retry(
- total=self.retry_config.max_retries,
- backoff_factor=self.retry_config.backoff_factor,
- status_forcelist=self.retry_config.retry_on_status,
- raise_on_status=False,
- )
- adapter = HTTPAdapter(max_retries=retry)
- self.session.mount('https://', adapter)
- self.session.mount('http://', adapter)
- def _build_url(self, endpoint: str) -> str:
- """
- 构建完整的URL
-
- Args:
- endpoint: API端点路径
-
- Returns:
- 完整的URL
- """
- endpoint = endpoint.lstrip('/')
- return f"{self.base_url}/{endpoint}"
- def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
- """
- 处理API响应
-
- Args:
- response: requests响应对象
-
- Returns:
- 解析后的响应数据
-
- Raises:
- APIError: 如果请求失败
- """
- try:
- response.raise_for_status()
- except requests.exceptions.HTTPError as e:
- # 尝试解析错误响应
- error_detail = None
- try:
- error_detail = response.json()
- except:
- error_detail = response.text
- raise APIError(
- message=f"API请求失败:{str(e)}",
- status_code=response.status_code,
- response=error_detail
- )
-
- # 解析响应数据
- try:
- return response.json()
- except ValueError:
- return {"content": response.text}
- def _log_request(self, method: str, url: str, **kwargs):
- """记录请求日志"""
- logger.debug(f"{method} {url}")
- if "json" in kwargs:
- logger.debug(f"请求体: {kwargs['json']}")
- def _log_response(self, response: requests.Response):
- """记录响应日志"""
- logger.debug(f"响应状态: {response.status_code}")
- try:
- logger.debug(f"响应体: {response.json()}")
- except:
- logger.debug(f"响应体: {response.text[:200]}")
- def request(
- self,
- method: str,
- endpoint: str,
- headers: Optional[Dict[str, str]] = None,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 发送API请求
-
- Args:
- method: HTTP方法(GET, POST, PUT, DELETE等)
- endpoint: API端点路径
- headers: 额外的请求头(会与默认请求头合并)
- **kwargs: 传递给requests的其他参数
-
- Returns:
- API响应数据
-
- Raises:
- APIError: 如果请求失败
- """
- url = self._build_url(endpoint)
- # 合并请求头
- request_headers = {**self.default_headers}
- if headers:
- request_headers.update(headers)
- # 记录请求
- self._log_request(method, url, **kwargs)
- try:
- response = self.session.request(
- method=method,
- url=url,
- headers=request_headers,
- timeout=self.timeout,
- **kwargs
- )
- # 记录响应
- self._log_response(response)
- return self._handle_response(response)
- except requests.exceptions.RequestException as e:
- logger.error(f"请求异常: {e}")
- raise APIError(f"网络请求失败: {str(e)}")
- def get(self, endpoint: str, **kwargs) -> Dict[str, Any]:
- """发送GET请求"""
- return self.request("GET", endpoint, **kwargs)
-
- def post(self, endpoint: str, **kwargs) -> Dict[str, Any]:
- """发送POST请求"""
- return self.request("POST", endpoint, **kwargs)
-
- def put(self, endpoint: str, **kwargs) -> Dict[str, Any]:
- """发送PUT请求"""
- return self.request("PUT", endpoint, **kwargs)
-
- def delete(self, endpoint: str, **kwargs) -> Dict[str, Any]:
- """发送DELETE请求"""
- return self.request("DELETE", endpoint, **kwargs)
-
- def patch(self, endpoint: str, **kwargs) -> Dict[str, Any]:
- """发送PATCH请求"""
- return self.request("PATCH", endpoint, **kwargs)
-
- def close(self):
- """关闭session"""
- self.session.close()
- logger.info("API客户端已关闭")
-
- def __enter__(self):
- """上下文管理器入口"""
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- """上下文管理器出口"""
- self.close()
|