image_qa.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import os
  2. import json
  3. import time
  4. import asyncio
  5. import logging
  6. from pathlib import Path
  7. from tqdm import tqdm
  8. from typing import List, Dict, Union
  9. from PIL import Image
  10. import torch
  11. from transformers import AutoModelForCausalLM
  12. from janus.models import MultiModalityCausalLM, VLChatProcessor
  13. from janus.utils.io import load_pil_images
  14. from utils.logger_config import setup_logger
  15. from utils.llm_outparser import extract_json
  16. from utils.common import read_json_file
  17. from config.image_qa import prompt, double_prompt
  18. # 配置日志系统
  19. logger = setup_logger(__name__)
  20. class JanusVisualAssistant:
  21. """A highly optimized visual assistant for multimodal interactions"""
  22. def __init__(
  23. self,
  24. model_path: str = "/data/data/luosy/models/Janus-Pro-7B",
  25. dtype: torch.dtype = torch.bfloat16,
  26. device: str = "cuda"
  27. ):
  28. """Initialize model components with efficient memory management"""
  29. self.dtype = dtype
  30. self.device = device
  31. self.image_cache = {}
  32. # Initialize components with memory optimization
  33. self.processor = VLChatProcessor.from_pretrained(model_path)
  34. self.tokenizer = self.processor.tokenizer
  35. self.model = self._load_model(model_path).eval()
  36. def _load_model(self, model_path: str) -> MultiModalityCausalLM:
  37. """Load model with optimized memory allocation"""
  38. return AutoModelForCausalLM.from_pretrained(
  39. model_path,
  40. trust_remote_code=True
  41. ).to(self.dtype).to(self.device)
  42. def _monitor_memory(self):
  43. """Monitor GPU memory usage"""
  44. allocated = torch.cuda.memory_allocated(self.device)
  45. reserved = torch.cuda.memory_reserved(self.device)
  46. logger.info(f"Memory allocated: {allocated / 1024 ** 2:.2f} MB")
  47. logger.info(f"Memory reserved: {reserved / 1024 ** 2:.2f} MB")
  48. def create_conversation(
  49. self,
  50. image_paths: List[str],
  51. questions: List[str],
  52. system_prompt: str = "你是一个专业的视频理解助手"
  53. ) -> List[Dict]:
  54. """Build conversation structure with efficient image handling"""
  55. if len(image_paths) != len(questions):
  56. raise ValueError("The number of images must match the number of questions.")
  57. conversations = []
  58. for image_path, question in zip(image_paths, questions):
  59. conversations.append({
  60. "role": "<|User|>",
  61. "content": f"<image_placeholder>\n{question}",
  62. "images": [self._preprocess_image(image_path)],
  63. })
  64. conversations.append({"role": "<|Assistant|>", "content": system_prompt})
  65. return conversations
  66. def _preprocess_image(self, image_path: Union[str, Path]) -> str:
  67. """Validate and standardize image input format with caching"""
  68. if image_path in self.image_cache:
  69. return self.image_cache[image_path]
  70. # Load and cache the image
  71. try:
  72. image = Image.open(image_path).convert('RGB')
  73. self.image_cache[image_path] = str(image_path)
  74. return str(image_path)
  75. except Exception as e:
  76. logger.error(f"Error processing image {image_path}: {e}")
  77. raise
  78. @torch.inference_mode()
  79. def generate_response(
  80. self,
  81. conversation: List[Dict],
  82. generation_config: Dict = None
  83. ) -> str:
  84. """Optimized generation pipeline with batch processing"""
  85. # Monitor memory before processing
  86. self._monitor_memory()
  87. # Default generation parameters
  88. default_config = {
  89. "max_new_tokens": 512,
  90. "do_sample": False,
  91. "use_cache": True,
  92. "temperature": 0.9,
  93. "pad_token_id": self.tokenizer.eos_token_id,
  94. "bos_token_id": self.tokenizer.bos_token_id,
  95. "eos_token_id": self.tokenizer.eos_token_id,
  96. }
  97. config = {**default_config, **(generation_config or {})}
  98. # Batch processing pipeline
  99. pil_images = load_pil_images(conversation)
  100. inputs = self.processor(
  101. conversations=conversation,
  102. images=pil_images,
  103. force_batchify=True
  104. ).to(self.device)
  105. # Direct memory reuse for embeddings
  106. inputs_embeds = self.model.prepare_inputs_embeds(**inputs)
  107. # Accelerated generation
  108. outputs = self.model.language_model.generate(
  109. inputs_embeds=inputs_embeds,
  110. attention_mask=inputs.attention_mask,
  111. **config
  112. )
  113. return self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
  114. def image_caption(image_list):
  115. # 初始化image qa模型
  116. assistant = JanusVisualAssistant()
  117. # prompt = read_json_file("./config/image_qa.json")["prompt"]
  118. # 执行图像理解
  119. for image in tqdm(image_list):
  120. conversation = assistant.create_conversation(
  121. image_paths=[image],
  122. questions=[prompt()]
  123. )
  124. clip_name = os.path.splitext(os.path.basename(image))[0]
  125. response = assistant.generate_response(conversation)
  126. response_json = json.loads(extract_json(str(response)))
  127. response_json["视频片段编号"] = clip_name
  128. with open(f'./data/img_caption/for_understand/{os.path.splitext(os.path.basename(image))[0]}.json', 'w', encoding='utf-8') as f:
  129. f.write(str(response_json))
  130. async def main():
  131. # Benchmark and example usage
  132. assistant = JanusVisualAssistant()
  133. start_time = time.time()
  134. conversation = assistant.create_conversation(
  135. image_paths=["/data/data/luosy/project/oral/data/key_frame/frame_00000000.jpg"],
  136. questions=["以JSON格式回复,包含字段【人物数量】【人物服装】【人物配饰】"]
  137. )
  138. response = await assistant.generate_response(conversation)
  139. print(f"Response: {response}")
  140. print(f"Total time: {time.time() - start_time:.2f}s")
  141. if __name__ == "__main__":
  142. asyncio.run(main())