base_client_async.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. """
  2. 异步API客户端基类
  3. 提供通用异步API调用功能,包括错误处理、重试机制、日志记录等
  4. 使用 aiohttp 实现异步HTTP请求
  5. """
  6. import asyncio
  7. import logging
  8. from typing import Any, Dict, Optional, Callable
  9. from dataclasses import dataclass
  10. import aiohttp
  11. from aiohttp import ClientSession, ClientTimeout
  12. from taskflow.logger import get_logger
  13. logger = get_logger("api_modules.base_client_async")
  14. @dataclass
  15. class RetryConfig:
  16. """
  17. 重试配置类
  18. 该类用于配置 API 请求的重试机制,包括最大重试次数、退避因子、需要重试的 HTTP 状态码及需要重试的异常类型。
  19. 属性:
  20. max_retries (int): 最大重试次数,默认为 3。当请求失败达到该次数后不再重试。
  21. backoff_factor (float): 退避因子,控制每次重试间的等待时长。默认为 1.0。
  22. retry_on_status (tuple): 需要进行重试的 HTTP 状态码,默认为 (500, 502, 503, 504)。
  23. retry_on_exception (tuple): 需要重试的异常类型,例如连接超时或连接错误,默认为 (aiohttp.ClientError, asyncio.TimeoutError)。
  24. """
  25. max_retries: int = 3
  26. backoff_factor: float = 1.0
  27. retry_on_status: tuple = (500, 502, 503, 504)
  28. retry_on_exception: tuple = (aiohttp.ClientError, asyncio.TimeoutError)
  29. class APIError(Exception):
  30. """API调用异常"""
  31. def __init__(self, message: str, status_code: Optional[int] = None, response: Optional[Dict] = None):
  32. """
  33. 初始化API错误
  34. Args:
  35. message: 错误消息
  36. status_code: HTTP状态码
  37. response: 响应内容
  38. """
  39. super().__init__(message)
  40. self.message = message
  41. self.status_code = status_code
  42. self.response = response
  43. def __str__(self):
  44. if self.status_code:
  45. return f"{self.message} (Status: {self.status_code})"
  46. return self.message
  47. class AsyncAPIClient:
  48. """
  49. 异步API客户端基类
  50. 提供通用的异步API调用功能:
  51. - 统一的请求接口
  52. - 自动重试机制
  53. - 错误处理
  54. - 日志记录
  55. - 超时控制
  56. 使用示例:
  57. >>> async with AsyncAPIClient(base_url="https://api.example.com", api_key="your_key") as client:
  58. ... response = await client.post("/endpoint", json={"data": "value"})
  59. """
  60. def __init__(
  61. self,
  62. base_url: str,
  63. api_key: Optional[str] = None,
  64. timeout: int = 300,
  65. retry_config: Optional[RetryConfig] = None,
  66. headers: Optional[Dict[str, str]] = None
  67. ):
  68. """
  69. 初始化异步API客户端
  70. Args:
  71. base_url: API基础URL
  72. api_key: API密钥(可选,也可以通过headers传入)
  73. timeout: 请求超时时间(秒)
  74. retry_config: 重试配置
  75. headers: 默认请求头
  76. """
  77. self.base_url = base_url.rstrip('/')
  78. self.api_key = api_key
  79. self.timeout = timeout
  80. self.retry_config = retry_config or RetryConfig()
  81. # 设置默认请求头
  82. self.default_headers = {
  83. "Content-Type": "application/json",
  84. **({} if headers is None else headers)
  85. }
  86. if api_key:
  87. self.default_headers["Authorization"] = f"Bearer {api_key}"
  88. # 创建session(在异步上下文中创建)
  89. self._session: Optional[ClientSession] = None
  90. logger.info(f"初始化异步API客户端: {self.base_url}")
  91. async def _get_session(self) -> ClientSession:
  92. """获取或创建session"""
  93. if self._session is None or self._session.closed:
  94. timeout = ClientTimeout(total=self.timeout)
  95. self._session = ClientSession(
  96. timeout=timeout,
  97. headers=self.default_headers
  98. )
  99. return self._session
  100. def _build_url(self, endpoint: str) -> str:
  101. """
  102. 构建完整的URL
  103. Args:
  104. endpoint: API端点路径
  105. Returns:
  106. 完整的URL
  107. """
  108. endpoint = endpoint.lstrip('/')
  109. return f"{self.base_url}/{endpoint}"
  110. async def _handle_response(self, response: aiohttp.ClientResponse) -> Dict[str, Any]:
  111. """
  112. 处理API响应
  113. Args:
  114. response: aiohttp响应对象
  115. Returns:
  116. 解析后的响应数据
  117. Raises:
  118. APIError: 如果请求失败
  119. """
  120. try:
  121. response.raise_for_status()
  122. except aiohttp.ClientResponseError as e:
  123. # 尝试解析错误响应
  124. error_detail = None
  125. try:
  126. error_detail = await response.json()
  127. logger.error(f"API错误响应 (JSON): {error_detail}")
  128. except:
  129. try:
  130. error_detail = await response.text()
  131. logger.error(f"API错误响应 (Text): {error_detail}")
  132. except:
  133. error_detail = str(e)
  134. logger.error(f"API错误响应 (String): {error_detail}")
  135. raise APIError(
  136. message=f"API请求失败:{str(e)}",
  137. status_code=response.status,
  138. response=error_detail
  139. )
  140. # 解析响应数据
  141. try:
  142. return await response.json()
  143. except aiohttp.ContentTypeError:
  144. text = await response.text()
  145. return {"content": text}
  146. def _log_request(self, method: str, url: str, **kwargs):
  147. """记录请求日志"""
  148. logger.debug(f"{method} {url}")
  149. if "json" in kwargs:
  150. logger.debug(f"请求体: {kwargs['json']}")
  151. def _log_response(self, status: int, response_data: Any = None):
  152. """记录响应日志"""
  153. logger.debug(f"响应状态: {status}")
  154. if response_data:
  155. logger.debug(f"响应体: {response_data}")
  156. async def _request_with_retry(
  157. self,
  158. method: str,
  159. url: str,
  160. headers: Optional[Dict[str, str]] = None,
  161. **kwargs
  162. ) -> Dict[str, Any]:
  163. """
  164. 带重试机制的请求
  165. Args:
  166. method: HTTP方法
  167. url: 完整URL
  168. headers: 额外的请求头
  169. **kwargs: 传递给aiohttp的其他参数
  170. Returns:
  171. API响应数据
  172. """
  173. session = await self._get_session()
  174. # 合并请求头
  175. request_headers = {**self.default_headers}
  176. if headers:
  177. request_headers.update(headers)
  178. # 记录请求
  179. self._log_request(method, url, **kwargs)
  180. last_exception = None
  181. for attempt in range(self.retry_config.max_retries + 1):
  182. try:
  183. async with session.request(
  184. method=method,
  185. url=url,
  186. headers=request_headers,
  187. **kwargs
  188. ) as response:
  189. response_data = await self._handle_response(response)
  190. self._log_response(response.status, response_data)
  191. return response_data
  192. except Exception as e:
  193. last_exception = e
  194. # 检查是否需要重试
  195. should_retry = False
  196. # 检查状态码
  197. if isinstance(e, APIError) and e.status_code:
  198. if e.status_code in self.retry_config.retry_on_status:
  199. should_retry = True
  200. # 检查异常类型
  201. if isinstance(e, self.retry_config.retry_on_exception):
  202. should_retry = True
  203. # 如果不需要重试或已达到最大重试次数,直接抛出异常
  204. if not should_retry or attempt >= self.retry_config.max_retries:
  205. break
  206. # 计算退避时间
  207. wait_time = self.retry_config.backoff_factor * (2 ** attempt)
  208. logger.warning(f"请求失败,{wait_time}秒后重试 (尝试 {attempt + 1}/{self.retry_config.max_retries + 1}): {e}")
  209. await asyncio.sleep(wait_time)
  210. # 所有重试都失败
  211. logger.error(f"请求异常: {last_exception}")
  212. if isinstance(last_exception, APIError):
  213. raise last_exception
  214. raise APIError(f"网络请求失败: {str(last_exception)}")
  215. async def request(
  216. self,
  217. method: str,
  218. endpoint: str,
  219. headers: Optional[Dict[str, str]] = None,
  220. **kwargs
  221. ) -> Dict[str, Any]:
  222. """
  223. 发送异步API请求
  224. Args:
  225. method: HTTP方法(GET, POST, PUT, DELETE等)
  226. endpoint: API端点路径
  227. headers: 额外的请求头(会与默认请求头合并)
  228. **kwargs: 传递给aiohttp的其他参数
  229. Returns:
  230. API响应数据
  231. Raises:
  232. APIError: 如果请求失败
  233. """
  234. url = self._build_url(endpoint)
  235. return await self._request_with_retry(method, url, headers, **kwargs)
  236. async def get(self, endpoint: str, **kwargs) -> Dict[str, Any]:
  237. """发送异步GET请求"""
  238. return await self.request("GET", endpoint, **kwargs)
  239. async def post(self, endpoint: str, **kwargs) -> Dict[str, Any]:
  240. """发送异步POST请求"""
  241. return await self.request("POST", endpoint, **kwargs)
  242. async def put(self, endpoint: str, **kwargs) -> Dict[str, Any]:
  243. """发送异步PUT请求"""
  244. return await self.request("PUT", endpoint, **kwargs)
  245. async def delete(self, endpoint: str, **kwargs) -> Dict[str, Any]:
  246. """发送异步DELETE请求"""
  247. return await self.request("DELETE", endpoint, **kwargs)
  248. async def patch(self, endpoint: str, **kwargs) -> Dict[str, Any]:
  249. """发送异步PATCH请求"""
  250. return await self.request("PATCH", endpoint, **kwargs)
  251. async def close(self):
  252. """关闭session"""
  253. if self._session and not self._session.closed:
  254. await self._session.close()
  255. logger.info("异步API客户端已关闭")
  256. async def __aenter__(self):
  257. """异步上下文管理器入口"""
  258. await self._get_session()
  259. return self
  260. async def __aexit__(self, exc_type, exc_val, exc_tb):
  261. """异步上下文管理器出口"""
  262. await self.close()