janus_vllm.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # import time
  2. # from pathlib import Path
  3. # from typing import List, Dict, Union
  4. # from PIL import Image
  5. # import torch
  6. # from transformers import AutoModelForCausalLM
  7. # from janus.models import MultiModalityCausalLM, VLChatProcessor
  8. # from janus.utils.io import load_pil_images
  9. # class JanusVisualAssistant:
  10. # """A highly optimized visual assistant for multimodal interactions
  11. # Attributes:
  12. # model: Pretrained language model
  13. # processor: Text and image processor
  14. # tokenizer: Text tokenizer
  15. # device: Current computation device
  16. # dtype: Data type for model parameters
  17. # """
  18. # def __init__(
  19. # self,
  20. # model_path: str = "/data/data/luosy/models/Janus-Pro-7B",
  21. # dtype: torch.dtype = torch.bfloat16,
  22. # device: str = "cuda"
  23. # ):
  24. # """Initialize model components with efficient memory management"""
  25. # self.dtype = dtype
  26. # self.device = device
  27. # # Initialize components with memory optimization
  28. # self.processor = VLChatProcessor.from_pretrained(model_path)
  29. # self.tokenizer = self.processor.tokenizer
  30. # self.model = self._load_model(model_path).eval()
  31. # def _load_model(self, model_path: str) -> MultiModalityCausalLM:
  32. # """Load model with optimized memory allocation"""
  33. # return AutoModelForCausalLM.from_pretrained(
  34. # model_path,
  35. # trust_remote_code=True
  36. # ).to(self.dtype).to(self.device)
  37. # def create_conversation(
  38. # self,
  39. # image_path: str,
  40. # question: str,
  41. # system_prompt: str = "你是一个专业的视频理解助手"
  42. # ) -> List[Dict]:
  43. # """Build conversation structure with efficient image handling"""
  44. # return [
  45. # {
  46. # "role": "<|User|>",
  47. # "content": f"<image_placeholder>\n{question}",
  48. # "images": [self._preprocess_image(image_path)],
  49. # },
  50. # {"role": "<|Assistant|>", "content": system_prompt},
  51. # ]
  52. # def _preprocess_image(self, image_path: Union[str, Path]) -> str:
  53. # """Validate and standardize image input format"""
  54. # # Remove actual resizing to use processor's native handling
  55. # # Can add caching mechanism here for frequently used images
  56. # return str(image_path)
  57. # @torch.inference_mode()
  58. # def generate_response(
  59. # self,
  60. # conversation: List[Dict],
  61. # generation_config: Dict = None
  62. # ) -> str:
  63. # """Optimized generation pipeline with batch processing"""
  64. # # Default generation parameters
  65. # default_config = {
  66. # "max_new_tokens": 512,
  67. # "do_sample": False,
  68. # "use_cache": True,
  69. # "temperature": 0.9,
  70. # "pad_token_id": self.tokenizer.eos_token_id,
  71. # "bos_token_id": self.tokenizer.bos_token_id,
  72. # "eos_token_id": self.tokenizer.eos_token_id,
  73. # }
  74. # config = {**default_config, **(generation_config or {})}
  75. # # Batch processing pipeline
  76. # pil_images = load_pil_images(conversation)
  77. # inputs = self.processor(
  78. # conversations=conversation,
  79. # images=pil_images,
  80. # force_batchify=True
  81. # ).to(self.device)
  82. # # Direct memory reuse for embeddings
  83. # inputs_embeds = self.model.prepare_inputs_embeds(**inputs)
  84. # # Accelerated generation
  85. # outputs = self.model.language_model.generate(
  86. # inputs_embeds=inputs_embeds,
  87. # attention_mask=inputs.attention_mask,
  88. # **config
  89. # )
  90. # return self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
  91. # def main():
  92. # # Benchmark and example usage
  93. # assistant = JanusVisualAssistant()
  94. # start_time = time.time()
  95. # conversation = assistant.create_conversation(
  96. # image_path="/data/data/luosy/project/oral/data/key_frame/frame_001.jpg",
  97. # question="以JSON格式回复,包含字段【人物数量】【人物服装】【人物配饰】"
  98. # )
  99. # response = assistant.generate_response(conversation)
  100. # print(f"Response: {response}")
  101. # print(f"Total time: {time.time() - start_time:.2f}s")
  102. # if __name__ == "__main__":
  103. # main()
  104. import time
  105. import asyncio
  106. import logging
  107. from pathlib import Path
  108. from typing import List, Dict, Union
  109. from PIL import Image
  110. import torch
  111. from transformers import AutoModelForCausalLM
  112. from janus.models import MultiModalityCausalLM, VLChatProcessor
  113. from janus.utils.io import load_pil_images
  114. from utils.logger_config import setup_logger
  115. # 配置日志系统
  116. logger = setup_logger(__name__)
  117. class JanusVisualAssistant:
  118. """A highly optimized visual assistant for multimodal interactions"""
  119. def __init__(
  120. self,
  121. model_path: str = "/data/data/luosy/models/Janus-Pro-7B",
  122. dtype: torch.dtype = torch.bfloat16,
  123. device: str = "cuda"
  124. ):
  125. """Initialize model components with efficient memory management"""
  126. self.dtype = dtype
  127. self.device = device
  128. self.image_cache = {}
  129. # Initialize components with memory optimization
  130. self.processor = VLChatProcessor.from_pretrained(model_path)
  131. self.tokenizer = self.processor.tokenizer
  132. self.model = self._load_model(model_path).eval()
  133. def _load_model(self, model_path: str) -> MultiModalityCausalLM:
  134. """Load model with optimized memory allocation"""
  135. return AutoModelForCausalLM.from_pretrained(
  136. model_path,
  137. trust_remote_code=True
  138. ).to(self.dtype).to(self.device)
  139. def _monitor_memory(self):
  140. """Monitor GPU memory usage"""
  141. allocated = torch.cuda.memory_allocated(self.device)
  142. reserved = torch.cuda.memory_reserved(self.device)
  143. logger.info(f"Memory allocated: {allocated / 1024 ** 2:.2f} MB")
  144. logger.info(f"Memory reserved: {reserved / 1024 ** 2:.2f} MB")
  145. def create_conversation(
  146. self,
  147. image_paths: List[str],
  148. questions: List[str],
  149. system_prompt: str = "你是一个专业的视频理解助手"
  150. ) -> List[Dict]:
  151. """Build conversation structure with efficient image handling"""
  152. if len(image_paths) != len(questions):
  153. raise ValueError("The number of images must match the number of questions.")
  154. conversations = []
  155. for image_path, question in zip(image_paths, questions):
  156. conversations.append({
  157. "role": "<|User|>",
  158. "content": f"<image_placeholder>\n{question}",
  159. "images": [self._preprocess_image(image_path)],
  160. })
  161. conversations.append({"role": "<|Assistant|>", "content": system_prompt})
  162. return conversations
  163. def _preprocess_image(self, image_path: Union[str, Path]) -> str:
  164. """Validate and standardize image input format with caching"""
  165. if image_path in self.image_cache:
  166. return self.image_cache[image_path]
  167. # Load and cache the image
  168. try:
  169. image = Image.open(image_path).convert('RGB')
  170. self.image_cache[image_path] = str(image_path)
  171. return str(image_path)
  172. except Exception as e:
  173. logger.error(f"Error processing image {image_path}: {e}")
  174. raise
  175. @torch.inference_mode()
  176. async def generate_response(
  177. self,
  178. conversation: List[Dict],
  179. generation_config: Dict = None
  180. ) -> str:
  181. """Optimized generation pipeline with batch processing"""
  182. # Monitor memory before processing
  183. self._monitor_memory()
  184. # Default generation parameters
  185. default_config = {
  186. "max_new_tokens": 512,
  187. "do_sample": False,
  188. "use_cache": True,
  189. "temperature": 0.9,
  190. "pad_token_id": self.tokenizer.eos_token_id,
  191. "bos_token_id": self.tokenizer.bos_token_id,
  192. "eos_token_id": self.tokenizer.eos_token_id,
  193. }
  194. config = {**default_config, **(generation_config or {})}
  195. # Batch processing pipeline
  196. pil_images = load_pil_images(conversation)
  197. inputs = self.processor(
  198. conversations=conversation,
  199. images=pil_images,
  200. force_batchify=True
  201. ).to(self.device)
  202. # Direct memory reuse for embeddings
  203. inputs_embeds = self.model.prepare_inputs_embeds(**inputs)
  204. # Accelerated generation
  205. outputs = self.model.language_model.generate(
  206. inputs_embeds=inputs_embeds,
  207. attention_mask=inputs.attention_mask,
  208. **config
  209. )
  210. return self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
  211. async def main():
  212. # Benchmark and example usage
  213. assistant = JanusVisualAssistant()
  214. start_time = time.time()
  215. conversation = assistant.create_conversation(
  216. image_paths=["/data/data/luosy/project/oral/data/key_frame/frame_014.jpg"],
  217. questions=["以JSON格式回复,包含字段【人物数量】【人物服装】【人物配饰】"]
  218. )
  219. response = await assistant.generate_response(conversation)
  220. print(f"Response: {response}")
  221. print(f"Total time: {time.time() - start_time:.2f}s")
  222. if __name__ == "__main__":
  223. asyncio.run(main())