import os import json import time import asyncio import logging from pathlib import Path from tqdm import tqdm from typing import List, Dict, Union from PIL import Image import torch from transformers import AutoModelForCausalLM from janus.models import MultiModalityCausalLM, VLChatProcessor from janus.utils.io import load_pil_images from utils.logger_config import setup_logger from utils.llm_outparser import extract_json from utils.common import read_json_file from config.image_qa import prompt, double_prompt # 配置日志系统 logger = setup_logger(__name__) class JanusVisualAssistant: """A highly optimized visual assistant for multimodal interactions""" def __init__( self, model_path: str = "/data/data/luosy/models/Janus-Pro-7B", dtype: torch.dtype = torch.bfloat16, device: str = "cuda" ): """Initialize model components with efficient memory management""" self.dtype = dtype self.device = device self.image_cache = {} # Initialize components with memory optimization self.processor = VLChatProcessor.from_pretrained(model_path) self.tokenizer = self.processor.tokenizer self.model = self._load_model(model_path).eval() def _load_model(self, model_path: str) -> MultiModalityCausalLM: """Load model with optimized memory allocation""" return AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True ).to(self.dtype).to(self.device) def _monitor_memory(self): """Monitor GPU memory usage""" allocated = torch.cuda.memory_allocated(self.device) reserved = torch.cuda.memory_reserved(self.device) logger.info(f"Memory allocated: {allocated / 1024 ** 2:.2f} MB") logger.info(f"Memory reserved: {reserved / 1024 ** 2:.2f} MB") def create_conversation( self, image_paths: List[str], questions: List[str], system_prompt: str = "你是一个专业的视频理解助手" ) -> List[Dict]: """Build conversation structure with efficient image handling""" if len(image_paths) != len(questions): raise ValueError("The number of images must match the number of questions.") conversations = [] for image_path, question in zip(image_paths, questions): conversations.append({ "role": "<|User|>", "content": f"\n{question}", "images": [self._preprocess_image(image_path)], }) conversations.append({"role": "<|Assistant|>", "content": system_prompt}) return conversations def _preprocess_image(self, image_path: Union[str, Path]) -> str: """Validate and standardize image input format with caching""" if image_path in self.image_cache: return self.image_cache[image_path] # Load and cache the image try: image = Image.open(image_path).convert('RGB') self.image_cache[image_path] = str(image_path) return str(image_path) except Exception as e: logger.error(f"Error processing image {image_path}: {e}") raise @torch.inference_mode() def generate_response( self, conversation: List[Dict], generation_config: Dict = None ) -> str: """Optimized generation pipeline with batch processing""" # Monitor memory before processing self._monitor_memory() # Default generation parameters default_config = { "max_new_tokens": 512, "do_sample": False, "use_cache": True, "temperature": 0.9, "pad_token_id": self.tokenizer.eos_token_id, "bos_token_id": self.tokenizer.bos_token_id, "eos_token_id": self.tokenizer.eos_token_id, } config = {**default_config, **(generation_config or {})} # Batch processing pipeline pil_images = load_pil_images(conversation) inputs = self.processor( conversations=conversation, images=pil_images, force_batchify=True ).to(self.device) # Direct memory reuse for embeddings inputs_embeds = self.model.prepare_inputs_embeds(**inputs) # Accelerated generation outputs = self.model.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=inputs.attention_mask, **config ) return self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) def image_caption(image_list): # 初始化image qa模型 assistant = JanusVisualAssistant() # prompt = read_json_file("./config/image_qa.json")["prompt"] # 执行图像理解 for image in tqdm(image_list): conversation = assistant.create_conversation( image_paths=[image], questions=[prompt()] ) clip_name = os.path.splitext(os.path.basename(image))[0] response = assistant.generate_response(conversation) response_json = json.loads(extract_json(str(response))) response_json["视频片段编号"] = clip_name with open(f'./data/img_caption/for_understand/{os.path.splitext(os.path.basename(image))[0]}.json', 'w', encoding='utf-8') as f: f.write(str(response_json)) async def main(): # Benchmark and example usage assistant = JanusVisualAssistant() start_time = time.time() conversation = assistant.create_conversation( image_paths=["/data/data/luosy/project/oral/data/key_frame/frame_00000000.jpg"], questions=["以JSON格式回复,包含字段【人物数量】【人物服装】【人物配饰】"] ) response = await assistant.generate_response(conversation) print(f"Response: {response}") print(f"Total time: {time.time() - start_time:.2f}s") if __name__ == "__main__": asyncio.run(main())