megrezo.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896
  1. # # coding=utf-8
  2. # # Adapted from
  3. # # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # # Copyright 2023 The vLLM team.
  5. # # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. # #
  7. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # # and OPT implementations in this library. It has been modified from its
  9. # # original forms to accommodate minor architectural differences compared
  10. # # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. # #
  12. # # Licensed under the Apache License, Version 2.0 (the "License");
  13. # # you may not use this file except in compliance with the License.
  14. # # You may obtain a copy of the License at
  15. # #
  16. # # http://www.apache.org/licenses/LICENSE-2.0
  17. # #
  18. # # Unless required by applicable law or agreed to in writing, software
  19. # # distributed under the License is distributed on an "AS IS" BASIS,
  20. # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # # See the License for the specific language governing permissions and
  22. # # limitations under the License.
  23. # """Inference-only MegrezO model compatible with HuggingFace weights."""
  24. # from functools import lru_cache
  25. # from functools import partial
  26. # from typing import Any, Callable, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union
  27. # import numpy as np
  28. # import torch
  29. # import torch.nn.functional as F
  30. # import torch.types
  31. # from PIL import Image
  32. # from torch import Tensor
  33. # from torch import nn
  34. # from torch.nn.init import trunc_normal_
  35. # from transformers import PretrainedConfig
  36. # from vllm.attention import AttentionMetadata
  37. # from vllm.config import CacheConfig
  38. # from vllm.config import MultiModalConfig
  39. # from vllm.inputs import INPUT_REGISTRY
  40. # from vllm.inputs import DecoderOnlyInputs
  41. # from vllm.inputs import InputContext
  42. # from vllm.inputs import token_inputs
  43. # from vllm.model_executor.layers.linear import ReplicatedLinear
  44. # from vllm.model_executor.layers.logits_processor import LogitsProcessor
  45. # from vllm.model_executor.layers.quantization import QuantizationConfig
  46. # from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
  47. # from vllm.model_executor.layers.sampler import Sampler
  48. # from vllm.model_executor.layers.sampler import SamplerOutput
  49. # from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
  50. # from vllm.model_executor.model_loader.weight_utils import default_weight_loader
  51. # from vllm.model_executor.models import VllmModelForTextGeneration
  52. # from vllm.model_executor.models.idefics2_vision_model import Idefics2VisionTransformer
  53. # from vllm.model_executor.models.interfaces import SupportsMultiModal
  54. # from vllm.model_executor.models.interfaces import SupportsPP
  55. # from vllm.model_executor.models.llama import LlamaModel
  56. # from vllm.model_executor.models.module_mapping import MultiModelKeys
  57. # from vllm.model_executor.models.utils import LLMWrapper
  58. # from vllm.model_executor.models.utils import is_pp_missing_parameter
  59. # from vllm.model_executor.sampling_metadata import SamplingMetadata
  60. # from vllm.multimodal import MULTIMODAL_REGISTRY
  61. # from vllm.multimodal.base import MultiModalInputs
  62. # from vllm.multimodal.utils import cached_get_tokenizer
  63. # from vllm.sequence import IntermediateTensors
  64. # from vllm.sequence import SequenceData
  65. # from vllm.transformers_utils.processor import get_processor
  66. # RawImageType = Union[Image.Image, torch.Tensor]
  67. # RawAudioType = Union[bytes, torch.Tensor]
  68. # cached_get_processor = lru_cache(get_processor)
  69. # class MegrezORawImageInput(TypedDict):
  70. # """Input mapper input with auxiliary data for computing image bounds."""
  71. # image: RawImageType
  72. # class MegrezOAudioInput(TypedDict):
  73. # type: Literal["audio"]
  74. # data: RawAudioType
  75. # class MegrezOAudioTensorInput(TypedDict):
  76. # type: Literal["audio_tensor"]
  77. # input_audios: torch.Tensor
  78. # input_audio_lengths: torch.Tensor
  79. # audio_span_tokens: torch.Tensor
  80. # class MegrezOImagePixelInputs(TypedDict):
  81. # type: Literal["pixel_values"]
  82. # pixel_values: torch.Tensor
  83. # """
  84. # Shape: `(batch_size * num_images, num_channels, height, width)`
  85. # Note that the image size may vary, so we pass it as a list
  86. # instead of a batched tensor.
  87. # """
  88. # tgt_sizes: torch.Tensor
  89. # """
  90. # Shape: `(batch_size * num_images, 2)`
  91. # This should be in `(height, width)` format.
  92. # """
  93. # patch_attention_mask: torch.Tensor
  94. # """
  95. # Shape: `(batch_size * num_images, num_patches, num_patches)`
  96. # """
  97. # class MegrezOImageEmbeddingInputs(TypedDict):
  98. # type: Literal["image_embeds"]
  99. # data: torch.Tensor
  100. # """
  101. # Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
  102. # `hidden_size` must match the hidden size of language model backbone.
  103. # instead of a batched tensor.
  104. # """
  105. # image_bounds: torch.Tensor
  106. # """
  107. # Shape: `(batch_size * num_images, 2)`
  108. # This should be in `(start, stop)` format.
  109. # """
  110. # def insert_audio_embeddings(text_embeddings, inserted_embeddings, inserted_bounds):
  111. # inserted_bounds = inserted_bounds.long()
  112. # for idx in range(len(inserted_embeddings)):
  113. # bid = inserted_bounds[idx][0]
  114. # start_id = inserted_bounds[idx][1]
  115. # end_id = inserted_bounds[idx][2]
  116. # embedding = inserted_embeddings[idx]
  117. # text_embeddings[start_id + 1 : end_id] = embedding
  118. # return text_embeddings
  119. # def insert_image_embeddings(text_embeddings, inserted_embeddings, inserted_bounds):
  120. # inserted_bounds = inserted_bounds.long()
  121. # for idx in range(len(inserted_embeddings)):
  122. # bid = inserted_bounds[idx][0]
  123. # start_id = inserted_bounds[idx][1]
  124. # end_id = inserted_bounds[idx][2]
  125. # embedding = inserted_embeddings[idx]
  126. # text_embeddings[start_id:end_id] = embedding
  127. # return text_embeddings
  128. # MegrezOImageInputs = Union[MegrezOImagePixelInputs]
  129. # MegrezOAudioInputs = Union[MegrezOAudioTensorInput]
  130. # # region: Resampler
  131. # DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
  132. # class Resampler(nn.Module):
  133. # def __init__(
  134. # self,
  135. # num_queries: int,
  136. # embed_dim: int,
  137. # num_heads: int,
  138. # kv_dim: Optional[int] = None,
  139. # norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  140. # max_size: Tuple[int, int] = (70, 70),
  141. # quant_config: Optional[QuantizationConfig] = None,
  142. # prefix: str = "",
  143. # ) -> None:
  144. # super().__init__()
  145. # self.num_queries = num_queries
  146. # self.embed_dim = embed_dim
  147. # self.num_heads = num_heads
  148. # self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
  149. # trunc_normal_(self.query, std=0.02)
  150. # if kv_dim is not None and kv_dim != embed_dim:
  151. # self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False, quant_config=quant_config, prefix=prefix)
  152. # else:
  153. # # Maintain the same return value with ReplicatedLinear.forward
  154. # self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
  155. # nn.Identity()(*args, **kwargs),
  156. # None,
  157. # )
  158. # self.attn = nn.MultiheadAttention(embed_dim, num_heads)
  159. # self.ln_q = norm_layer(embed_dim)
  160. # self.ln_kv = norm_layer(embed_dim)
  161. # self.do_post_projection = True
  162. # self.ln_post = norm_layer(embed_dim)
  163. # self.proj = nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
  164. # self.max_size = max_size
  165. # self._set_2d_pos_cache(self.max_size)
  166. # self.apply(self._init_weights)
  167. # def _init_weights(self, m: nn.Module) -> None:
  168. # if isinstance(m, nn.Linear):
  169. # trunc_normal_(m.weight, std=0.02)
  170. # if isinstance(m, nn.Linear) and m.bias is not None:
  171. # nn.init.constant_(m.bias, 0)
  172. # elif isinstance(m, nn.LayerNorm):
  173. # nn.init.constant_(m.bias, 0)
  174. # nn.init.constant_(m.weight, 1.0)
  175. # def _repeat(self, query, N: int):
  176. # return query.unsqueeze(1).repeat(1, N, 1)
  177. # def _set_2d_pos_cache(self, max_size: Tuple[int, int], device: torch.types.Device = "cpu") -> None:
  178. # pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, max_size, version=(2, 5))
  179. # pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
  180. # self.register_buffer("pos_embed", pos_embed, persistent=False)
  181. # def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, device: torch.types.Device) -> None:
  182. # max_h = tgt_sizes[:, 0].max().item()
  183. # max_w = tgt_sizes[:, 1].max().item()
  184. # assert isinstance(max_h, int) and isinstance(max_w, int)
  185. # if max_h > self.max_size[0] or max_w > self.max_size[1]:
  186. # self.max_size = (
  187. # max(max_h, self.max_size[0]),
  188. # max(max_w, self.max_size[1]),
  189. # )
  190. # self._set_2d_pos_cache(self.max_size, device)
  191. # def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor:
  192. # assert x.shape[0] == tgt_sizes.shape[0]
  193. # bs = x.shape[0]
  194. # device = x.device
  195. # dtype = x.dtype
  196. # patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
  197. # self._adjust_pos_cache(tgt_sizes, device=device)
  198. # max_patch_len = patch_len.max().item()
  199. # assert isinstance(max_patch_len, int)
  200. # key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device)
  201. # pos_embed = []
  202. # for i in range(bs):
  203. # tgt_h, tgt_w = tgt_sizes[i].tolist()
  204. # pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D
  205. # key_padding_mask[i, patch_len[i] :] = True
  206. # pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute(
  207. # 1, 0, 2
  208. # ) # BLD => L * B * D
  209. # x, _ = self.kv_proj(x) # B * L * D
  210. # x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
  211. # q = self.ln_q(self.query) # Q * D
  212. # out = self.attn(
  213. # self._repeat(q, bs), # Q * B * D
  214. # x + pos_embed, # L * B * D + L * B * D
  215. # x,
  216. # key_padding_mask=key_padding_mask,
  217. # )[0]
  218. # # out: Q * B * D
  219. # x = out.permute(1, 0, 2) # B * Q * D
  220. # x = self.ln_post(x)
  221. # x = x @ self.proj
  222. # return x
  223. # # endregion
  224. # # region: AudioEncoder
  225. # class LayerNorm(nn.LayerNorm):
  226. # def forward(self, x: Tensor) -> Tensor:
  227. # # return super().forward(x.float()).type(x.dtype)
  228. # return super().forward(x).type(x.dtype)
  229. # class Linear(nn.Linear):
  230. # def forward(self, x: Tensor) -> Tensor:
  231. # return F.linear(
  232. # x,
  233. # self.weight.to(x.dtype),
  234. # None if self.bias is None else self.bias.to(x.dtype),
  235. # )
  236. # class Conv1d(nn.Conv1d):
  237. # def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
  238. # return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
  239. # def sinusoids(length, channels, max_timescale=10000):
  240. # """Returns sinusoids for positional embedding"""
  241. # assert channels % 2 == 0
  242. # log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
  243. # inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
  244. # scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
  245. # return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
  246. # class MultiHeadAttention(nn.Module):
  247. # def __init__(self, n_state: int, n_head: int):
  248. # super().__init__()
  249. # self.n_head = n_head
  250. # self.query = Linear(n_state, n_state)
  251. # self.key = Linear(n_state, n_state, bias=False)
  252. # self.value = Linear(n_state, n_state)
  253. # self.out = Linear(n_state, n_state)
  254. # def forward(
  255. # self,
  256. # x: Tensor,
  257. # xa: Optional[Tensor] = None,
  258. # mask: Optional[Tensor] = None,
  259. # kv_cache: Optional[dict] = None,
  260. # ):
  261. # q = self.query(x)
  262. # if kv_cache is None or xa is None or self.key not in kv_cache:
  263. # # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
  264. # # otherwise, perform key/value projections for self- or cross-attention as usual.
  265. # k = self.key(x if xa is None else xa)
  266. # v = self.value(x if xa is None else xa)
  267. # else:
  268. # # for cross-attention, calculate keys and values once and reuse in subsequent calls.
  269. # k = kv_cache[self.key]
  270. # v = kv_cache[self.value]
  271. # wv, qk = self.qkv_attention(q, k, v, mask)
  272. # return self.out(wv), qk
  273. # def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
  274. # n_batch, n_ctx, n_state = q.shape
  275. # scale = (n_state // self.n_head) ** -0.25
  276. # q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
  277. # k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
  278. # v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
  279. # qk = q @ k
  280. # if mask is not None:
  281. # qk += mask
  282. # w = F.softmax(qk, dim=-1).to(q.dtype)
  283. # return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
  284. # class ResidualAttentionBlock(nn.Module):
  285. # def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
  286. # super().__init__()
  287. # self.attn = MultiHeadAttention(n_state, n_head)
  288. # self.attn_ln = LayerNorm(n_state)
  289. # self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
  290. # self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
  291. # n_mlp = n_state * 4
  292. # self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
  293. # self.mlp_ln = LayerNorm(n_state)
  294. # def forward(
  295. # self,
  296. # x: Tensor,
  297. # xa: Optional[Tensor] = None,
  298. # mask: Optional[Tensor] = None,
  299. # kv_cache: Optional[dict] = None,
  300. # ):
  301. # x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
  302. # if self.cross_attn:
  303. # x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
  304. # x = x + self.mlp(self.mlp_ln(x))
  305. # return x
  306. # class AudioEncoder(nn.Module):
  307. # def __init__(
  308. # self,
  309. # n_mels: int,
  310. # n_ctx: int,
  311. # n_state: int,
  312. # n_head: int,
  313. # n_layer: int,
  314. # output_dim: int = 512,
  315. # avg_pool: bool = True,
  316. # add_audio_bos_eos_token: bool = True,
  317. # **kwargs,
  318. # ):
  319. # super().__init__()
  320. # self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
  321. # self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
  322. # # self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
  323. # self.positional_embedding = nn.Parameter(sinusoids(n_ctx, n_state), requires_grad=False)
  324. # self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
  325. # [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
  326. # )
  327. # self.ln_post = LayerNorm(n_state)
  328. # if avg_pool:
  329. # self.avg_pooler = nn.AvgPool1d(2, stride=2)
  330. # else:
  331. # self.avg_pooler = None
  332. # self.proj = nn.Linear(n_state, output_dim)
  333. # if add_audio_bos_eos_token:
  334. # self.audio_bos_eos_token = nn.Embedding(2, output_dim)
  335. # else:
  336. # self.audio_bos_eos_token = None
  337. # self.output_dim = output_dim
  338. # self.n_head = n_head
  339. # def forward(self, x: Tensor, padding_mask: Tensor = None, audio_lengths: Tensor = None):
  340. # """
  341. # x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
  342. # the mel spectrogram of the audio
  343. # """
  344. # x = x.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
  345. # if audio_lengths is not None:
  346. # input_mel_len = audio_lengths[:, 0] * 2
  347. # max_mel_len_in_batch = input_mel_len.max()
  348. # x = x[:, :, :max_mel_len_in_batch]
  349. # x = F.gelu(self.conv1(x))
  350. # x = F.gelu(self.conv2(x))
  351. # x = x.permute(0, 2, 1) # B, L, D
  352. # bsz = x.size(0)
  353. # src_len = x.size(1)
  354. # self.input_positional_embedding = self.positional_embedding[:src_len]
  355. # assert (
  356. # x.shape[1:] == self.input_positional_embedding.shape
  357. # ), f"incorrect audio shape: {x.shape[1:], self.input_positional_embedding.shape}"
  358. # x = (x + self.input_positional_embedding).to(x.dtype)
  359. # if padding_mask is not None:
  360. # padding_mask = padding_mask.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
  361. # batch_src_len = padding_mask.size(1)
  362. # x = x[:, :batch_src_len, :]
  363. # padding_mask = padding_mask.view(bsz, -1, batch_src_len)
  364. # padding_mask_ = padding_mask.all(1)
  365. # x[padding_mask_] = 0
  366. # key_padding_mask = (
  367. # padding_mask_.view(bsz, 1, 1, batch_src_len)
  368. # .expand(-1, self.n_head, -1, -1)
  369. # .reshape(bsz, self.n_head, 1, batch_src_len)
  370. # )
  371. # new_padding_mask = torch.zeros_like(key_padding_mask, dtype=x.dtype)
  372. # padding_mask = new_padding_mask.masked_fill(key_padding_mask, float("-inf"))
  373. # for block in self.blocks:
  374. # x = block(x, mask=padding_mask)
  375. # if self.avg_pooler:
  376. # x = x.permute(0, 2, 1)
  377. # x = self.avg_pooler(x)
  378. # x = x.permute(0, 2, 1)
  379. # x = self.ln_post(x)
  380. # x = self.proj(x)
  381. # if self.audio_bos_eos_token is not None:
  382. # bos = self.audio_bos_eos_token.weight[0][None, :]
  383. # eos = self.audio_bos_eos_token.weight[1][None, :]
  384. # else:
  385. # bos, eos = None, None
  386. # return x, bos, eos
  387. # def encode(
  388. # self,
  389. # input_audios: Tensor,
  390. # input_audio_lengths: Tensor,
  391. # audio_span_tokens: List,
  392. # ):
  393. # real_input_audio_lens = input_audio_lengths[:, 0].tolist()
  394. # max_len_in_batch = max(real_input_audio_lens)
  395. # padding_mask = torch.ones([input_audios.size(0), max_len_in_batch]).to(
  396. # dtype=self.conv1.weight.dtype, device=self.conv1.weight.device
  397. # )
  398. # for index in range(len(input_audios)):
  399. # padding_mask[index, : input_audio_lengths[index][0].item()] = 0
  400. # x, bos, eos = self(input_audios, padding_mask, input_audio_lengths)
  401. # output_audios = []
  402. # for i in range(len(audio_span_tokens)):
  403. # audio_span = audio_span_tokens[i]
  404. # audio = x[i][: audio_span - 2]
  405. # if bos is not None:
  406. # audio = torch.concat([bos, audio, eos])
  407. # assert len(audio) == audio_span
  408. # output_audios.append(audio)
  409. # return output_audios
  410. # class AudioModel(torch.nn.Module):
  411. # def __init__(self, config):
  412. # super(AudioModel, self).__init__()
  413. # self.config = config
  414. # self.audio = AudioEncoder(**config.audio_config.to_dict())
  415. # def forward(self, audio_info):
  416. # audios = audio_info["input_audios"][0]
  417. # input_audio_lengths = audio_info["input_audio_lengths"][0]
  418. # audio_span_tokens = audio_info["audio_span_tokens"][0]
  419. # audios_features = self.audio.encode(audios, input_audio_lengths, audio_span_tokens)
  420. # return audios_features
  421. # # endregion
  422. # def get_max_megrezo_image_tokens(ctx: InputContext):
  423. # hf_config = ctx.get_hf_config()
  424. # return getattr(hf_config, "query_num", 64) * 10
  425. # def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
  426. # return SequenceData.from_prompt_token_counts((0, seq_len))
  427. # def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig, num_images: int):
  428. # width = height = hf_config.vision_config.image_size
  429. # imgs = [MegrezORawImageInput(image=Image.new("RGB", (width, height), color=0)) for _ in range(num_images)]
  430. # return {"image": imgs}
  431. # def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]):
  432. # hf_config = ctx.get_hf_config()
  433. # num_images = mm_counts["image"]
  434. # seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
  435. # mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images) # skip audio for now
  436. # return (seq_data, mm_data)
  437. # def input_processor_for_megrezo(ctx: InputContext, inputs: DecoderOnlyInputs):
  438. # multi_modal_data = inputs.get("multi_modal_data")
  439. # if multi_modal_data is None or ("image" not in multi_modal_data and "audio" not in multi_modal_data):
  440. # return inputs
  441. # model_config = ctx.model_config
  442. # tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=model_config.trust_remote_code)
  443. # processor = cached_get_processor(model_config.model, trust_remote_code=model_config.trust_remote_code)
  444. # prompt = inputs.get("prompt")
  445. # token_ids = inputs.get("prompt_token_ids")
  446. # if prompt is None:
  447. # prompt = tokenizer.decode(token_ids)
  448. # images = multi_modal_data.get("image")
  449. # audios = multi_modal_data.get("audio")
  450. # prompt, multimodal_inputs = processor.process_multimodal_inputs(
  451. # prompt,
  452. # images=images,
  453. # audios=audios,
  454. # return_tensors="pt",
  455. # )
  456. # text_encodings = tokenizer(
  457. # prompt,
  458. # return_tensors="pt",
  459. # padding=True,
  460. # padding_side="left",
  461. # )
  462. # encodings = processor.merge_encodings(text_encodings, multimodal_inputs)
  463. # data = processor.data_collator([encodings])
  464. # new_prompt = tokenizer.decode(data["input_ids"][0])
  465. # new_multi_modal_data = {
  466. # "image": data["image_encoding"],
  467. # "audio": data["audio_encoding"],
  468. # }
  469. # return token_inputs(
  470. # prompt_token_ids=data["input_ids"][0],
  471. # prompt=new_prompt,
  472. # multi_modal_data=new_multi_modal_data,
  473. # )
  474. # def input_mapper_for_megrezo(ctx: InputContext, data: object):
  475. # return MultiModalInputs(data)
  476. # @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_megrezo)
  477. # @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_megrezo)
  478. # @MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", 3000)
  479. # @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_megrezo_image_tokens)
  480. # @INPUT_REGISTRY.register_input_processor(input_processor_for_megrezo)
  481. # class MegrezOModel(nn.Module, VllmModelForTextGeneration, SupportsMultiModal, SupportsPP):
  482. # packed_modules_mapping = {
  483. # "qkv_proj": ["q_proj", "k_proj", "v_proj"],
  484. # "gate_up_proj": ["gate_proj", "up_proj"],
  485. # }
  486. # def __init__(
  487. # self,
  488. # config: PretrainedConfig,
  489. # multimodal_config: MultiModalConfig,
  490. # cache_config: Optional[CacheConfig] = None,
  491. # quant_config: Optional[QuantizationConfig] = None,
  492. # ):
  493. # super().__init__()
  494. # # All MiniCPM-V models disable `tie_word_embeddings` but
  495. # # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
  496. # # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
  497. # # and config class
  498. # self.config = config
  499. # self.multimodal_config = multimodal_config
  500. # self.llm = self.init_llm(config, cache_config, quant_config, prefix="model")
  501. # self.vision = self.init_vision_module(config, quant_config, prefix="vpm")
  502. # param_dtype = torch.get_default_dtype()
  503. # self.vision.to(dtype=param_dtype)
  504. # self.audio = self.init_audio_module(config, quant_config)
  505. # self.audio.to(dtype=param_dtype)
  506. # self.vision_dim = self.vision.embeddings.embed_dim
  507. # self.embed_dim = self.config.hidden_size
  508. # self.resampler = self.init_resampler(
  509. # self.embed_dim, self.vision_dim, quant_config=quant_config, prefix="vision.resampler"
  510. # )
  511. # self.resampler.to(device="cuda", dtype=param_dtype)
  512. # self.lm_head = ParallelLMHead(
  513. # config.vocab_size, config.hidden_size, quant_config=quant_config, prefix="llm.lm_head"
  514. # )
  515. # self.logits_processor = LogitsProcessor(config.vocab_size)
  516. # self.sampler = Sampler()
  517. # self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors
  518. # self._called_cnt = 0
  519. # def get_vision_hidden_states(
  520. # self,
  521. # pixel_values,
  522. # tgt_sizes,
  523. # patch_attn_mask,
  524. # ) -> torch.Tensor:
  525. # device = self.vision.embeddings.position_embedding.weight.device
  526. # dtype = self.vision.embeddings.position_embedding.weight.dtype
  527. # pixel_values = torch.stack([(image.to(device) - 127.5) / 127.5 for image in pixel_values]).type(dtype)
  528. # vision_embedding = self.vision(
  529. # pixel_values.type(dtype),
  530. # patch_attention_mask=patch_attn_mask,
  531. # tgt_sizes=tgt_sizes,
  532. # )
  533. # return self.resampler(vision_embedding, tgt_sizes)
  534. # def compose_embeddings(self, mini_batch):
  535. # input_ids = mini_batch["input_ids"]
  536. # image_encoding = mini_batch.get("image_encoding")
  537. # audio_encoding = mini_batch.get("audio_encoding")
  538. # embeddings_text = self.llm.model.embed_tokens(input_ids)
  539. # input_embeds = embeddings_text
  540. # if image_encoding:
  541. # pixel_values = image_encoding["pixel_values"][0]
  542. # tgt_sizes = image_encoding["tgt_sizes"][0]
  543. # patch_attention_mask = image_encoding["patch_attention_mask"][0]
  544. # bounds_image = image_encoding["image_bounds"][0]
  545. # device = self.vision.embeddings.position_embedding.weight.device
  546. # dtype = self.vision.embeddings.position_embedding.weight.dtype
  547. # embeddings_image = self.get_vision_hidden_states(
  548. # pixel_values.to(device, dtype),
  549. # tgt_sizes,
  550. # patch_attention_mask.to(device),
  551. # )
  552. # input_embeds = insert_image_embeddings(embeddings_text, embeddings_image, bounds_image)
  553. # if audio_encoding:
  554. # embeddings_audio = self.audio(audio_encoding)
  555. # bounds_audio = audio_encoding["audio_bounds"][0]
  556. # input_embeds = insert_audio_embeddings(embeddings_text, embeddings_audio, bounds_audio)
  557. # return input_embeds
  558. # def _parse_inputs(self, input_ids: torch.Tensor, **kwargs):
  559. # if kwargs.get("pixel_values") is not None:
  560. # image_encoding = {
  561. # "pixel_values": kwargs.get("pixel_values"),
  562. # "tgt_sizes": kwargs.get("tgt_sizes"),
  563. # "patch_attention_mask": kwargs.get("patch_attention_mask"),
  564. # "image_bounds": kwargs.get("image_bounds"),
  565. # }
  566. # else:
  567. # image_encoding = None
  568. # if kwargs.get("input_audios") is not None:
  569. # audio_encoding = {
  570. # "input_audios": kwargs.get("input_audios"),
  571. # "input_audio_lengths": kwargs.get("input_audio_lengths"),
  572. # "audio_span_tokens": kwargs.get("audio_span_tokens"),
  573. # "audio_bounds": kwargs.get("audio_bounds"),
  574. # }
  575. # else:
  576. # audio_encoding = None
  577. # return {
  578. # "input_ids": input_ids,
  579. # "image_encoding": image_encoding,
  580. # "audio_encoding": audio_encoding,
  581. # }
  582. # def forward(
  583. # self,
  584. # input_ids: torch.Tensor,
  585. # positions: torch.Tensor,
  586. # kv_caches: List[torch.Tensor],
  587. # attn_metadata: AttentionMetadata,
  588. # intermediate_tensors: Optional[IntermediateTensors] = None,
  589. # **kwargs: Any,
  590. # ) -> torch.Tensor:
  591. # if intermediate_tensors is not None:
  592. # embeddings = None
  593. # else:
  594. # mini_batch = self._parse_inputs(input_ids, **kwargs)
  595. # embeddings = self.compose_embeddings(mini_batch)
  596. # # always pass the input via `inputs_embeds`
  597. # # to make sure the computation graph is consistent
  598. # # for `torch.compile` integration
  599. # input_ids = None
  600. # output = self.llm(
  601. # input_ids=input_ids,
  602. # positions=positions,
  603. # kv_caches=kv_caches,
  604. # attn_metadata=attn_metadata,
  605. # intermediate_tensors=intermediate_tensors,
  606. # inputs_embeds=embeddings,
  607. # )
  608. # self._called_cnt += 1
  609. # return output
  610. # def compute_logits(
  611. # self,
  612. # hidden_states: torch.Tensor,
  613. # sampling_metadata: SamplingMetadata,
  614. # ) -> Optional[torch.Tensor]:
  615. # logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
  616. # return logits
  617. # def sample(
  618. # self,
  619. # logits: torch.Tensor,
  620. # sampling_metadata: SamplingMetadata,
  621. # ) -> Optional[SamplerOutput]:
  622. # next_tokens = self.sampler(logits, sampling_metadata)
  623. # return next_tokens
  624. # def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  625. # stacked_params_mapping = [
  626. # # (param_name, shard_name, shard_id)
  627. # (".qkv_proj", ".q_proj", "q"),
  628. # (".qkv_proj", ".k_proj", "k"),
  629. # (".qkv_proj", ".v_proj", "v"),
  630. # (".gate_up_proj", ".gate_proj", 0),
  631. # (".gate_up_proj", ".up_proj", 1),
  632. # ]
  633. # keys_to_modify_mapping = {
  634. # "llm.lm_head": "lm_head",
  635. # "vision.resampler": "resampler",
  636. # }
  637. # params_dict = dict(self.named_parameters())
  638. # for name, loaded_weight in weights:
  639. # for key_to_modify, new_key in keys_to_modify_mapping.items():
  640. # if key_to_modify in name:
  641. # name = name.replace(key_to_modify, new_key)
  642. # if "rotary_emb.inv_freq" in name:
  643. # continue
  644. # if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
  645. # # Models trained using ColossalAI may include these tensors in
  646. # # the checkpoint. Skip them.
  647. # continue
  648. # # if "audio.positional_embedding" in name:
  649. # # continue
  650. # for param_name, weight_name, shard_id in stacked_params_mapping:
  651. # if weight_name not in name:
  652. # continue
  653. # name = name.replace(weight_name, param_name)
  654. # # Skip loading extra bias for GPTQ models.
  655. # if name.endswith(".bias") and name not in params_dict:
  656. # continue
  657. # if is_pp_missing_parameter(name, self):
  658. # continue
  659. # if name in params_dict:
  660. # param = params_dict[name]
  661. # weight_loader = param.weight_loader
  662. # weight_loader(param, loaded_weight, shard_id)
  663. # else:
  664. # print(f"Skipping loading of {name}")
  665. # break
  666. # else:
  667. # # Skip loading extra bias for GPTQ models.
  668. # if name.endswith(".bias") and name not in params_dict:
  669. # continue
  670. # if name is None:
  671. # continue
  672. # if is_pp_missing_parameter(name, self):
  673. # continue
  674. # if name in params_dict:
  675. # param = params_dict[name]
  676. # weight_loader = getattr(param, "weight_loader", default_weight_loader)
  677. # weight_loader(param, loaded_weight)
  678. # else:
  679. # print(f"Skipping loading of {name}")
  680. # def get_mm_mapping(self) -> MultiModelKeys:
  681. # """
  682. # Get the module prefix in multimodal models
  683. # """
  684. # return MultiModelKeys.from_string_field(language_model="llm", connector="resampler", tower_model="vpm")
  685. # def init_llm(
  686. # self,
  687. # config: PretrainedConfig,
  688. # cache_config: Optional[CacheConfig] = None,
  689. # quant_config: Optional[QuantizationConfig] = None,
  690. # prefix: str = "",
  691. # ) -> nn.Module:
  692. # return LLMWrapper(
  693. # LlamaModel(
  694. # config,
  695. # cache_config=cache_config,
  696. # quant_config=quant_config,
  697. # prefix=prefix,
  698. # ),
  699. # name=prefix,
  700. # )
  701. # def init_audio_module(
  702. # self,
  703. # config: PretrainedConfig,
  704. # quant_config: Optional[QuantizationConfig],
  705. # prefix: str = "",
  706. # ) -> nn.Module:
  707. # return AudioModel(config)
  708. # def init_vision_module(
  709. # self,
  710. # config: PretrainedConfig,
  711. # quant_config: Optional[QuantizationConfig],
  712. # prefix: str = "",
  713. # ) -> nn.Module:
  714. # model = LLMWrapper(
  715. # Idefics2VisionTransformer(config.vision_config),
  716. # name=prefix,
  717. # )
  718. # if self.config.drop_vision_last_layer:
  719. # model.encoder.layers = model.encoder.layers[:-1]
  720. # return model
  721. # def init_resampler(
  722. # self,
  723. # embed_dim: int,
  724. # vision_dim: int,
  725. # quant_config: Optional[QuantizationConfig] = None,
  726. # prefix: str = "",
  727. # ) -> nn.Module:
  728. # resampler = Resampler(
  729. # num_queries=self.config.query_num,
  730. # embed_dim=embed_dim,
  731. # num_heads=embed_dim // 128,
  732. # kv_dim=vision_dim,
  733. # quant_config=quant_config,
  734. # prefix=prefix,
  735. # )
  736. # return resampler