Squid / modeling_dolphin.py
zackli4ai's picture
upload registration code
c516834
raw
history blame
29.4 kB
from transformers import (
AutoTokenizer, AutoModelForCausalLM, AutoConfig, logging
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.utils import (ModelOutput)
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2PreTrainedModel, Qwen2Model, Qwen2RMSNorm
)
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union
import warnings
from dataclasses import dataclass
from torch.nn import CrossEntropyLoss
from .configuration_dolphin import encoder_config_dict, Qwen2Config
CONTEXT_EMB = 896 # Qwen 0.7B has dimension of 896
HIDDEN_EMB = 3584 # Qwen 7B has dimension of 3584
warnings.filterwarnings("ignore")
MEM_SIZE = 32
logger = logging.get_logger(__name__)
@dataclass
class DolphinMemoryOutput(ModelOutput):
memory_states: Optional[torch.FloatTensor] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class Qwen2ForMemoryOutput(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Qwen2Model(config)
self.model.config.pad_token_id = self.model.config.eos_token_id
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError(
"Cannot handle batch sizes > 1 if no padding token is defined."
)
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (
torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1)
)
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(hidden_states.device)
else:
sequence_lengths = -1
# if sequence_lengths != -1:
# assert (sequence_lengths > MEMORY_SIZE).all(), "All sequences must be longer than MEMORY_SIZE"
MEMORY_SIZE = 32
batch_range = torch.arange(batch_size, device=hidden_states.device)
start_indices = sequence_lengths - MEMORY_SIZE
# print(sequence_lengths)
# print(torch.arange(MEMORY_SIZE, device=hidden_states.device)[None, :] + start_indices[:, None])
memory_states = hidden_states[
batch_range[:, None],
torch.arange(MEMORY_SIZE, device=hidden_states.device)[None, :]
+ start_indices[:, None],
]
return DolphinMemoryOutput(
memory_states=memory_states,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
class Projector(nn.Module):
def __init__(self, context_dim: int, hidden_dim: int, projection_cls="linear"):
super().__init__()
self.projection_cls = projection_cls
if projection_cls == "linear":
self.context_projection = nn.Linear(context_dim, hidden_dim)
elif projection_cls == "mlp":
dim_projection = hidden_dim
depth = 2
layers = [
nn.Linear(context_dim, dim_projection),
]
for _ in range(1, depth):
layers.extend(
[
nn.GELU(),
nn.Linear(dim_projection, dim_projection),
]
)
self.context_projection = nn.Sequential(*layers)
else:
raise ValueError(f"Projection class {projection_cls} not supported")
def forward(self, x):
if self.projection_cls == "linear":
return self.context_projection(x)
for layer in self.context_projection:
x = layer(x)
return x
class ContextEmbd(nn.Module):
def __init__(
self, config, context_dim, hidden_dim, MEM_SIZE=32, torch_dtype=torch.bfloat16
):
super().__init__()
self.encoder = Qwen2ForMemoryOutput(config).to(torch_dtype)
self.projector = Projector(context_dim, hidden_dim).to(torch_dtype)
self.MEM_SIZE = MEM_SIZE
def forward(self, context_input_ids, context_attention_mask=None):
memory_slot = self.encoder(
context_input_ids, context_attention_mask, output_hidden_states=True
).memory_states
# now project the memory slot into token space
return self.projector(memory_slot)
class DolphinModel(Qwen2PreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
Args:
config: DolphinModel
"""
def __init__(self, config: Qwen2Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
Qwen2DecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self._attn_implementation = config._attn_implementation
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
if not config.encoder_config:
raise ValueError("Please provide the encoder config")
self.encoder_config = Qwen2Config.from_dict(config.encoder_config)
self.context_encoder = ContextEmbd(
config=self.encoder_config, context_dim=CONTEXT_EMB, hidden_dim=HIDDEN_EMB
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# We assume there is only on context, and this function can only support one context
def get_token_embebddings_context(
self,
input_ids: torch.LongTensor,
context_input_ids: torch.LongTensor,
context_attention_mask: torch.LongTensor,
) -> torch.FloatTensor:
# The size is batch_size x memory_size x hidden_dim
context_emb = self.context_encoder(context_input_ids, context_attention_mask)
# Create embeddings for regular tokens
embed_input_ids = input_ids.clone()
embed_input_ids[embed_input_ids < 0] = (
0 # Replace negative values with 0 for embedding
)
hidden_states = self.embed_tokens(embed_input_ids)
batch_size, seq_len, hidden_dim = hidden_states.shape
_, memory_size, _ = context_emb.shape
# Find the start positions of -1 sequences
mask = input_ids == -1
starts = torch.where(mask[:, :-1] < mask[:, 1:])[1]
# Replace -1 spans with context embeddings
for i in range(batch_size):
for start in starts:
if start + memory_size <= seq_len:
hidden_states[i, start : start + memory_size] = context_emb[i]
return hidden_states
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
if inputs_embeds is None:
if context_input_ids is not None:
assert (
context_attention_mask is not None
), "You have to provide the context_attention_mask"
inputs_embeds = self.get_token_embebddings_context(
input_ids, context_input_ids, context_attention_mask
)
else:
inputs_embeds = self.embed_tokens(input_ids)
# We need to update the attention mask if the attention mask is provided
# if attention_mask is not None:
# MEMORY_SIZE = 32
# batch_size = inputs_embeds.shape[0]
# attention_mask = torch.cat(
# (torch.ones(batch_size, MEMORY_SIZE, device=inputs_embeds.device), attention_mask),
# dim=1,
# ).to(attention_mask.dtype).to(attention_mask.device)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not using_static_cache
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError(
"Custom 4D attention mask should be passed in inverted form with max==0`"
)
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(
input_tensor.shape[0], 1, -1, -1
)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length]
+ attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
return causal_mask
class DolphinForCausalLM(Qwen2PreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = DolphinModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
context_input_ids=context_input_ids,
context_attention_mask=context_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
use_cache=True,
**kwargs,
):
past_length = 0
# Omit tokens covered by past_key_values
if past_key_values is not None:
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = (
cache_position[0]
if cache_position is not None
else past_key_values.get_seq_length()
)
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = (
past_length
if max_cache_length is None
else torch.min(max_cache_length, past_length)
)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if (
attention_mask is not None
and attention_mask.shape[1] > input_ids.shape[1]
):
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
input_length = (
position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
)
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + input_length, device=input_ids.device
)
elif use_cache:
cache_position = cache_position[-input_length:]
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"cache_position": cache_position,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
),
)
return reordered_past
if __name__ == "__main__":
config = Qwen2Config(encoder_config=encoder_config_dict)
dolphin_model = DolphinModel(config)
# AutoConfig.register("dolphin", Qwen2Config)
AutoModelForCausalLM.register(Qwen2Config, DolphinForCausalLM)
tokenizer = AutoTokenizer.from_pretrained('nexa-collaboration/dolphin_instruct_1M_0805', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('nexa-collaboration/dolphin_instruct_1M_0805', trust_remote_code=True)