| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- from email import message
- import io
- import os
- import base64
- from PIL import Image
- from dotenv import load_dotenv
- from volcenginesdkarkruntime import Ark
- from volcenginesdkarkruntime.resources import content_generation
- from config.prompt import RECOMMEND_PROMPT, INTENT_PROMPT, ANSWER_PROMPT
- from utils.logger_config import setup_logger
- logger = setup_logger(__name__)
- load_dotenv()
- ARK_API_KEY = os.getenv("ARK_API_KEY")
- client = Ark(
- api_key= ARK_API_KEY,
- base_url="https://ark.cn-beijing.volces.com/api/v3",
- )
- # model: deepseek-r1-250528
-
- def encode_image(pil_image):
- """将PIL Image对象转换为base64编码"""
- buffered = io.BytesIO()
- pil_image.save(buffered, format="JPEG")
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
- def image_qa(image_path, sys_prompt):
- base64_image = encode_image(image_path)
- response = client.chat.completions.create(
- model="doubao-seed-1-6-251015",
- temperature=1,
- max_tokens=200,
- messages=[
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": sys_prompt,
- },
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:image/jpg;base64,{base64_image}"
- },
- },
- ],
- }
- ],
- )
- return response.choices[0].message.content
- def text_qa(query, sys_prompt = RECOMMEND_PROMPT):
- response = client.chat.completions.create(
- model="doubao-seed-1-6-251015",
- temperature=1,
- max_tokens=500,
- messages = [
- {"role": "system", "content": sys_prompt},
- {"role": "user", "content": query}
- ],
- )
- return response.choices[0].message.content
- def intent_reg(query, sys_prompt = INTENT_PROMPT):
- response = client.chat.completions.create(
- model="doubao-seed-1-6-251015",
- temperature=1,
- max_tokens=500,
- messages = [
- {"role": "system", "content": sys_prompt},
- {"role": "user", "content": query}
- ],
- )
- return response.choices[0].message.content
- def large_order_qa(query, context, sys_prompt = ANSWER_PROMPT.format(query=None, answer=None)):
- response = client.chat.completions.create(
- model="doubao-seed-1-6-251015",
- temperature=1,
- max_tokens=500,
- messages = [
- {"role": "system", "content": sys_prompt},
- {"role": "user", "content": query},
- {"role": "user", "content": f"上下文信息:\n{context}"}
- ],
- )
- return response.choices[0].message.content
- if __name__ == "__main__":
- query = "1A200987A"
- context = "1A200987A的搭配结果:https://www.123456.xdba.cn"
- print(large_order_qa(query, context, ANSWER_PROMPT.format(answer=context)))
|