qa_robot.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from email import message
  2. import io
  3. import os
  4. import base64
  5. from PIL import Image
  6. from dotenv import load_dotenv
  7. from volcenginesdkarkruntime import Ark
  8. from volcenginesdkarkruntime.resources import content_generation
  9. from config.prompt import RECOMMEND_PROMPT, INTENT_PROMPT, ANSWER_PROMPT
  10. from utils.logger_config import setup_logger
  11. logger = setup_logger(__name__)
  12. load_dotenv()
  13. ARK_API_KEY = os.getenv("ARK_API_KEY")
  14. client = Ark(
  15. api_key= ARK_API_KEY,
  16. base_url="https://ark.cn-beijing.volces.com/api/v3",
  17. )
  18. # model: deepseek-r1-250528
  19. def encode_image(pil_image):
  20. """将PIL Image对象转换为base64编码"""
  21. buffered = io.BytesIO()
  22. pil_image.save(buffered, format="JPEG")
  23. return base64.b64encode(buffered.getvalue()).decode('utf-8')
  24. def image_qa(image_path, sys_prompt):
  25. base64_image = encode_image(image_path)
  26. response = client.chat.completions.create(
  27. model="doubao-seed-1-6-251015",
  28. temperature=1,
  29. max_tokens=200,
  30. messages=[
  31. {
  32. "role": "user",
  33. "content": [
  34. {
  35. "type": "text",
  36. "text": sys_prompt,
  37. },
  38. {
  39. "type": "image_url",
  40. "image_url": {
  41. "url": f"data:image/jpg;base64,{base64_image}"
  42. },
  43. },
  44. ],
  45. }
  46. ],
  47. )
  48. return response.choices[0].message.content
  49. def text_qa(query, sys_prompt = RECOMMEND_PROMPT):
  50. response = client.chat.completions.create(
  51. model="doubao-seed-1-6-251015",
  52. temperature=1,
  53. max_tokens=500,
  54. messages = [
  55. {"role": "system", "content": sys_prompt},
  56. {"role": "user", "content": query}
  57. ],
  58. )
  59. return response.choices[0].message.content
  60. def intent_reg(query, sys_prompt = INTENT_PROMPT):
  61. response = client.chat.completions.create(
  62. model="doubao-seed-1-6-251015",
  63. temperature=1,
  64. max_tokens=500,
  65. messages = [
  66. {"role": "system", "content": sys_prompt},
  67. {"role": "user", "content": query}
  68. ],
  69. )
  70. return response.choices[0].message.content
  71. def large_order_qa(query, context, sys_prompt = ANSWER_PROMPT.format(query=None, answer=None)):
  72. response = client.chat.completions.create(
  73. model="doubao-seed-1-6-251015",
  74. temperature=1,
  75. max_tokens=500,
  76. messages = [
  77. {"role": "system", "content": sys_prompt},
  78. {"role": "user", "content": query},
  79. {"role": "user", "content": f"上下文信息:\n{context}"}
  80. ],
  81. )
  82. return response.choices[0].message.content
  83. if __name__ == "__main__":
  84. query = "1A200987A"
  85. context = "1A200987A的搭配结果:https://www.123456.xdba.cn"
  86. print(large_order_qa(query, context, ANSWER_PROMPT.format(answer=context)))