dummy_m4 / m4 /models /custom_modules.py
ysharma's picture
ysharma HF staff
Duplicate from HuggingFaceM4/m4-dialogue
e7d3e35
import os
import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.utils import ContextManagers
from m4.training.setup_vision_model import vision_model_name_to_model
from m4.training.utils import (
deepspeed_zero_init_disabled_context_manager,
is_deepspeed_zero_init_enabled,
load_state_dict_into_model,
)
# from pathlib import Path
class VLOOMPreTrainedModelBase(PreTrainedModel):
# The problem we are trying to solve is 2 nested zero.Init thanks to fetching from_pretrained(vision_model_name)
# and then one more zero.Init to override from_pretrained(vision_model_name) once again as it was done in the original - this breaks deepspeed zero3 w/ zero.Init
# So one solution is this:
# a. replace from_pretrained(vision_model_name) with from_config(vision_model_name) while hacking to disable zero.Init context
# b. instead of straight replacement of model.vision_model = from_pretrained(vision_model_name) when it gets updated, we first do from_pretrained(vision_model_name) and then update the existing model with weights using the already zero.Init'ed pre-sharded weights
#
# there are a few variations to get_vision_model_from_config - all need to bypass zero.Init under zero3
# 1. one variant is to hack into accelerate's deepspeed_plugin and turn off zero.Init while loading the vision model
# 2. the other variant is to override _from_config method with our version that doesn't do zero.Init
@classmethod
def override_vision_model(cls, model, vision_model_name, vision_model_params, torch_dtype):
# 1. fetch the pretrained vision model w/o zero.Init
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
vision_model = AutoModel.from_pretrained(vision_model_name, **vision_model_params, torch_dtype=torch_dtype)
# this extracts the desired submodule if the part we want is nested (e.g. as in clip)
real_vision_model = vision_model_name_to_model(vision_model_name, vision_model)
# 2. now override the weights already sharded by zero.Init with the weights from the real_vision_model
# by gradually gathering sharded weights and replacing with new weights
if is_deepspeed_zero_init_enabled():
state_dict = real_vision_model.state_dict()
load_state_dict_into_model(model.vision_model, state_dict, start_prefix="")
else:
model.vision_model = real_vision_model
@classmethod
def from_config(cls, config, **kwargs):
# torch_dtype is crucial for using the minimal amount of memory at load time
torch_dtype = kwargs.get("torch_dtype", None)
vision_model_name = config.vision_model_name
vision_model_params = eval(config.vision_model_params)
# 1. create an uninitialized vision_model to insert into the main model.
# It has to be created outside lm's `from_pretrained` and w/o zero.Init so that zero3+zero.Init works
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
vision_model_config = AutoConfig.from_pretrained(vision_model_name, **vision_model_params)
vision_model_from_config = AutoModel.from_config(vision_model_config, torch_dtype=torch_dtype)
# this extracts the desired submodule if the part we want is nested (e.g. as in clip)
kwargs["vision_model"] = vision_model_name_to_model(vision_model_name, vision_model_from_config)
# 2. create the main class's model, passing the uninitialized vision_model to it
model = cls(config, **kwargs)
return model
@classmethod
def from_pretrained_models(cls, *args, **kwargs):
"""
Use this method when creating a new vloom model that hasn't been yet trained and it'll be
composed of 2 pre-trained models - hence `pretrained_models`.
"""
return cls.from_pretrained(*args, **kwargs, new_model=True)
@classmethod
def from_pretrained(cls, *model_args, is_resume=False, new_model=False, **kwargs):
"""
Use this method when loading an already pretrained vloom model - either from a checkpoint or from hub.
For creating an untrained model use `pretrained_models` instead.
"""
is_untrained_vloom_model = False
is_pretrained_vloom_model_resumed = False
is_pretrained_vloom_model_from_hub_or_path = False
# we have 3 use cases:
# 1. is_untrained_vloom_model - a totally new vloom model
# 2. is_pretrained_vloom_model_resumed - a pretrained vloom model being resumed from a
# checkpoint (instantiate a random empty model in this case)
# 3. is_pretrained_vloom_model_from_hub_or_path - a pretrained vloom model loaded from hub or local path
if new_model:
is_untrained_vloom_model = True
elif is_resume:
is_pretrained_vloom_model_resumed = True
else:
is_pretrained_vloom_model_from_hub_or_path = True
# torch_dtype is crucial for using the minimal amount of memory at load time
torch_dtype = kwargs.get("torch_dtype", None)
# config is:
# 1. either not passed and then we use the model's default config (used by tests)
# 2. passed and in which case it's one of:
# 2a. `PretrainedConfig` (a new m4 model)
# 2b. path to a json config (an already pretrained m4 model, usually resumed training)
config = kwargs.get("config", None)
if config is None:
config = cls.config_class.from_pretrained(*model_args, **kwargs, return_unused_kwargs=False)
elif not isinstance(config, PretrainedConfig):
# adapted from https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/modeling_utils.py#L1920
assert isinstance(config, os.PathLike)
config_path = str(config)
config = cls.config_class.from_pretrained(
config_path,
return_unused_kwargs=False,
**kwargs,
)
vision_model_name = config.vision_model_name
vision_model_params = eval(config.vision_model_params)
# 1. create an uninitialized vision_model to insert into the main model.
# It has to be created outside lm's `from_pretrained` and w/o zero.Init so that zero3+zero.Init works
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
vision_model_config = AutoConfig.from_pretrained(vision_model_name, **vision_model_params)
vision_model_from_config = AutoModel.from_config(vision_model_config, torch_dtype=torch_dtype)
# this extracts the desired submodule if the part we want is nested (e.g. as in clip)
kwargs["vision_model"] = vision_model_name_to_model(vision_model_name, vision_model_from_config)
# 2. create the vloom model
if is_untrained_vloom_model or is_pretrained_vloom_model_from_hub_or_path:
model = super().from_pretrained(*model_args, **kwargs)
elif is_pretrained_vloom_model_resumed:
# in the case of resume under deepspeed we create an empty model, and get deepspeed
# to load the weights from the checkpoint
# but not all models have these keys so handle the case they don't have them
_ = kwargs.pop("config", None)
model = super().from_pretrained(None, config=config, state_dict={}, **kwargs)
# 3. if is_untrained_vloom_model, now override the uninitialized vision_model with one with pretrained weights
if is_untrained_vloom_model:
cls.override_vision_model_wrapper(model, config, vision_model_name, vision_model_params, torch_dtype)
return model
class DecoupledEmbedding(nn.Embedding):
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
"""
Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings.
In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained.
If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
"""
def __init__(
self,
num_embeddings,
num_additional_embeddings,
embedding_dim,
partially_freeze=False,
device=None,
dtype=None,
padding_idx=None,
**kwargs,
) -> None:
"""
num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`.
partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen.
Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these.
"""
if padding_idx is not None and padding_idx > num_embeddings:
raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}")
super().__init__(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
padding_idx=padding_idx,
**kwargs,
)
self.num_embeddings = num_embeddings
self.padding_idx = padding_idx
self.num_additional_embeddings = num_additional_embeddings
self.partially_freeze = partially_freeze
if partially_freeze:
self.weight.requires_grad_(False)
if self.num_additional_embeddings > 0:
self.additional_embedding = nn.Embedding(
num_embeddings=self.num_additional_embeddings,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
)
def forward(self, input_ids):
"""
we have 2 embeddings, with different indices - one pretrained self.weight and another
self.additional_embedding.weight that is being trained.
in order to make a lookup of the input ids, we:
1. find out the indices of the entries belonging to the 2nd embedding
2. extract those values while subtracting the size of the first embedding (num_embeddings),
since the 2nd embedding starts from 0 and not num_embeddings
3. perform the 2nd embedding lookup
4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
5. perform the 1st embedding lookup
6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
note: for the 1st embedding lookup we could have looked up only the low indices and not do
the padding, but then we have to create a new tensor and populate it with 2 tensors that are
spread out across various indices - i.e. not a simple concat - I haven't benchmarked the
complex case if it's any faster, given that seqlens are usually relatively short it's
probably not faster or if faster not by much - but might be a good idea to measure.
"""
if self.num_additional_embeddings == 0:
return F.embedding(input_ids, self.weight)
# Clone so that we don't modify the original input_ids later on
input_ids = input_ids.clone()
additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
input_ids_additional_vocab = input_ids[additional_vocab_indices]
additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
input_ids[additional_vocab_indices] = 0
full_vector = F.embedding(input_ids, self.weight)
# overwrite the records with high indices
full_vector[additional_vocab_indices] = additional_embeddings
return full_vector
def extra_repr(self) -> str:
return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
self.num_embeddings,
self.num_additional_embeddings,
self.embedding_dim,
self.partially_freeze,
)
@classmethod
def from_pretrained(cls, embeddings, freeze=True, **kwargs):
raise NotImplementedError
class DecoupledLinear(nn.Linear):
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
"""
Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters.
In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, then it will create `out_additional_features * in_features` additional parameters that are always trained.
If `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
"""
def __init__(
self,
in_features: int,
out_features: int,
out_additional_features: int = 0,
bias: bool = True,
partially_freeze: bool = True,
device=None,
dtype=None,
) -> None:
"""
out_additional_features: int. Number of additional trainable dimensions. Only makes sense when `partially_freeze=True`.
partially_freeze: bool. If True, the regular `weight` will be frozen and extra parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear.
"""
super().__init__(in_features, out_features, bias, device, dtype)
self.out_additional_features = out_additional_features
self.partially_freeze = partially_freeze
self.in_features = in_features
self.out_features = out_features
if partially_freeze:
self.weight.requires_grad_(False)
if bias:
self.bias.requires_grad_(False)
if out_additional_features > 0:
self.additional_fc = nn.Linear(
in_features=in_features,
out_features=out_additional_features,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = F.linear(input, self.weight, self.bias)
if self.out_additional_features > 0:
additional_features = F.linear(input, self.additional_fc.weight, self.additional_fc.bias)
output = torch.cat((output, additional_features), -1)
return output
def extra_repr(self) -> str:
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
self.in_features,
self.out_features,
self.out_additional_features,
self.bias is not None,
self.partially_freeze,
)
if __name__ == "__main__":
emb = DecoupledEmbedding(num_embeddings=10, num_additional_embeddings=3, embedding_dim=5, partially_freeze=True)
for n, p in emb.named_parameters():
print(n, p.requires_grad)
idx = torch.tensor([[11, 1, 3]])
y = emb(idx)
loss = y.sum()
loss.backward()
print(emb.weight, emb.weight.grad)
print(emb.additional_embedding, emb.additional_embedding.grad)
lin = DecoupledLinear(in_features=3, out_features=4, out_additional_features=2, bias=True, partially_freeze=True)
for n, p in lin.named_parameters():
print(n, p.requires_grad)
x = torch.randn(12, 3)
y = lin(x)
loss = y.sum()
loss.backward()
print("Weight w and grad:", lin.weight, lin.weight.grad)
print("bias w and grad:", lin.bias, lin.bias.grad)
print("additional_fc.weight w and grad:", lin.additional_fc.weight, lin.additional_fc.weight.grad)
print("additional_bias w and grad:", lin.additional_fc.bias, lin.additional_fc.bias.grad)