test_chat.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. import sys
  2. import time
  3. import os
  4. from PIL import Image
  5. import requests
  6. from prompt import *
  7. from llm import *
  8. import json
  9. from conf import *
  10. import re
  11. MAX_RETRIES = 5
  12. MAX_HISTORY = 20
  13. MAX_CHAR_LIMIT = 400
  14. MIN_CHAR_LIMIT = 150 # 假设的最小长度限制
  15. history_list=[]
  16. plugins = {
  17. # "ch_en_selling_points":get_ch_en_selling_points,
  18. # "en_ch_selling_points":get_en_ch_selling_points,
  19. "ch_en_selling_title":get_ch_en_selling_title,
  20. # "en_ch_selling_points_his":get_en_ch_selling_points_his,
  21. # "TextControl_his":TextControl_his,
  22. # "TextControl":TextControl
  23. }
  24. def contains_chinese(text):
  25. pattern = re.compile(r'[\u4e00-\u9fa5]')
  26. return bool(pattern.search(text))
  27. def format_history(strings, indent=" "):
  28. result = ""
  29. for i, string in enumerate(strings, start=1):
  30. # 拼接序号、缩进和字符串,并添加换行符
  31. result += f"{indent}{i}. {string}\n"
  32. return result
  33. def get_history():
  34. """获取格式化的历史记录(用于原始prompt)"""
  35. global history_list
  36. if len(history_list)==0:
  37. history=''
  38. else:
  39. history=format_history(history_list)
  40. return history
  41. def add_history(input,max_num=20):
  42. global history_list
  43. text = re.split(r'[,\.\!\?\;\:]+', input)
  44. text=text[0].strip()
  45. history_list.insert(0, text)
  46. if len(history_list)>max_num:
  47. history_list=history_list[:max_num]
  48. def generate_text(plm_info,img,graphic_label=None,plat="ali",model_name="mm_qwen"):
  49. history_string=get_history()
  50. # print(his)
  51. if graphic_label:
  52. tags_sen=",".join(graphic_label)
  53. plm_info+="\n' '以下是该衣服的关键点:"+tags_sen
  54. if plat=="ali":
  55. key=ali_ky
  56. model=ali_model[model_name]
  57. else:
  58. key=doubao_ky
  59. model=doubao_model[model_name]
  60. llm=llm_request(*key,model)
  61. en,kw='',''
  62. result_json = None
  63. for attempt in range(MAX_RETRIES):
  64. # --- 构造Prompt ---
  65. if attempt == 0:
  66. # 第一次尝试:使用您的主Prompt
  67. usrp = user_prompt.format(basic_info_string=plm_info,history_string=history_string)
  68. else:
  69. # 后续尝试:使用“修正Prompt”
  70. usrp = get_refinement_prompt(plm_info, history_string, result_json)
  71. print(f"--- 尝试 {attempt + 1} ---")
  72. # print(prompt) # Debug: 打印Prompt
  73. response_text = llm.llm_mm_request(usrp,img,sys_text=system_prompt)
  74. try:
  75. is_valid, validation_error, result_json = validate_response(response_text)
  76. if is_valid:
  77. # 成功!
  78. print("生成成功!")
  79. en,kw=result_json['en'],result_json['kw']
  80. add_history(en)
  81. break
  82. else:
  83. # 失败,记录错误,循环将继续
  84. print(f"尝试 {attempt + 1} 失败: {validation_error}")
  85. # result_json 已经包含了失败的文本和错误信息,将用于下一次修正
  86. continue
  87. except Exception as e:
  88. # API调用本身失败
  89. print(f"API 调用失败: {e}")
  90. result_json = {"error": "API_FAILURE", "raw_response": str(e)}
  91. continue
  92. if result_json and result_json.get("error") == "EN_TOO_LONG":
  93. # 如果是因为超长而失败,且 raw_response 有效
  94. try:
  95. failed_data = json.loads(result_json.get("raw_response", "{}"))
  96. long_en_text = failed_data.get("en")
  97. if long_en_text and len(long_en_text) > MAX_CHAR_LIMIT:
  98. en = smart_truncate_by_sentence(long_en_text, max_chars=MAX_CHAR_LIMIT)
  99. kw = failed_data.get("kw", '')
  100. print("执行智能截断成功,返回修正后的文案。")
  101. add_history(en)
  102. except (json.JSONDecodeError, KeyError, TypeError):
  103. # 无法截断,进入最终错误
  104. pass
  105. if isinstance(kw,str):
  106. kw = [item.strip() for item in kw.split('.') if item.strip()]
  107. return en,kw
  108. def validate_response(response_text):
  109. """验证模型的输出是否符合所有规则"""
  110. try:
  111. # 规则1: 是否是有效JSON?
  112. data = json.loads(response_text.strip())
  113. except json.JSONDecodeError:
  114. return False, "INVALID_JSON", {"error": "INVALID_JSON", "raw_response": response_text}
  115. # 规则2: 键是否齐全?
  116. if not all(k in data for k in ["en", "ch", "kw"]):
  117. return False, "MISSING_KEYS", {"error": "MISSING_KEYS", "raw_response": json.dumps(data)}
  118. en_text = data.get("en", "")
  119. # 规则3: 长度是否超标?
  120. if len(en_text) > MAX_CHAR_LIMIT:
  121. return False, "EN_TOO_LONG", {"error": "EN_TOO_LONG", "raw_response": json.dumps(data)}
  122. # 规则4: 长度是否太短?
  123. if len(en_text) < MIN_CHAR_LIMIT:
  124. return False, "EN_TOO_SHORT", {"error": "EN_TOO_SHORT", "raw_response": json.dumps(data)}
  125. if contains_chinese(en_text):
  126. return False, "EN_CONTAINS_CHINESE", {"error": "EN_CONTAINS_CHINESE", "raw_response": json.dumps(data)}
  127. return True, "SUCCESS", data
  128. def smart_truncate_by_sentence(text, max_chars=MAX_CHAR_LIMIT):
  129. if len(text) <= max_chars:
  130. return text
  131. sentence_endings = re.compile(r'[.!?](?:\s+|$)')
  132. best_truncate_point = -1
  133. for match in sentence_endings.finditer(text):
  134. end_position = match.end()
  135. if end_position <= max_chars:
  136. best_truncate_point = end_position
  137. else:
  138. break
  139. if best_truncate_point > 0:
  140. truncated_text = text[:best_truncate_point].strip()
  141. if not truncated_text.endswith(('.', '!', '?')):
  142. truncated_text += '.'
  143. return truncated_text.strip()
  144. else:
  145. return text[:max_chars-3].strip() + '...'
  146. def get_refinement_prompt(basic_info_string, history_string, failed_result):
  147. """
  148. 根据上一次的失败原因,生成一个“引导式修正”的Prompt
  149. """
  150. failure_reason = failed_result.get("error", "UNKNOWN")
  151. raw_response = failed_result.get("raw_response", "")
  152. feedback = ""
  153. # 尝试提取上次失败的文案
  154. last_text_en = ""
  155. try:
  156. if raw_response:
  157. last_text_en = json.loads(raw_response).get("en", "")
  158. except json.JSONDecodeError:
  159. pass # 无法解析,last_text_en 保持空
  160. if failure_reason == "INVALID_JSON":
  161. feedback = f"你上次的输出不是一个有效的JSON。请【严格】按照JSON格式输出。你上次的错误输出是:\n{raw_response}"
  162. elif failure_reason == "EN_TOO_LONG":
  163. feedback = f"""
  164. 你上次生成的 "en" 描述【超过了{MAX_CHAR_LIMIT}个字符】!
  165. 【你生成的超长原文】:\n{last_text_en}
  166. 【修正任务】: 请【大幅精简】上述原文,保留核心卖点,使其长度【绝对】在{MIN_CHAR_LIMIT}-{MAX_CHAR_LIMIT}字符以内。
  167. """
  168. elif failure_reason == "EN_TOO_SHORT":
  169. feedback = f"""
  170. 你上次生成的 "en" 描述太短了(小于{MIN_CHAR_LIMIT}字符)。
  171. 【你生成的原文】:\n{last_text_en}
  172. 【修正任务】: 请在原文案基础上,围绕核心卖点再丰富一些细节,使其达到{MIN_CHAR_LIMIT}-{MAX_CHAR_LIMIT}字符。
  173. """
  174. elif failure_reason == "MISSING_KEYS":
  175. feedback = f"你上次输出的JSON缺少 'en', 'ch' 或 'kw' 键。请确保三者齐全。"
  176. elif failure_reason == "TOO_SIMILAR":
  177. feedback = "你上次生成的文案与历史记录太相似了。请换一个角度(比如从'材质'或'穿搭场景')重新构思,字数保持在要求的范围内。"
  178. elif failure_reason == "EN_CONTAINS_CHINESE":
  179. feedback = f"""
  180. 你上次生成的 "en" 描述中包含了中文汉字(例如:{last_text_en})。
  181. 【修正任务】: "en" 字段【必须是纯英文】,【绝对禁止】出现任何中文字符。请严格修正并重新输出。
  182. """
  183. else:
  184. feedback = "你上次的生成失败了。请重新严格按照所有规则生成一次。"
  185. # 修正Prompt模板
  186. refinement_prompt = f"""## 角色
  187. 你是一个文案修正专家。
  188. ## 原始任务
  189. 根据以下信息和随消息传入的图片生成文案:{basic_info_string}
  190. ## 上次失败的反馈 (你必须修正!)
  191. {feedback}
  192. ## 核心规则 (必须再次遵守)
  193. 1. 【必须】输出严格的JSON格式。
  194. 2. "en" 描述【必须严格在{MAX_CHAR_LIMIT}字符以内】。
  195. 3. 【不要】使用历史开篇:\n{history_string}
  196. ## 最终输出
  197. 请直接输出修正后的、严格符合要求的JSON字典。
  198. """
  199. return refinement_prompt
  200. def gen_title(info,tags=None,referencr_title=None,method="ch_en_selling_title",plat="ali",model_name="text_dsv3"):
  201. if tags:
  202. tags_sen=",".join(tags)
  203. info="\n' '以下是该衣服的关键点:"+tags_sen
  204. if referencr_title:
  205. info="\n' '请以这条标题样例的结构作为借鉴来写这条标题:"+referencr_title
  206. sysp,usrp = plugins[method](info)
  207. if plat=="ali":
  208. key=ali_ky
  209. model=ali_model[model_name]
  210. else:
  211. key=doubao_ky
  212. model=doubao_model[model_name]
  213. llm=llm_request(*key,model)
  214. res=llm.llm_text_request(usrp,sysp)
  215. res_dict = json.loads(res)
  216. return {"title":res_dict["en_tile"]}
  217. if __name__ == "__main__":
  218. # inf="'Meet your new best friend in fashion—this unisex sweater that whispers comfort and style. Crafted from premium cotton, it feels like a gentle hug on your skin. The heart embroidery adds a touch of whimsy, making you the star of any casual outing. Perfect for layering or wearing solo, this soft companion keeps you cozy all season long."
  219. # print(gen_title(inf))
  220. # id_image,id_price, id_color, id_ingredient, id_selling_point, id_details=search_json_files("1A6H4K7V0")
  221. # id_image=id_image[2:]
  222. # id_image=os.path.join("/data/data/luosy/project/sku_search",id_image)
  223. id_image="https://img2.goelia.com.au/prod/product/1ENC6E220/material/main/Shopify/-1/72736752b0ad405382d5ed277dabc660.jpg"
  224. graphic_label=['-100% Merino wool', '-With pockets', '-H-line fit']
  225. plm_info='1、手工流苏边设计 \xa0 2、贴袋设计 \xa0 3、金属纽扣'
  226. # print(id_details,id_image)
  227. for _ in range(3):
  228. result=generate_text(plm_info,id_image,graphic_label)
  229. # result=gen_title("This maxi dress features unparalleled comfort and a unique texture with its <b>tencel blend fabric</b>. The square neckline and smocked bodice create a flattering silhouette, while the layered skirt adds romantic flair. <b>Side pockets and an included scarf scrunchie</b> enhance both style and functionality, elevating its versatility for everyday wear and beyond.")
  230. print(result)
  231. # from tqdm import tqdm
  232. # def image_to_base64(image):
  233. # # 将Image对象转换为BytesIO对象
  234. # image_io = io.BytesIO()
  235. # image.save(image_io, format='PNG')
  236. # image_io.seek(0)
  237. # # 使用base64编码
  238. # image_base64 = base64.b64encode(image_io.read()).decode('utf-8')
  239. # return image_base64
  240. # def create_html_with_base64_images(root, output_html):
  241. # with open(output_html, 'w', encoding='utf-8') as html_file:
  242. # html_file.write('<!DOCTYPE html>\n<html>\n<head>\n<title>Images in Table</title>\n')
  243. # html_file.write('<meta charset="UTF-8">\n') # 添加字符编码声明
  244. # html_file.write('<style>\n')
  245. # html_file.write('table {\nborder-collapse: collapse;\nwidth: 100%;\n}\n')
  246. # html_file.write('table, th, td {\nborder: 1px solid black;\n}\n')
  247. # html_file.write('img {\nmax-width: 100%;\nheight: auto;\ndisplay: block;\nmargin-left: auto;\nmargin-right: auto;\n}\n')
  248. # html_file.write('</style>\n')
  249. # html_file.write('</head>\n<body>\n')
  250. # html_file.write('<table>\n')
  251. # html_file.write('<tr>\n')
  252. # html_file.write('<th>输入的图片</th>\n') # 第一列:索引
  253. # html_file.write('<th>输入的描述</th>\n') # 第二列:标题
  254. # html_file.write('<th>输出的商品详情</th>\n') # 第二列:标题
  255. # html_file.write('<th>输出的商品详情(翻译)</th>\n') # 第三列:图表
  256. # html_file.write('<th>输出的卖点</th>\n') # 第三列:图表
  257. # # for i in range(1, 100): # 添加序号列1到13
  258. # # html_file.write(f'<th>{i}</th>\n')
  259. # html_file.write('</tr>\n')
  260. # for file in tqdm(os.listdir(root)[:100], desc="Processing", unit="iter"):
  261. # if '.ipynb_checkpoints' in file:
  262. # continue
  263. # file_path = os.path.join(root, file)
  264. # with open(file_path, 'r') as f:
  265. # data = json.load(f)
  266. # if data and "商品图像" in data.keys():
  267. # id_image,id_details=data["商品图像"][2:], data["商品细节"]
  268. # else:
  269. # continue
  270. # id_image=os.path.join("/data/data/luosy/project/sku_search",id_image)
  271. # img_base64 = image_to_base64(Image.open(id_image))
  272. # ch,en,kw=generate_text_new(id_details,id_image)
  273. # html_file.write('<tr>\n')
  274. # # html_file.write(f'<td>{index+1}</td>\n') # 添加序号
  275. # # html_file.write('<td>\n')
  276. # # html_file.write(f'<img src="data:image/png;base64,{frame_title_img}" alt="Image">\n')
  277. # # html_file.write('</td>\n')
  278. # html_file.write('<td>\n')
  279. # html_file.write(f'<img src="data:image/png;base64,{img_base64}" alt="Image">\n')
  280. # html_file.write('</td>\n')
  281. # html_file.write(f'<td>{id_details}</td>\n') # 添加序号
  282. # html_file.write(f'<td>{en}</td>\n') # 添加序号
  283. # html_file.write(f'<td>{ch}</td>\n') # 添加序号
  284. # html_file.write(f'<td>{kw}</td>\n') # 添加序号
  285. # # html_file.write('</td>\n')
  286. # # for img in image_data:
  287. # # html_file.write('<td>\n')
  288. # # html_file.write(f'<img src="data:image/jpeg;base64,{img}" alt="Image" style="max-width: 100px; max-height: 100px; margin: 5px;">\n')
  289. # # html_file.write('</td>\n')
  290. # # html_file.write('</td>\n')
  291. # html_file.write('</tr>\n')
  292. # html_file.write('</table>\n')
  293. # html_file.write('</body>\n</html>')
  294. # root='/data/data/luosy/project/sku_search/database/meta'
  295. # create_html_with_base64_images(root, "out——qw_v6.html")
  296. # app.run(host="0.0.0.0",port=2222,debug=True)
  297. # print(gen_title(info= "This sweatshirt is a wardrobe essential with its simple yet stylish design and 3D heart pattern that adds a fun visual pop. <b>The unisex design is perfect for couples</b>, and it pairs effortlessly with jeans, cargo pants, or a pleated skirt. Ideal for school, work, or casual outings, it's comfortable and trendy all day long!"))
  298. # from PIL import Image
  299. # img1=Image.open("/data/data/luosy/project/sku_search/temp_img/企业微信截图_17372766091671.png")
  300. # ch_sen,en_sen,key_point,id_image,id_price, id_color, id_ingredient, id_selling_point, id_details=generate_text("",img1,"""-With elastic waistband
  301. # -With hairband
  302. # -X-line fit
  303. # 1.腰部橡筋 2.袖子橡
  304. # 筋 3.前中绳子可调
  305. # 节大小""")
  306. # print(len(en_sen),end=" ")
  307. # print(ch_sen,en_sen,key_point)
  308. # ###############################
  309. # img2=Image.open("/data/data/luosy/project/sku_search/temp_img/企业微信截图_17389065463149[1](1).png")
  310. # ch_sen,en_sen,key_point,id_image,id_price, id_color, id_ingredient, id_selling_point, id_details=generate_text("",img2,"""-Washable wool
  311. # -Unisex
  312. # -With silver threads
  313. # 1.后中开衩;2.双扣可调节袖袢;3.暗门筒设计,天然果实扣;4.可水洗羊毛含银葱人字纹面料;5.里面左右两侧均有内袋,左侧最外层内袋是手机袋,防丢失""")
  314. # print(len(en_sen),end=" ")
  315. # print(ch_sen,en_sen,key_point)
  316. # ###############################
  317. # img3=Image.open("/data/data/luosy/project/sku_search/temp_img/企业微信截图_17392379937637.png")
  318. # ch_sen,en_sen,key_point,id_image,id_price, id_color, id_ingredient, id_selling_point, id_details=generate_text("",img3,"""-Acetate
  319. # -With pockets
  320. # -Workwear
  321. # 1.描述二醋酸面料:2.扣子为镶钻布包扣;3.半裙后腰包橡筋;4.半裙有
  322. # 侧插袋;5.半裙有侧开隐形拉链,这是两件套套装""")
  323. # print(len(en_sen),end=" ")
  324. # print(ch_sen,en_sen,key_point)