llm.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. """
  2. LLM请求模块
  3. 提供多模态和文本LLM请求功能,支持图片和文本输入
  4. """
  5. import io
  6. import os
  7. import sys
  8. import time
  9. import base64
  10. import logging
  11. import numpy as np
  12. import requests
  13. from PIL import Image
  14. from openai import OpenAI
  15. from requests.adapters import HTTPAdapter
  16. from urllib3.util.retry import Retry
  17. from logger_setup import logger
  18. from conf import *
  19. from tos import HttpMethodType
  20. def image_to_base64(image):
  21. """
  22. 将PIL Image对象转换为base64编码字符串
  23. Args:
  24. image: PIL Image对象
  25. Returns:
  26. base64编码的字符串
  27. """
  28. image_io = io.BytesIO()
  29. image.save(image_io, format='JPEG', quality=95)
  30. image_io.seek(0)
  31. image_base64 = base64.b64encode(image_io.read()).decode('utf-8')
  32. return image_base64
  33. def download_image_with_retry(url, max_retries=3, timeout=30):
  34. """
  35. 下载图片并重试机制
  36. Args:
  37. url: 图片URL
  38. max_retries: 最大重试次数
  39. timeout: 超时时间(秒)
  40. Returns:
  41. PIL Image对象,失败返回None
  42. """
  43. session = requests.Session()
  44. retry_strategy = Retry(
  45. total=max_retries,
  46. backoff_factor=1,
  47. status_forcelist=[429, 500, 502, 503, 504],
  48. )
  49. adapter = HTTPAdapter(max_retries=retry_strategy)
  50. session.mount("http://", adapter)
  51. session.mount("https://", adapter)
  52. try:
  53. logger.info(f"正在下载图片: {url}")
  54. response = session.get(url, timeout=timeout)
  55. response.raise_for_status()
  56. logger.info("图片下载成功")
  57. return Image.open(io.BytesIO(response.content))
  58. except Exception as e:
  59. logger.error(f"下载图片失败: {e}")
  60. return None
  61. def image_reader(image):
  62. """
  63. 图片读取器,将各种格式的图片转换为base64编码的data URI
  64. 支持:
  65. - 本地文件路径(字符串)
  66. - HTTP/HTTPS URL(字符串)
  67. - numpy数组
  68. - PIL Image对象
  69. Args:
  70. image: 图片输入(路径、URL、numpy数组或PIL Image)
  71. Returns:
  72. base64编码的data URI字符串
  73. Raises:
  74. Exception: 如果下载图片失败
  75. """
  76. if isinstance(image, str):
  77. if image.startswith("http"):
  78. # 下载网络图片
  79. out_image = download_image_with_retry(image)
  80. if out_image is None:
  81. raise Exception(f"无法下载图片: {image}")
  82. else:
  83. # 读取本地图片
  84. out_image = Image.open(image)
  85. elif isinstance(image, np.ndarray):
  86. out_image = Image.fromarray(image)
  87. else:
  88. out_image = image
  89. out_image = out_image.convert('RGB')
  90. base64_img = image_to_base64(out_image)
  91. return f"data:image/jpeg;base64,{base64_img}"
  92. def get_lm_text(sys_prompt, user_prompt):
  93. """
  94. 文本LLM请求(已废弃,使用llm_request类替代)
  95. Args:
  96. sys_prompt: 系统提示词
  97. user_prompt: 用户提示词
  98. Returns:
  99. LLM返回的文本
  100. """
  101. completion = LMConfig.lm_client.chat.completions.create(
  102. messages = [
  103. {"role": "system", "content": sys_prompt},
  104. {"role": "user", "content": user_prompt},
  105. ],
  106. model=LMConfig.model,
  107. )
  108. return completion.choices[0].message.content
  109. # ==================== 图片处理工具 ====================
  110. def compress_image(input_path, output_path):
  111. """
  112. 压缩图片到目标大小
  113. Args:
  114. input_path: 输入图片路径
  115. output_path: 输出图片路径
  116. Returns:
  117. 最终使用的压缩质量
  118. """
  119. img = Image.open(input_path)
  120. current_size = os.path.getsize(input_path)
  121. # 粗略的估计压缩质量,也可以从常量开始,逐步减小压缩质量,直到文件大小小于目标大小
  122. image_quality = int(float(MMMConfig.target_size / current_size) * 100)
  123. img.save(output_path, optimize=True, quality=int(float(MMMConfig.target_size / current_size) * 100))
  124. # 如果压缩后文件大小仍然大于目标大小,则继续压缩
  125. # 压缩质量递减,直到文件大小小于目标大小
  126. while os.path.getsize(output_path) > MMMConfig.target_size:
  127. img = Image.open(output_path)
  128. image_quality -= 10
  129. if image_quality <= 0:
  130. break
  131. img.save(output_path, optimize=True, quality=image_quality)
  132. return image_quality
  133. def upload_tos(filename, tos_object_key):
  134. """
  135. 上传文件到TOS并获取预签名URL
  136. Args:
  137. filename: 本地文件路径
  138. tos_object_key: TOS对象键
  139. Returns:
  140. 预签名的URL
  141. Raises:
  142. Exception: 上传失败时抛出异常
  143. """
  144. tos_client, inner_tos_client = MMMConfig.tos_client, MMMConfig.inner_tos_client
  145. try:
  146. # 将本地文件上传到目标桶中, filename为本地压缩后图片的完整路径
  147. tos_client.put_object_from_file(MMMConfig.tos_bucket_name, tos_object_key, filename)
  148. # 获取上传后预签名的 url
  149. return inner_tos_client.pre_signed_url(HttpMethodType.Http_Method_Get, MMMConfig.tos_bucket_name, tos_object_key)
  150. except Exception as e:
  151. if isinstance(e, tos.exceptions.TosClientError):
  152. # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常
  153. logger.error('TOS客户端错误, message:{}, cause: {}'.format(e.message, e.cause))
  154. elif isinstance(e, tos.exceptions.TosServerError):
  155. # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息
  156. logger.error('TOS服务端错误, code: {}'.format(e.code))
  157. # request id 可定位具体问题,强烈建议日志中保存
  158. logger.error('error with request id: {}'.format(e.request_id))
  159. logger.error('error with message: {}'.format(e.message))
  160. logger.error('error with http code: {}'.format(e.status_code))
  161. else:
  162. logger.error('TOS上传失败,未知错误: {}'.format(e))
  163. raise e
  164. # def doubao_MMM_request(pre_signed_url_output, prompt):
  165. # client = MMMConfig.client
  166. # response = client.chat.completions.create(
  167. # model=MMMConfig.model,
  168. # messages=[{"role": "user","content": [
  169. # {"type": "text", "text": prompt},
  170. # {"type": "image_url", "image_url": {"url": pre_signed_url_output.signed_url}}
  171. # ],
  172. # }],
  173. # temperature=0.8,
  174. # extra_headers={"x-ark-beta-vision": "true"}
  175. # )
  176. # result = response.choices[0].message.content
  177. # return result
  178. class llm_request:
  179. """
  180. LLM请求类
  181. 提供多模态和文本LLM请求功能
  182. """
  183. def __init__(self, api_key, base_url, model):
  184. """
  185. 初始化LLM请求客户端
  186. Args:
  187. api_key: API密钥
  188. base_url: API基础URL
  189. model: 模型名称
  190. """
  191. self.api_key = api_key
  192. self.base_url = base_url
  193. self.model = model
  194. def llm_mm_request(self, usr_text, img, sys_text="You are a helpful assistant."):
  195. """
  196. 多模态请求(单张图片)
  197. Args:
  198. usr_text: 用户文本提示
  199. img: 图片(路径、URL、numpy数组或PIL Image)
  200. sys_text: 系统提示词
  201. Returns:
  202. LLM返回的文本内容
  203. """
  204. client = OpenAI(
  205. api_key=self.api_key,
  206. base_url=self.base_url
  207. )
  208. completion = client.chat.completions.create(
  209. model=self.model,
  210. messages=[
  211. {
  212. "role": "system",
  213. "content": [{"type": "text", "text": sys_text}]
  214. },
  215. {
  216. "role": "user",
  217. "content": [
  218. {
  219. "type": "image_url",
  220. "image_url": {"url": image_reader(img)},
  221. },
  222. {"type": "text", "text": usr_text},
  223. ],
  224. }
  225. ],
  226. temperature=0.5,
  227. top_p=0.7,
  228. timeout=120.0
  229. )
  230. return completion.choices[0].message.content
  231. def llm_mm_2_request(self, usr_text, imgs, sys_text="You are a helpful assistant."):
  232. """
  233. 多模态请求(多张图片)
  234. Args:
  235. usr_text: 用户文本提示
  236. imgs: 图片列表(路径、URL、numpy数组或PIL Image)
  237. sys_text: 系统提示词
  238. Returns:
  239. LLM返回的文本内容
  240. """
  241. client = OpenAI(
  242. api_key=self.api_key,
  243. base_url=self.base_url
  244. )
  245. image_content_list = [
  246. {
  247. "type": "image_url",
  248. "image_url": {"url": image_reader(img)},
  249. }
  250. for img in imgs
  251. ]
  252. text_content = {"type": "text", "text": usr_text}
  253. user_content = image_content_list + [text_content]
  254. completion = client.chat.completions.create(
  255. model=self.model,
  256. messages=[
  257. {
  258. "role": "system",
  259. "content": [{"type": "text", "text": sys_text}]
  260. },
  261. {
  262. "role": "user",
  263. "content": user_content,
  264. }
  265. ],
  266. temperature=0.5,
  267. top_p=0.7,
  268. timeout=120.0
  269. )
  270. return completion.choices[0].message.content
  271. def llm_text_request(self, text, sys_text="You are a helpful assistant."):
  272. """
  273. 纯文本LLM请求
  274. Args:
  275. text: 用户文本提示
  276. sys_text: 系统提示词
  277. Returns:
  278. LLM返回的文本内容
  279. """
  280. client = OpenAI(
  281. api_key=self.api_key,
  282. base_url=self.base_url
  283. )
  284. completion = client.chat.completions.create(
  285. model=self.model,
  286. messages=[
  287. {
  288. "role": "system",
  289. "content": sys_text
  290. },
  291. {
  292. "role": "user",
  293. "content": text,
  294. }
  295. ],
  296. temperature=0.9,
  297. timeout=120.0
  298. )
  299. return completion.choices[0].message.content
  300. if __name__=="__main__":
  301. ##ali
  302. # ky="sk-TstsKbfIFjdNpjNGo6uBHzZayp5Bq8FjTV0b6BwyXflaOFLs"
  303. # baseurl="https://api.openaius.com/v1"
  304. # model="gpt-5"
  305. #ali
  306. ky="sk-04b63960983445f980d85ff185a17876"
  307. baseurl="https://dashscope.aliyuncs.com/compatible-mode/v1"
  308. model="qwen3-vl-plus"
  309. ##doubao
  310. # ky='817dff39-5586-4f9b-acba-55004167c0b1'
  311. # baseurl="https://ark.cn-beijing.volces.com/api/v3"
  312. # model="doubao-1-5-vision-pro-32k-250115"
  313. llm=llm_request(ky,baseurl,model)
  314. imgs=r"H:\data\线稿图\S1261A097_S1261A097_concatenated.jpg"
  315. res1=llm.llm_mm_request("判断一下图2是不是图1的平铺图,纽扣数量是否一致",imgs)
  316. print(res1)
  317. # res2=llm.llm_text_request("你好!你是谁")
  318. # print(res2)