123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896 |
- # # coding=utf-8
- # # Adapted from
- # # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
- # # Copyright 2023 The vLLM team.
- # # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
- # #
- # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
- # # and OPT implementations in this library. It has been modified from its
- # # original forms to accommodate minor architectural differences compared
- # # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
- # #
- # # Licensed under the Apache License, Version 2.0 (the "License");
- # # you may not use this file except in compliance with the License.
- # # You may obtain a copy of the License at
- # #
- # # http://www.apache.org/licenses/LICENSE-2.0
- # #
- # # Unless required by applicable law or agreed to in writing, software
- # # distributed under the License is distributed on an "AS IS" BASIS,
- # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # # See the License for the specific language governing permissions and
- # # limitations under the License.
- # """Inference-only MegrezO model compatible with HuggingFace weights."""
- # from functools import lru_cache
- # from functools import partial
- # from typing import Any, Callable, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union
- # import numpy as np
- # import torch
- # import torch.nn.functional as F
- # import torch.types
- # from PIL import Image
- # from torch import Tensor
- # from torch import nn
- # from torch.nn.init import trunc_normal_
- # from transformers import PretrainedConfig
- # from vllm.attention import AttentionMetadata
- # from vllm.config import CacheConfig
- # from vllm.config import MultiModalConfig
- # from vllm.inputs import INPUT_REGISTRY
- # from vllm.inputs import DecoderOnlyInputs
- # from vllm.inputs import InputContext
- # from vllm.inputs import token_inputs
- # from vllm.model_executor.layers.linear import ReplicatedLinear
- # from vllm.model_executor.layers.logits_processor import LogitsProcessor
- # from vllm.model_executor.layers.quantization import QuantizationConfig
- # from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
- # from vllm.model_executor.layers.sampler import Sampler
- # from vllm.model_executor.layers.sampler import SamplerOutput
- # from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
- # from vllm.model_executor.model_loader.weight_utils import default_weight_loader
- # from vllm.model_executor.models import VllmModelForTextGeneration
- # from vllm.model_executor.models.idefics2_vision_model import Idefics2VisionTransformer
- # from vllm.model_executor.models.interfaces import SupportsMultiModal
- # from vllm.model_executor.models.interfaces import SupportsPP
- # from vllm.model_executor.models.llama import LlamaModel
- # from vllm.model_executor.models.module_mapping import MultiModelKeys
- # from vllm.model_executor.models.utils import LLMWrapper
- # from vllm.model_executor.models.utils import is_pp_missing_parameter
- # from vllm.model_executor.sampling_metadata import SamplingMetadata
- # from vllm.multimodal import MULTIMODAL_REGISTRY
- # from vllm.multimodal.base import MultiModalInputs
- # from vllm.multimodal.utils import cached_get_tokenizer
- # from vllm.sequence import IntermediateTensors
- # from vllm.sequence import SequenceData
- # from vllm.transformers_utils.processor import get_processor
- # RawImageType = Union[Image.Image, torch.Tensor]
- # RawAudioType = Union[bytes, torch.Tensor]
- # cached_get_processor = lru_cache(get_processor)
- # class MegrezORawImageInput(TypedDict):
- # """Input mapper input with auxiliary data for computing image bounds."""
- # image: RawImageType
- # class MegrezOAudioInput(TypedDict):
- # type: Literal["audio"]
- # data: RawAudioType
- # class MegrezOAudioTensorInput(TypedDict):
- # type: Literal["audio_tensor"]
- # input_audios: torch.Tensor
- # input_audio_lengths: torch.Tensor
- # audio_span_tokens: torch.Tensor
- # class MegrezOImagePixelInputs(TypedDict):
- # type: Literal["pixel_values"]
- # pixel_values: torch.Tensor
- # """
- # Shape: `(batch_size * num_images, num_channels, height, width)`
- # Note that the image size may vary, so we pass it as a list
- # instead of a batched tensor.
- # """
- # tgt_sizes: torch.Tensor
- # """
- # Shape: `(batch_size * num_images, 2)`
- # This should be in `(height, width)` format.
- # """
- # patch_attention_mask: torch.Tensor
- # """
- # Shape: `(batch_size * num_images, num_patches, num_patches)`
- # """
- # class MegrezOImageEmbeddingInputs(TypedDict):
- # type: Literal["image_embeds"]
- # data: torch.Tensor
- # """
- # Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
- # `hidden_size` must match the hidden size of language model backbone.
- # instead of a batched tensor.
- # """
- # image_bounds: torch.Tensor
- # """
- # Shape: `(batch_size * num_images, 2)`
- # This should be in `(start, stop)` format.
- # """
- # def insert_audio_embeddings(text_embeddings, inserted_embeddings, inserted_bounds):
- # inserted_bounds = inserted_bounds.long()
- # for idx in range(len(inserted_embeddings)):
- # bid = inserted_bounds[idx][0]
- # start_id = inserted_bounds[idx][1]
- # end_id = inserted_bounds[idx][2]
- # embedding = inserted_embeddings[idx]
- # text_embeddings[start_id + 1 : end_id] = embedding
- # return text_embeddings
- # def insert_image_embeddings(text_embeddings, inserted_embeddings, inserted_bounds):
- # inserted_bounds = inserted_bounds.long()
- # for idx in range(len(inserted_embeddings)):
- # bid = inserted_bounds[idx][0]
- # start_id = inserted_bounds[idx][1]
- # end_id = inserted_bounds[idx][2]
- # embedding = inserted_embeddings[idx]
- # text_embeddings[start_id:end_id] = embedding
- # return text_embeddings
- # MegrezOImageInputs = Union[MegrezOImagePixelInputs]
- # MegrezOAudioInputs = Union[MegrezOAudioTensorInput]
- # # region: Resampler
- # DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
- # class Resampler(nn.Module):
- # def __init__(
- # self,
- # num_queries: int,
- # embed_dim: int,
- # num_heads: int,
- # kv_dim: Optional[int] = None,
- # norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
- # max_size: Tuple[int, int] = (70, 70),
- # quant_config: Optional[QuantizationConfig] = None,
- # prefix: str = "",
- # ) -> None:
- # super().__init__()
- # self.num_queries = num_queries
- # self.embed_dim = embed_dim
- # self.num_heads = num_heads
- # self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
- # trunc_normal_(self.query, std=0.02)
- # if kv_dim is not None and kv_dim != embed_dim:
- # self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False, quant_config=quant_config, prefix=prefix)
- # else:
- # # Maintain the same return value with ReplicatedLinear.forward
- # self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
- # nn.Identity()(*args, **kwargs),
- # None,
- # )
- # self.attn = nn.MultiheadAttention(embed_dim, num_heads)
- # self.ln_q = norm_layer(embed_dim)
- # self.ln_kv = norm_layer(embed_dim)
- # self.do_post_projection = True
- # self.ln_post = norm_layer(embed_dim)
- # self.proj = nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
- # self.max_size = max_size
- # self._set_2d_pos_cache(self.max_size)
- # self.apply(self._init_weights)
- # def _init_weights(self, m: nn.Module) -> None:
- # if isinstance(m, nn.Linear):
- # trunc_normal_(m.weight, std=0.02)
- # if isinstance(m, nn.Linear) and m.bias is not None:
- # nn.init.constant_(m.bias, 0)
- # elif isinstance(m, nn.LayerNorm):
- # nn.init.constant_(m.bias, 0)
- # nn.init.constant_(m.weight, 1.0)
- # def _repeat(self, query, N: int):
- # return query.unsqueeze(1).repeat(1, N, 1)
- # def _set_2d_pos_cache(self, max_size: Tuple[int, int], device: torch.types.Device = "cpu") -> None:
- # pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, max_size, version=(2, 5))
- # pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
- # self.register_buffer("pos_embed", pos_embed, persistent=False)
- # def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, device: torch.types.Device) -> None:
- # max_h = tgt_sizes[:, 0].max().item()
- # max_w = tgt_sizes[:, 1].max().item()
- # assert isinstance(max_h, int) and isinstance(max_w, int)
- # if max_h > self.max_size[0] or max_w > self.max_size[1]:
- # self.max_size = (
- # max(max_h, self.max_size[0]),
- # max(max_w, self.max_size[1]),
- # )
- # self._set_2d_pos_cache(self.max_size, device)
- # def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor:
- # assert x.shape[0] == tgt_sizes.shape[0]
- # bs = x.shape[0]
- # device = x.device
- # dtype = x.dtype
- # patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
- # self._adjust_pos_cache(tgt_sizes, device=device)
- # max_patch_len = patch_len.max().item()
- # assert isinstance(max_patch_len, int)
- # key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device)
- # pos_embed = []
- # for i in range(bs):
- # tgt_h, tgt_w = tgt_sizes[i].tolist()
- # pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D
- # key_padding_mask[i, patch_len[i] :] = True
- # pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute(
- # 1, 0, 2
- # ) # BLD => L * B * D
- # x, _ = self.kv_proj(x) # B * L * D
- # x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
- # q = self.ln_q(self.query) # Q * D
- # out = self.attn(
- # self._repeat(q, bs), # Q * B * D
- # x + pos_embed, # L * B * D + L * B * D
- # x,
- # key_padding_mask=key_padding_mask,
- # )[0]
- # # out: Q * B * D
- # x = out.permute(1, 0, 2) # B * Q * D
- # x = self.ln_post(x)
- # x = x @ self.proj
- # return x
- # # endregion
- # # region: AudioEncoder
- # class LayerNorm(nn.LayerNorm):
- # def forward(self, x: Tensor) -> Tensor:
- # # return super().forward(x.float()).type(x.dtype)
- # return super().forward(x).type(x.dtype)
- # class Linear(nn.Linear):
- # def forward(self, x: Tensor) -> Tensor:
- # return F.linear(
- # x,
- # self.weight.to(x.dtype),
- # None if self.bias is None else self.bias.to(x.dtype),
- # )
- # class Conv1d(nn.Conv1d):
- # def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
- # return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
- # def sinusoids(length, channels, max_timescale=10000):
- # """Returns sinusoids for positional embedding"""
- # assert channels % 2 == 0
- # log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
- # inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
- # scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
- # return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
- # class MultiHeadAttention(nn.Module):
- # def __init__(self, n_state: int, n_head: int):
- # super().__init__()
- # self.n_head = n_head
- # self.query = Linear(n_state, n_state)
- # self.key = Linear(n_state, n_state, bias=False)
- # self.value = Linear(n_state, n_state)
- # self.out = Linear(n_state, n_state)
- # def forward(
- # self,
- # x: Tensor,
- # xa: Optional[Tensor] = None,
- # mask: Optional[Tensor] = None,
- # kv_cache: Optional[dict] = None,
- # ):
- # q = self.query(x)
- # if kv_cache is None or xa is None or self.key not in kv_cache:
- # # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
- # # otherwise, perform key/value projections for self- or cross-attention as usual.
- # k = self.key(x if xa is None else xa)
- # v = self.value(x if xa is None else xa)
- # else:
- # # for cross-attention, calculate keys and values once and reuse in subsequent calls.
- # k = kv_cache[self.key]
- # v = kv_cache[self.value]
- # wv, qk = self.qkv_attention(q, k, v, mask)
- # return self.out(wv), qk
- # def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
- # n_batch, n_ctx, n_state = q.shape
- # scale = (n_state // self.n_head) ** -0.25
- # q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
- # k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
- # v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
- # qk = q @ k
- # if mask is not None:
- # qk += mask
- # w = F.softmax(qk, dim=-1).to(q.dtype)
- # return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
- # class ResidualAttentionBlock(nn.Module):
- # def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
- # super().__init__()
- # self.attn = MultiHeadAttention(n_state, n_head)
- # self.attn_ln = LayerNorm(n_state)
- # self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
- # self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
- # n_mlp = n_state * 4
- # self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
- # self.mlp_ln = LayerNorm(n_state)
- # def forward(
- # self,
- # x: Tensor,
- # xa: Optional[Tensor] = None,
- # mask: Optional[Tensor] = None,
- # kv_cache: Optional[dict] = None,
- # ):
- # x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
- # if self.cross_attn:
- # x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
- # x = x + self.mlp(self.mlp_ln(x))
- # return x
- # class AudioEncoder(nn.Module):
- # def __init__(
- # self,
- # n_mels: int,
- # n_ctx: int,
- # n_state: int,
- # n_head: int,
- # n_layer: int,
- # output_dim: int = 512,
- # avg_pool: bool = True,
- # add_audio_bos_eos_token: bool = True,
- # **kwargs,
- # ):
- # super().__init__()
- # self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
- # self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
- # # self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
- # self.positional_embedding = nn.Parameter(sinusoids(n_ctx, n_state), requires_grad=False)
- # self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
- # [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
- # )
- # self.ln_post = LayerNorm(n_state)
- # if avg_pool:
- # self.avg_pooler = nn.AvgPool1d(2, stride=2)
- # else:
- # self.avg_pooler = None
- # self.proj = nn.Linear(n_state, output_dim)
- # if add_audio_bos_eos_token:
- # self.audio_bos_eos_token = nn.Embedding(2, output_dim)
- # else:
- # self.audio_bos_eos_token = None
- # self.output_dim = output_dim
- # self.n_head = n_head
- # def forward(self, x: Tensor, padding_mask: Tensor = None, audio_lengths: Tensor = None):
- # """
- # x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
- # the mel spectrogram of the audio
- # """
- # x = x.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
- # if audio_lengths is not None:
- # input_mel_len = audio_lengths[:, 0] * 2
- # max_mel_len_in_batch = input_mel_len.max()
- # x = x[:, :, :max_mel_len_in_batch]
- # x = F.gelu(self.conv1(x))
- # x = F.gelu(self.conv2(x))
- # x = x.permute(0, 2, 1) # B, L, D
- # bsz = x.size(0)
- # src_len = x.size(1)
- # self.input_positional_embedding = self.positional_embedding[:src_len]
- # assert (
- # x.shape[1:] == self.input_positional_embedding.shape
- # ), f"incorrect audio shape: {x.shape[1:], self.input_positional_embedding.shape}"
- # x = (x + self.input_positional_embedding).to(x.dtype)
- # if padding_mask is not None:
- # padding_mask = padding_mask.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
- # batch_src_len = padding_mask.size(1)
- # x = x[:, :batch_src_len, :]
- # padding_mask = padding_mask.view(bsz, -1, batch_src_len)
- # padding_mask_ = padding_mask.all(1)
- # x[padding_mask_] = 0
- # key_padding_mask = (
- # padding_mask_.view(bsz, 1, 1, batch_src_len)
- # .expand(-1, self.n_head, -1, -1)
- # .reshape(bsz, self.n_head, 1, batch_src_len)
- # )
- # new_padding_mask = torch.zeros_like(key_padding_mask, dtype=x.dtype)
- # padding_mask = new_padding_mask.masked_fill(key_padding_mask, float("-inf"))
- # for block in self.blocks:
- # x = block(x, mask=padding_mask)
- # if self.avg_pooler:
- # x = x.permute(0, 2, 1)
- # x = self.avg_pooler(x)
- # x = x.permute(0, 2, 1)
- # x = self.ln_post(x)
- # x = self.proj(x)
- # if self.audio_bos_eos_token is not None:
- # bos = self.audio_bos_eos_token.weight[0][None, :]
- # eos = self.audio_bos_eos_token.weight[1][None, :]
- # else:
- # bos, eos = None, None
- # return x, bos, eos
- # def encode(
- # self,
- # input_audios: Tensor,
- # input_audio_lengths: Tensor,
- # audio_span_tokens: List,
- # ):
- # real_input_audio_lens = input_audio_lengths[:, 0].tolist()
- # max_len_in_batch = max(real_input_audio_lens)
- # padding_mask = torch.ones([input_audios.size(0), max_len_in_batch]).to(
- # dtype=self.conv1.weight.dtype, device=self.conv1.weight.device
- # )
- # for index in range(len(input_audios)):
- # padding_mask[index, : input_audio_lengths[index][0].item()] = 0
- # x, bos, eos = self(input_audios, padding_mask, input_audio_lengths)
- # output_audios = []
- # for i in range(len(audio_span_tokens)):
- # audio_span = audio_span_tokens[i]
- # audio = x[i][: audio_span - 2]
- # if bos is not None:
- # audio = torch.concat([bos, audio, eos])
- # assert len(audio) == audio_span
- # output_audios.append(audio)
- # return output_audios
- # class AudioModel(torch.nn.Module):
- # def __init__(self, config):
- # super(AudioModel, self).__init__()
- # self.config = config
- # self.audio = AudioEncoder(**config.audio_config.to_dict())
- # def forward(self, audio_info):
- # audios = audio_info["input_audios"][0]
- # input_audio_lengths = audio_info["input_audio_lengths"][0]
- # audio_span_tokens = audio_info["audio_span_tokens"][0]
- # audios_features = self.audio.encode(audios, input_audio_lengths, audio_span_tokens)
- # return audios_features
- # # endregion
- # def get_max_megrezo_image_tokens(ctx: InputContext):
- # hf_config = ctx.get_hf_config()
- # return getattr(hf_config, "query_num", 64) * 10
- # def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
- # return SequenceData.from_prompt_token_counts((0, seq_len))
- # def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig, num_images: int):
- # width = height = hf_config.vision_config.image_size
- # imgs = [MegrezORawImageInput(image=Image.new("RGB", (width, height), color=0)) for _ in range(num_images)]
- # return {"image": imgs}
- # def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]):
- # hf_config = ctx.get_hf_config()
- # num_images = mm_counts["image"]
- # seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
- # mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images) # skip audio for now
- # return (seq_data, mm_data)
- # def input_processor_for_megrezo(ctx: InputContext, inputs: DecoderOnlyInputs):
- # multi_modal_data = inputs.get("multi_modal_data")
- # if multi_modal_data is None or ("image" not in multi_modal_data and "audio" not in multi_modal_data):
- # return inputs
- # model_config = ctx.model_config
- # tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=model_config.trust_remote_code)
- # processor = cached_get_processor(model_config.model, trust_remote_code=model_config.trust_remote_code)
- # prompt = inputs.get("prompt")
- # token_ids = inputs.get("prompt_token_ids")
- # if prompt is None:
- # prompt = tokenizer.decode(token_ids)
- # images = multi_modal_data.get("image")
- # audios = multi_modal_data.get("audio")
- # prompt, multimodal_inputs = processor.process_multimodal_inputs(
- # prompt,
- # images=images,
- # audios=audios,
- # return_tensors="pt",
- # )
- # text_encodings = tokenizer(
- # prompt,
- # return_tensors="pt",
- # padding=True,
- # padding_side="left",
- # )
- # encodings = processor.merge_encodings(text_encodings, multimodal_inputs)
- # data = processor.data_collator([encodings])
- # new_prompt = tokenizer.decode(data["input_ids"][0])
- # new_multi_modal_data = {
- # "image": data["image_encoding"],
- # "audio": data["audio_encoding"],
- # }
- # return token_inputs(
- # prompt_token_ids=data["input_ids"][0],
- # prompt=new_prompt,
- # multi_modal_data=new_multi_modal_data,
- # )
- # def input_mapper_for_megrezo(ctx: InputContext, data: object):
- # return MultiModalInputs(data)
- # @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_megrezo)
- # @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_megrezo)
- # @MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", 3000)
- # @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_megrezo_image_tokens)
- # @INPUT_REGISTRY.register_input_processor(input_processor_for_megrezo)
- # class MegrezOModel(nn.Module, VllmModelForTextGeneration, SupportsMultiModal, SupportsPP):
- # packed_modules_mapping = {
- # "qkv_proj": ["q_proj", "k_proj", "v_proj"],
- # "gate_up_proj": ["gate_proj", "up_proj"],
- # }
- # def __init__(
- # self,
- # config: PretrainedConfig,
- # multimodal_config: MultiModalConfig,
- # cache_config: Optional[CacheConfig] = None,
- # quant_config: Optional[QuantizationConfig] = None,
- # ):
- # super().__init__()
- # # All MiniCPM-V models disable `tie_word_embeddings` but
- # # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
- # # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
- # # and config class
- # self.config = config
- # self.multimodal_config = multimodal_config
- # self.llm = self.init_llm(config, cache_config, quant_config, prefix="model")
- # self.vision = self.init_vision_module(config, quant_config, prefix="vpm")
- # param_dtype = torch.get_default_dtype()
- # self.vision.to(dtype=param_dtype)
- # self.audio = self.init_audio_module(config, quant_config)
- # self.audio.to(dtype=param_dtype)
- # self.vision_dim = self.vision.embeddings.embed_dim
- # self.embed_dim = self.config.hidden_size
- # self.resampler = self.init_resampler(
- # self.embed_dim, self.vision_dim, quant_config=quant_config, prefix="vision.resampler"
- # )
- # self.resampler.to(device="cuda", dtype=param_dtype)
- # self.lm_head = ParallelLMHead(
- # config.vocab_size, config.hidden_size, quant_config=quant_config, prefix="llm.lm_head"
- # )
- # self.logits_processor = LogitsProcessor(config.vocab_size)
- # self.sampler = Sampler()
- # self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors
- # self._called_cnt = 0
- # def get_vision_hidden_states(
- # self,
- # pixel_values,
- # tgt_sizes,
- # patch_attn_mask,
- # ) -> torch.Tensor:
- # device = self.vision.embeddings.position_embedding.weight.device
- # dtype = self.vision.embeddings.position_embedding.weight.dtype
- # pixel_values = torch.stack([(image.to(device) - 127.5) / 127.5 for image in pixel_values]).type(dtype)
- # vision_embedding = self.vision(
- # pixel_values.type(dtype),
- # patch_attention_mask=patch_attn_mask,
- # tgt_sizes=tgt_sizes,
- # )
- # return self.resampler(vision_embedding, tgt_sizes)
- # def compose_embeddings(self, mini_batch):
- # input_ids = mini_batch["input_ids"]
- # image_encoding = mini_batch.get("image_encoding")
- # audio_encoding = mini_batch.get("audio_encoding")
- # embeddings_text = self.llm.model.embed_tokens(input_ids)
- # input_embeds = embeddings_text
- # if image_encoding:
- # pixel_values = image_encoding["pixel_values"][0]
- # tgt_sizes = image_encoding["tgt_sizes"][0]
- # patch_attention_mask = image_encoding["patch_attention_mask"][0]
- # bounds_image = image_encoding["image_bounds"][0]
- # device = self.vision.embeddings.position_embedding.weight.device
- # dtype = self.vision.embeddings.position_embedding.weight.dtype
- # embeddings_image = self.get_vision_hidden_states(
- # pixel_values.to(device, dtype),
- # tgt_sizes,
- # patch_attention_mask.to(device),
- # )
- # input_embeds = insert_image_embeddings(embeddings_text, embeddings_image, bounds_image)
- # if audio_encoding:
- # embeddings_audio = self.audio(audio_encoding)
- # bounds_audio = audio_encoding["audio_bounds"][0]
- # input_embeds = insert_audio_embeddings(embeddings_text, embeddings_audio, bounds_audio)
- # return input_embeds
- # def _parse_inputs(self, input_ids: torch.Tensor, **kwargs):
- # if kwargs.get("pixel_values") is not None:
- # image_encoding = {
- # "pixel_values": kwargs.get("pixel_values"),
- # "tgt_sizes": kwargs.get("tgt_sizes"),
- # "patch_attention_mask": kwargs.get("patch_attention_mask"),
- # "image_bounds": kwargs.get("image_bounds"),
- # }
- # else:
- # image_encoding = None
- # if kwargs.get("input_audios") is not None:
- # audio_encoding = {
- # "input_audios": kwargs.get("input_audios"),
- # "input_audio_lengths": kwargs.get("input_audio_lengths"),
- # "audio_span_tokens": kwargs.get("audio_span_tokens"),
- # "audio_bounds": kwargs.get("audio_bounds"),
- # }
- # else:
- # audio_encoding = None
- # return {
- # "input_ids": input_ids,
- # "image_encoding": image_encoding,
- # "audio_encoding": audio_encoding,
- # }
- # def forward(
- # self,
- # input_ids: torch.Tensor,
- # positions: torch.Tensor,
- # kv_caches: List[torch.Tensor],
- # attn_metadata: AttentionMetadata,
- # intermediate_tensors: Optional[IntermediateTensors] = None,
- # **kwargs: Any,
- # ) -> torch.Tensor:
- # if intermediate_tensors is not None:
- # embeddings = None
- # else:
- # mini_batch = self._parse_inputs(input_ids, **kwargs)
- # embeddings = self.compose_embeddings(mini_batch)
- # # always pass the input via `inputs_embeds`
- # # to make sure the computation graph is consistent
- # # for `torch.compile` integration
- # input_ids = None
- # output = self.llm(
- # input_ids=input_ids,
- # positions=positions,
- # kv_caches=kv_caches,
- # attn_metadata=attn_metadata,
- # intermediate_tensors=intermediate_tensors,
- # inputs_embeds=embeddings,
- # )
- # self._called_cnt += 1
- # return output
- # def compute_logits(
- # self,
- # hidden_states: torch.Tensor,
- # sampling_metadata: SamplingMetadata,
- # ) -> Optional[torch.Tensor]:
- # logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
- # return logits
- # def sample(
- # self,
- # logits: torch.Tensor,
- # sampling_metadata: SamplingMetadata,
- # ) -> Optional[SamplerOutput]:
- # next_tokens = self.sampler(logits, sampling_metadata)
- # return next_tokens
- # def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- # stacked_params_mapping = [
- # # (param_name, shard_name, shard_id)
- # (".qkv_proj", ".q_proj", "q"),
- # (".qkv_proj", ".k_proj", "k"),
- # (".qkv_proj", ".v_proj", "v"),
- # (".gate_up_proj", ".gate_proj", 0),
- # (".gate_up_proj", ".up_proj", 1),
- # ]
- # keys_to_modify_mapping = {
- # "llm.lm_head": "lm_head",
- # "vision.resampler": "resampler",
- # }
- # params_dict = dict(self.named_parameters())
- # for name, loaded_weight in weights:
- # for key_to_modify, new_key in keys_to_modify_mapping.items():
- # if key_to_modify in name:
- # name = name.replace(key_to_modify, new_key)
- # if "rotary_emb.inv_freq" in name:
- # continue
- # if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
- # # Models trained using ColossalAI may include these tensors in
- # # the checkpoint. Skip them.
- # continue
- # # if "audio.positional_embedding" in name:
- # # continue
- # for param_name, weight_name, shard_id in stacked_params_mapping:
- # if weight_name not in name:
- # continue
- # name = name.replace(weight_name, param_name)
- # # Skip loading extra bias for GPTQ models.
- # if name.endswith(".bias") and name not in params_dict:
- # continue
- # if is_pp_missing_parameter(name, self):
- # continue
- # if name in params_dict:
- # param = params_dict[name]
- # weight_loader = param.weight_loader
- # weight_loader(param, loaded_weight, shard_id)
- # else:
- # print(f"Skipping loading of {name}")
- # break
- # else:
- # # Skip loading extra bias for GPTQ models.
- # if name.endswith(".bias") and name not in params_dict:
- # continue
- # if name is None:
- # continue
- # if is_pp_missing_parameter(name, self):
- # continue
- # if name in params_dict:
- # param = params_dict[name]
- # weight_loader = getattr(param, "weight_loader", default_weight_loader)
- # weight_loader(param, loaded_weight)
- # else:
- # print(f"Skipping loading of {name}")
- # def get_mm_mapping(self) -> MultiModelKeys:
- # """
- # Get the module prefix in multimodal models
- # """
- # return MultiModelKeys.from_string_field(language_model="llm", connector="resampler", tower_model="vpm")
- # def init_llm(
- # self,
- # config: PretrainedConfig,
- # cache_config: Optional[CacheConfig] = None,
- # quant_config: Optional[QuantizationConfig] = None,
- # prefix: str = "",
- # ) -> nn.Module:
- # return LLMWrapper(
- # LlamaModel(
- # config,
- # cache_config=cache_config,
- # quant_config=quant_config,
- # prefix=prefix,
- # ),
- # name=prefix,
- # )
- # def init_audio_module(
- # self,
- # config: PretrainedConfig,
- # quant_config: Optional[QuantizationConfig],
- # prefix: str = "",
- # ) -> nn.Module:
- # return AudioModel(config)
- # def init_vision_module(
- # self,
- # config: PretrainedConfig,
- # quant_config: Optional[QuantizationConfig],
- # prefix: str = "",
- # ) -> nn.Module:
- # model = LLMWrapper(
- # Idefics2VisionTransformer(config.vision_config),
- # name=prefix,
- # )
- # if self.config.drop_vision_last_layer:
- # model.encoder.layers = model.encoder.layers[:-1]
- # return model
- # def init_resampler(
- # self,
- # embed_dim: int,
- # vision_dim: int,
- # quant_config: Optional[QuantizationConfig] = None,
- # prefix: str = "",
- # ) -> nn.Module:
- # resampler = Resampler(
- # num_queries=self.config.query_num,
- # embed_dim=embed_dim,
- # num_heads=embed_dim // 128,
- # kv_dim=vision_dim,
- # quant_config=quant_config,
- # prefix=prefix,
- # )
- # return resampler
|