image_qa_doubao.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import base64
  2. import os
  3. import io
  4. from volcenginesdkarkruntime import Ark
  5. from PIL import Image
  6. from utils.logger_config import setup_logger
  7. from utils.llm_outparser import extract_json
  8. from utils.common import read_json_file
  9. from config.image_qa import prompt, double_prompt, show_prompt
  10. import json
  11. from tqdm import tqdm
  12. logger = setup_logger(__name__)
  13. client = Ark(
  14. base_url="https://ark.cn-beijing.volces.com/api/v3",
  15. api_key="817dff39-5586-4f9b-acba-55004167c0b1",
  16. )
  17. # def encode_image(image_path):
  18. # with open(image_path, "rb") as image_file:
  19. # return base64.b64encode(image_file.read()).decode('utf-8')
  20. def encode_image(image_path, crop_margin=200):
  21. """裁切图像四周边缘像素并返回中间部分的 base64 编码数据"""
  22. with Image.open(image_path) as img:
  23. # 获取图像的宽度和高度
  24. width, height = img.size
  25. # 计算裁切后的区域
  26. left = crop_margin
  27. upper = 0
  28. right = width - crop_margin
  29. lower = height
  30. # 裁切图像
  31. cropped_img = img.crop((left, upper, right, lower))
  32. # 将裁切后的图像转换为 base64 编码
  33. buffered = io.BytesIO()
  34. cropped_img.save(buffered, format="JPEG") # 可以根据需要选择格式
  35. return base64.b64encode(buffered.getvalue()).decode('utf-8')
  36. def analyze_single_image_content(image_path):
  37. base64_image = encode_image(image_path)
  38. response = client.chat.completions.create(
  39. model="doubao-1-5-vision-pro-32k-250115",
  40. temperature=1,
  41. max_tokens=200,
  42. messages=[
  43. {
  44. "role": "user",
  45. "content": [
  46. {
  47. "type": "text",
  48. "text": prompt(),
  49. },
  50. {
  51. "type": "image_url",
  52. "image_url": {
  53. "url": f"data:image/jpg;base64,{base64_image}"
  54. },
  55. },
  56. ],
  57. }
  58. ],
  59. )
  60. return response.choices[0].message.content
  61. def detect_show(image_path):
  62. base64_image = encode_image(image_path)
  63. response = client.chat.completions.create(
  64. model="doubao-1-5-vision-pro-32k-250115",
  65. temperature=1,
  66. max_tokens=200,
  67. messages=[
  68. {
  69. "role": "user",
  70. "content": [
  71. {
  72. "type": "text",
  73. "text": show_prompt(),
  74. },
  75. {
  76. "type": "image_url",
  77. "image_url": {
  78. "url": f"data:image/jpg;base64,{base64_image}"
  79. },
  80. },
  81. ],
  82. }
  83. ],
  84. )
  85. return response.choices[0].message.content
  86. def analyze_double_image_content(image_path_1, image_path_2):
  87. base64_image_1 = encode_image(image_path_1)
  88. base64_image_2 = encode_image(image_path_2)
  89. response = client.chat.completions.create(
  90. model="doubao-1-5-vision-pro-32k-250115",
  91. temperature=1,
  92. max_tokens=200,
  93. messages=[
  94. {
  95. "role": "user",
  96. "content": [
  97. {
  98. "type": "text",
  99. "text": double_prompt(),
  100. },
  101. {
  102. "type": "image_url",
  103. "image_url": {
  104. "url": f"data:image/jpg;base64,{base64_image_1}"
  105. },
  106. },
  107. {
  108. "type": "image_url",
  109. "image_url": {
  110. "url": f"data:image/jpg;base64,{base64_image_2}"
  111. },
  112. },
  113. ],
  114. }
  115. ],
  116. )
  117. return response.choices[0].message.content
  118. def image_caption_doubao(image_list):
  119. # 执行图像理解
  120. logger.info(f"fisrt_cut: 执行单帧图像理解")
  121. for image in tqdm(image_list):
  122. clip_name = os.path.splitext(os.path.basename(image))[0]
  123. response = analyze_single_image_content(image)
  124. response_json = json.loads(extract_json(str(response)))
  125. response_json["视频片段编号"] = clip_name
  126. response_json = str(response_json).replace("'",'"')
  127. with open(f'./data/img_caption/{clip_name}.json', 'w', encoding='utf-8') as f:
  128. f.write(response_json)
  129. def show_detect_doubao(image_list):
  130. # 执行图像理解
  131. logger.info(f"show_cut: 执行单帧图像理解")
  132. for image in tqdm(image_list):
  133. clip_name = os.path.splitext(os.path.basename(image))[0]
  134. response = detect_show(image)
  135. response_json = json.loads(extract_json(str(response)))
  136. response_json["视频片段编号"] = clip_name
  137. response_json = str(response_json).replace("'",'"')
  138. with open(f'./data/img_caption/for_show/{clip_name}.json', 'w', encoding='utf-8') as f:
  139. f.write(response_json)
  140. def image_compare_doubao(image_list):
  141. logger.info(f"first_cut: 执行两帧对比理解")
  142. for i in tqdm(range(len(image_list) - 1)):
  143. image1 = image_list[i]
  144. image2 = image_list[i + 1]
  145. clip1_name = os.path.splitext(os.path.basename(image1))[0]
  146. clip2_name = os.path.splitext(os.path.basename(image2))[0]
  147. clip_name = clip1_name + '-' + clip2_name
  148. similarity = analyze_double_image_content(image1, image2)
  149. similarity_json = json.loads(extract_json(str(similarity)))
  150. similarity_json["对比图像"] = clip_name
  151. similarity_json = str(similarity_json).replace("'",'"')
  152. with open(f'./data/img_caption/for_cut/{clip_name}.json', 'w', encoding='utf-8') as f:
  153. f.write(similarity_json)
  154. # 使用示例
  155. if __name__ == "__main__":
  156. image_path = "/data/data/luosy/project/oral/data/key_frame/frame_00000000.jpg"
  157. result = analyze_image_content(image_path)
  158. print(result)