llm.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import io
  2. from PIL import Image
  3. import os
  4. from openai import OpenAI
  5. from conf import *
  6. from tos import HttpMethodType
  7. import time
  8. from openai import OpenAI
  9. import os
  10. import base64
  11. from PIL import Image
  12. def image_to_base64(image):
  13. # 将Image对象转换为BytesIO对象
  14. image_io = io.BytesIO()
  15. image.save(image_io, format='PNG')
  16. image_io.seek(0)
  17. # 使用base64编码
  18. image_base64 = base64.b64encode(image_io.read()).decode('utf-8')
  19. return f"data:image/png;base64,{image_base64}"
  20. def image_reader(image):
  21. """图片读取器,输出PIL.Image格式的图片"""
  22. if isinstance(image,str):
  23. if image.startswith("http"):
  24. return image
  25. else:
  26. image_path = image
  27. out_image = Image.open(image_path)
  28. elif isinstance(image,np.ndarray):
  29. out_image = Image.fromarray(image)
  30. else:
  31. out_image = image
  32. out_image=out_image.convert('RGB')
  33. base64_img=image_to_base64(out_image)
  34. return base64_img
  35. def get_lm_text(sys_prompt,user_prompt):
  36. completion = LMConfig.lm_client.chat.completions.create(
  37. messages = [
  38. {"role": "system", "content": sys_prompt},
  39. {"role": "user", "content": user_prompt},
  40. ],
  41. model=LMConfig.model,
  42. )
  43. return completion.choices[0].message.content
  44. ## 多模态的输入
  45. def compress_image(input_path, output_path):
  46. img = Image.open(input_path)
  47. current_size = os.path.getsize(input_path)
  48. # 粗略的估计压缩质量,也可以从常量开始,逐步减小压缩质量,直到文件大小小于目标大小
  49. image_quality = int(float(MMMConfig.target_size / current_size) * 100)
  50. img.save(output_path, optimize=True, quality=int(float(MMMConfig.target_size / current_size) * 100))
  51. # 如果压缩后文件大小仍然大于目标大小,则继续压缩
  52. # 压缩质量递减,直到文件大小小于目标大小
  53. while os.path.getsize(output_path) > MMMConfig.target_size:
  54. img = Image.open(output_path)
  55. image_quality -= 10
  56. if image_quality <= 0:
  57. break
  58. img.save(output_path, optimize=True, quality=image_quality)
  59. return image_quality
  60. def upload_tos(filename, tos_object_key):
  61. tos_client, inner_tos_client = MMMConfig.tos_client, MMMConfig.inner_tos_client
  62. try:
  63. # 将本地文件上传到目标桶中, filename为本地压缩后图片的完整路径
  64. tos_client.put_object_from_file(MMMConfig.tos_bucket_name, tos_object_key, filename)
  65. # 获取上传后预签名的 url
  66. return inner_tos_client.pre_signed_url(HttpMethodType.Http_Method_Get, MMMConfig.tos_bucket_name, tos_object_key)
  67. except Exception as e:
  68. if isinstance(e, tos.exceptions.TosClientError):
  69. # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常
  70. print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause))
  71. elif isinstance(e, tos.exceptions.TosServerError):
  72. # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息
  73. print('fail with server error, code: {}'.format(e.code))
  74. # request id 可定位具体问题,强烈建议日志中保存
  75. print('error with request id: {}'.format(e.request_id))
  76. print('error with message: {}'.format(e.message))
  77. print('error with http code: {}'.format(e.status_code))
  78. else:
  79. print('fail with unknown error: {}'.format(e))
  80. raise e
  81. def doubao_MMM_request(pre_signed_url_output, prompt):
  82. client = MMMConfig.client
  83. response = client.chat.completions.create(
  84. model=MMMConfig.model,
  85. messages=[{"role": "user","content": [
  86. {"type": "text", "text": prompt},
  87. {"type": "image_url", "image_url": {"url": pre_signed_url_output.signed_url}}
  88. ],
  89. }],
  90. temperature=0.8,
  91. extra_headers={"x-ark-beta-vision": "true"}
  92. )
  93. result = response.choices[0].message.content
  94. return result
  95. class llm_request:
  96. def __init__(self,api_key,base_url,model) -> None:
  97. self.api_key=api_key
  98. self.base_url=base_url
  99. self.model=model
  100. def llm_mm_request(self,usr_text,img,sys_text="You are a helpful assistant."):
  101. client = OpenAI(
  102. # 若没有配置环境变量,请用百炼API Key将下行替换为:api_key="sk-xxx"
  103. api_key=self.api_key,
  104. base_url=self.base_url
  105. )
  106. completion = client.chat.completions.create(
  107. model=self.model,#
  108. messages=[
  109. {
  110. "role": "system",
  111. "content": [{"type":"text","text": sys_text}]},
  112. {
  113. "role": "user",
  114. "content": [
  115. {
  116. "type": "image_url",
  117. # 需要注意,传入Base64,图像格式(即image/{format})需要与支持的图片列表中的Content Type保持一致。"f"是字符串格式化的方法。
  118. # PNG图像: f"data:image/png;base64,{base64_image}"
  119. # JPEG图像: f"data:image/jpeg;base64,{base64_image}"
  120. # WEBP图像: f"data:image/webp;base64,{base64_image}"
  121. "image_url": {"url": image_reader(img)},
  122. },
  123. {"type": "text", "text": usr_text},
  124. ],
  125. }
  126. ],
  127. temperature=1.5,
  128. top_p=0.85,
  129. presence_penalty=1.5,
  130. frequency_penalty=1.5,
  131. )
  132. return completion.choices[0].message.content
  133. def llm_text_request(self,text,sys_text="You are a helpful assistant."):
  134. client = OpenAI(
  135. # 若没有配置环境变量,请用百炼API Key将下行替换为:api_key="sk-xxx"
  136. api_key=self.api_key,
  137. base_url=self.base_url
  138. )
  139. completion = client.chat.completions.create(
  140. model=self.model,#
  141. messages=[
  142. {
  143. "role": "system",
  144. "content": sys_text},
  145. {
  146. "role": "user",
  147. "content": text,
  148. }
  149. ],
  150. temperature=0.9,
  151. )
  152. return completion.choices[0].message.content
  153. if __name__=="__main__":
  154. ##ali
  155. ky="sk-04b63960983445f980d85ff185a17876"
  156. baseurl="https://dashscope.aliyuncs.com/compatible-mode/v1"
  157. model="qwen-vl-max-latest"
  158. ##doubao
  159. # ky='817dff39-5586-4f9b-acba-55004167c0b1'
  160. # baseurl="https://ark.cn-beijing.volces.com/api/v3"
  161. # model="doubao-1-5-vision-pro-32k-250115"
  162. llm=llm_request(ky,baseurl,model)
  163. res1=llm.llm_mm_request("描述一下图片中的衣服","/data/data/Mia/product_env_project/gen_sellpoint/企业微信截图_17372766091671.png")
  164. print(res1)
  165. res2=llm.llm_text_request("你好!你是谁")
  166. print(res2)