123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- import os
- import json
- import time
- import asyncio
- import logging
- from pathlib import Path
- from tqdm import tqdm
- from typing import List, Dict, Union
- from PIL import Image
- import torch
- from transformers import AutoModelForCausalLM
- from janus.models import MultiModalityCausalLM, VLChatProcessor
- from janus.utils.io import load_pil_images
- from utils.logger_config import setup_logger
- from utils.llm_outparser import extract_json
- from utils.common import read_json_file
- from config.image_qa import prompt, double_prompt
- # 配置日志系统
- logger = setup_logger(__name__)
- class JanusVisualAssistant:
- """A highly optimized visual assistant for multimodal interactions"""
-
- def __init__(
- self,
- model_path: str = "/data/data/luosy/models/Janus-Pro-7B",
- dtype: torch.dtype = torch.bfloat16,
- device: str = "cuda"
- ):
- """Initialize model components with efficient memory management"""
- self.dtype = dtype
- self.device = device
- self.image_cache = {}
- # Initialize components with memory optimization
- self.processor = VLChatProcessor.from_pretrained(model_path)
- self.tokenizer = self.processor.tokenizer
- self.model = self._load_model(model_path).eval()
- def _load_model(self, model_path: str) -> MultiModalityCausalLM:
- """Load model with optimized memory allocation"""
- return AutoModelForCausalLM.from_pretrained(
- model_path,
- trust_remote_code=True
- ).to(self.dtype).to(self.device)
- def _monitor_memory(self):
- """Monitor GPU memory usage"""
- allocated = torch.cuda.memory_allocated(self.device)
- reserved = torch.cuda.memory_reserved(self.device)
- logger.info(f"Memory allocated: {allocated / 1024 ** 2:.2f} MB")
- logger.info(f"Memory reserved: {reserved / 1024 ** 2:.2f} MB")
- def create_conversation(
- self,
- image_paths: List[str],
- questions: List[str],
- system_prompt: str = "你是一个专业的视频理解助手"
- ) -> List[Dict]:
- """Build conversation structure with efficient image handling"""
- if len(image_paths) != len(questions):
- raise ValueError("The number of images must match the number of questions.")
-
- conversations = []
- for image_path, question in zip(image_paths, questions):
- conversations.append({
- "role": "<|User|>",
- "content": f"<image_placeholder>\n{question}",
- "images": [self._preprocess_image(image_path)],
- })
-
- conversations.append({"role": "<|Assistant|>", "content": system_prompt})
- return conversations
- def _preprocess_image(self, image_path: Union[str, Path]) -> str:
- """Validate and standardize image input format with caching"""
- if image_path in self.image_cache:
- return self.image_cache[image_path]
-
- # Load and cache the image
- try:
- image = Image.open(image_path).convert('RGB')
- self.image_cache[image_path] = str(image_path)
- return str(image_path)
- except Exception as e:
- logger.error(f"Error processing image {image_path}: {e}")
- raise
- @torch.inference_mode()
- def generate_response(
- self,
- conversation: List[Dict],
- generation_config: Dict = None
- ) -> str:
- """Optimized generation pipeline with batch processing"""
- # Monitor memory before processing
- self._monitor_memory()
- # Default generation parameters
- default_config = {
- "max_new_tokens": 512,
- "do_sample": False,
- "use_cache": True,
- "temperature": 0.9,
- "pad_token_id": self.tokenizer.eos_token_id,
- "bos_token_id": self.tokenizer.bos_token_id,
- "eos_token_id": self.tokenizer.eos_token_id,
- }
- config = {**default_config, **(generation_config or {})}
- # Batch processing pipeline
- pil_images = load_pil_images(conversation)
- inputs = self.processor(
- conversations=conversation,
- images=pil_images,
- force_batchify=True
- ).to(self.device)
- # Direct memory reuse for embeddings
- inputs_embeds = self.model.prepare_inputs_embeds(**inputs)
-
- # Accelerated generation
- outputs = self.model.language_model.generate(
- inputs_embeds=inputs_embeds,
- attention_mask=inputs.attention_mask,
- **config
- )
-
- return self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
- def image_caption(image_list):
- # 初始化image qa模型
- assistant = JanusVisualAssistant()
- # prompt = read_json_file("./config/image_qa.json")["prompt"]
-
- # 执行图像理解
- for image in tqdm(image_list):
- conversation = assistant.create_conversation(
- image_paths=[image],
- questions=[prompt()]
- )
- clip_name = os.path.splitext(os.path.basename(image))[0]
- response = assistant.generate_response(conversation)
- response_json = json.loads(extract_json(str(response)))
- response_json["视频片段编号"] = clip_name
- with open(f'./data/img_caption/for_understand/{os.path.splitext(os.path.basename(image))[0]}.json', 'w', encoding='utf-8') as f:
- f.write(str(response_json))
-
- async def main():
- # Benchmark and example usage
- assistant = JanusVisualAssistant()
-
- start_time = time.time()
- conversation = assistant.create_conversation(
- image_paths=["/data/data/luosy/project/oral/data/key_frame/frame_00000000.jpg"],
- questions=["以JSON格式回复,包含字段【人物数量】【人物服装】【人物配饰】"]
- )
-
- response = await assistant.generate_response(conversation)
-
- print(f"Response: {response}")
- print(f"Total time: {time.time() - start_time:.2f}s")
- if __name__ == "__main__":
- asyncio.run(main())
|