cde-small-v1 / model.py
jxm's picture
Upload DatasetTransformer
7d508f1 verified
raw
history blame
26.3 kB
from typing import Dict, Optional, Union
import copy
import torch
import torch.nn as nn
import transformers
from cde.lib.dist import print0
from cde.lib.tensor import mean_pool, mean_pool_3d, mean_pool_weighted, last_token_pool
from cde.lib import load_embedder_and_tokenizer, ContextualModelConfig
def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None:
if hasattr(model, 'transformer'):
if hasattr(model.transformer, 'h'):
# gpt2
model.transformer.h = model.transformer.h[:n_layers]
else:
model.transformer.layer = model.transformer.layer[:n_layers]
elif hasattr(model, 'encoder'):
if hasattr(model.encoder, 'layers'):
model.encoder.layers = model.encoder.layers[:n_layers]
else:
model.encoder.layer = model.encoder.layer[:n_layers]
else:
raise RuntimeError(f"unknown how to limit layers of model {type(model)}")
def disable_dropout(model: torch.nn.Module):
dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)]
for m in dropout_modules:
m.p = 0.0
print0(
f"Disabled {len(dropout_modules)} dropout modules from model type {type(model)}"
)
def disable_causality(model: torch.nn.Module):
disabled_modules = 0
for m in model.modules():
if hasattr(m, "is_causal"):
m.is_causal = False
disabled_modules += 1
print0(
f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}"
)
class ContextualModelMixin(nn.Module):
@property
def num_corpus_tokens(self) -> int:
return self.transductive_corpus_size * self.transductive_tokens_per_document
def contextual_init(self):
self.n_soft_prompt = 8
self.prompt_projection = torch.nn.Sequential(
torch.nn.Linear(self.hidden_size, self.hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(self.hidden_size, self.hidden_size * self.n_soft_prompt)
)
self.transductive_corpus_size = vars(self.config).get("transductive_corpus_size", 1)
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1)
self.randomize_dataset_sequence_order = True
self.sequence_dropout_prob = vars(self.config).get("transductive_sequence_dropout_prob", 0.0)
if self.sequence_dropout_prob > 0.0:
self.sequence_dropout_null_embedding = torch.nn.Parameter(
torch.randn(self.hidden_size) * 0.01,
requires_grad = True
)
self.output_projection = torch.nn.Sequential(
torch.nn.Linear(self.hidden_size, self.hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(self.hidden_size, self.hidden_size)
)
def _prepare_dataset_embeddings(
self,
input_ids: torch.Tensor, dataset_embeddings: torch.Tensor,
null_dataset_embedding: bool = False,
) -> torch.Tensor:
if not isinstance(dataset_embeddings, torch.Tensor):
dataset_embeddings = torch.tensor(dataset_embeddings)
if len(dataset_embeddings.shape) == 2:
# Auto-expand for a batch.
dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d)
dataset_embeddings = dataset_embeddings.to(input_ids.device)
batch_size = input_ids.shape[0]
if (self.transductive_tokens_per_document > 1):
if self.training:
# Choose N random documents to fill our context window with.
# This logic is a little confusing but allows us to sample a
# different batch *per-document*
assert dataset_embeddings.shape[1] == self.transductive_tokens_per_document
R = torch.randint(
low=0,
high=len(dataset_embeddings),
size=(batch_size, self.config.transductive_corpus_size),
device=dataset_embeddings.device
)
# TODO make this deterministic somehow for evaluation?
dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size))
else:
dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size))
# print("reshaped to dataset_embeddings.shape =", dataset_embeddings.shape)
if dataset_embeddings.shape[1] > self.num_corpus_tokens:
# If too many dataset embeddings are passed in, just take the first N until
# we have the proper number.
dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :]
_, corpus_size, _hidden_size = dataset_embeddings.shape
if _ == 1:
# Auto-expand for a batch.
dataset_embeddings = dataset_embeddings.expand((batch_size, -1, -1))
if self.training and self.sequence_dropout_prob > 0.0:
sequence_dropout_mask = (
torch.rand((batch_size, corpus_size), device=dataset_embeddings.device) < self.sequence_dropout_prob
)
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
dataset_embeddings = torch.where(
sequence_dropout_mask[..., None], null_embeddings, dataset_embeddings
)
elif null_dataset_embedding:
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
dataset_embeddings = null_embeddings
# print(f"[ContextualModelMixin] dataset_embeddings.shape = {dataset_embeddings.shape}")
# backbone_max_seq_length = self.backbone.config.max_trained_positions
# assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model"
soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype)
soft_prompt = self.prompt_projection(soft_prompt).reshape((1, self.n_soft_prompt, self.hidden_size))
soft_prompt = soft_prompt.expand((len(dataset_embeddings), -1, -1)) # -> (b, 4+b, d) # soft_prompt.repeat((len(input_ids), 1, 1))
soft_prompt = torch.cat((dataset_embeddings, soft_prompt), dim=1)
# print(f"[ContextualModelMixin] soft_prompt.shape = {soft_prompt.shape}")
if self.training and self.randomize_dataset_sequence_order:
randomized_order = torch.stack(
[
torch.cat(
(
torch.randperm(corpus_size, device=soft_prompt.device),
torch.arange(self.n_soft_prompt, device=soft_prompt.device) + corpus_size
), dim=0)
for _ in range(batch_size)])
randomized_order = randomized_order.to(soft_prompt.device)
soft_prompt = soft_prompt.gather(1, randomized_order[..., None].expand_as(soft_prompt))
return soft_prompt
class BiEncoder(transformers.PreTrainedModel):
embedder: transformers.PreTrainedModel
def __init__(
self,
config, #: transformers.PreTrainedConfig,
):
super().__init__(config=config)
embedder, _ = load_embedder_and_tokenizer(
config.embedder,
)
if config.limit_layers:
print0(f"Limiting layers to {config.limit_layers}")
limit_layers(embedder, config.limit_layers)
self.embedder = embedder
# if ("t5" in embedder.config.model_type):
# print0(f"using torch.compile() on embedder of type `{embedder.config.model_type}`")
# self.embedder = torch.compile(self.embedder)
self.hidden_size = self.embedder.config.hidden_size
# Allow pooling to multiple tokens per document
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1)
self.mlp = torch.nn.Sequential(
torch.nn.Linear(self.hidden_size, self.hidden_size),
torch.nn.GELU(),
torch.nn.Linear(self.hidden_size, self.config.embedding_output_dim or self.hidden_size),
)
self.temp = config.logit_scale
if config.disable_dropout:
disable_dropout(self)
self.pooling_strategy = vars(config).get("pooling_strategy", "mean")
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
dataset_input_ids: Optional[torch.Tensor] = None,
dataset_attention_mask: Optional[torch.Tensor] = None,
token_type_ids = None,
output_hidden_states: bool = False,
) -> torch.Tensor:
"""
query_embedding (float torch.Tensor) - shape (batch_size, embedding_dim)
document_embeddings (float torch.Tensor) - shape (corpus_size, embedding_dim)
where the corpus_size >= batch_size and is structured like this:
[d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2]
for a corpus with three documents and two hard negatives per document
"""
# del dataset_input_ids
# del dataset_attention_mask
del token_type_ids
# from cde.lib.dist import get_rank
# tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
# if get_rank() == 0:
# breakpoint()
# torch.distributed.barrier()
outputs = (
self.embedder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
)
if self.transductive_tokens_per_document > 1:
document_embeddings = None
batch_size, seq_length, output_dim = outputs.shape
if seq_length % self.transductive_tokens_per_document != 0:
# Pad to nearest multiple
n_extra_embeds = self.transductive_tokens_per_document - (seq_length % self.transductive_tokens_per_document)
outputs = torch.cat(
(outputs, torch.zeros((batch_size, n_extra_embeds, output_dim), device=outputs.device)),
dim=1
)
attention_mask = torch.cat(
(attention_mask, torch.zeros((batch_size, n_extra_embeds), device=attention_mask.device)),
dim=1
)
seq_length += n_extra_embeds
print(f"Added {n_extra_embeds} padding tokens to input_ids and attention_mask")
# print("ftransductive_tokens_per_document {self.transductive_tokens_per_document} outputs.shape =", outputs.shape)
outputs = outputs.reshape(
(batch_size, self.transductive_tokens_per_document, seq_length // self.transductive_tokens_per_document, output_dim)
)
attention_mask = attention_mask.reshape((batch_size, self.transductive_tokens_per_document, -1))
document_embeddings = mean_pool_3d(outputs, attention_mask)
document_embeddings = document_embeddings.reshape((batch_size, self.transductive_tokens_per_document, output_dim))
else:
if self.pooling_strategy == "mean":
document_embeddings = mean_pool(outputs, attention_mask)
else:
document_embeddings = document_embeddings.max(dim=1)
output = self.mlp(document_embeddings)
if output_hidden_states:
return {
"hidden_states": outputs,
"pooled": output,
}
else:
return output
class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualModelMixin):
def __init__(
self,
config,
dataset_backbone: transformers.PreTrainedModel,
first_stage_hidden_size: int,
):
super().__init__(config=config)
self.backbone = dataset_backbone
self.backbone_hidden_size = self.backbone.config.hidden_size
self.hidden_size = first_stage_hidden_size # Input token size
self.contextual_init()
disable_causality(self.backbone)
self.input_ln = torch.nn.LayerNorm(
self.backbone_hidden_size,
eps=1e-5
)
# Override contextual init
self.output_projection = torch.nn.Sequential(
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size)
)
self._shift_rotary_embedding()
@property
def num_corpus_tokens(self) -> int:
return self.config.transductive_corpus_size * self.transductive_tokens_per_document
@property
def corpus_token_ratio(self) -> float:
# How many tokens from the first stage make one token in the second
# stage?
return self.backbone_hidden_size / self.hidden_size
def corpus_token_pad_size(self, n_tokens: int) -> int:
return self.hidden_size % self.backbone_hidden_size
def _shift_rotary_embedding(self) -> None:
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True)
# TODO: Can we do this for LLAMA?
print("Warning: Positional embedding disabling not implemented for LLAMA.")
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
dataset_embeddings: torch.Tensor,
output_hidden_states: bool = False,
null_dataset_embedding: bool = False,
) -> torch.Tensor:
soft_prompt = self._prepare_dataset_embeddings(
input_ids=input_ids,
dataset_embeddings=dataset_embeddings,
null_dataset_embedding=null_dataset_embedding,
)
# Reshape for this model.
# print("[DatasetConditionedAutoregressive] 1 -> soft_prompt.shape =", soft_prompt.shape)
num_soft_elements = torch.prod(torch.tensor(soft_prompt.shape[1:])).item()
soft_prompt = soft_prompt.reshape((soft_prompt.shape[0], num_soft_elements))
num_padding_elements = self.backbone_hidden_size - (num_soft_elements % self.backbone_hidden_size)
padding = torch.ones((soft_prompt.shape[0], num_padding_elements), device=soft_prompt.device)
soft_prompt = torch.cat((soft_prompt, padding), dim=1)
soft_prompt = soft_prompt.reshape(
(soft_prompt.shape[0], -1, self.backbone_hidden_size)
)
soft_prompt = self.input_ln(soft_prompt)
# print("[DatasetConditionedAutoregressive] 2 -> soft_prompt.shape =", soft_prompt.shape)
backbone_attention_mask = torch.ones(
soft_prompt.shape[0:2],
dtype=torch.long,
device=soft_prompt.device,
)
token_embeddings = self.backbone.get_input_embeddings()
inputs_embeds = token_embeddings(input_ids) # (b, s) -> (b, s, d)
# print("[2] inputs_embeds.shape =", inputs_embeds.shape)
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
# print("[3.a] inputs_embeds.shape =", inputs_embeds.shape)
input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
# print("[3.b] attention_mask.shape =", attention_mask.shape)
output = self.backbone(
inputs_embeds=inputs_embeds,
attention_mask=input_attention_mask,
output_hidden_states=True,
) # (1, 4 + b + s, d)
# trim soft prompt
last_hidden_state = output.hidden_states[-1]
n_soft_prompt_tokens = soft_prompt.shape[1]
output_vectors = last_hidden_state[:, n_soft_prompt_tokens:, :]
output_attention_mask = input_attention_mask[:, n_soft_prompt_tokens:]
# Take last token position
if vars(self.config).get("pooling_strategy") == "last_token":
output_pooled = last_token_pool(output_vectors, output_attention_mask)
elif vars(self.config).get("pooling_strategy") == "mean":
output_pooled = mean_pool(output_vectors, output_attention_mask)
else:
output_pooled = mean_pool_weighted(output_vectors, output_attention_mask)
# average with original vectors
# TODO: Argparse for pooling strategy.
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
if output_hidden_states:
return {
"hidden_states": output_vectors,
"pooled": output,
}
else:
return output
class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
def __init__(
self,
config,
dataset_backbone: transformers.PreTrainedModel,
):
super().__init__(config=config)
self.backbone = dataset_backbone
self.hidden_size = self.backbone.config.hidden_size
self.hidden_size = dataset_backbone.config.hidden_size
# self.input_ln = torch.nn.LayerNorm(
# self.hidden_size,
# eps=self.backbone.config.layer_norm_epsilon
# )
self.contextual_init()
self._shift_rotary_embedding()
@property
def num_corpus_tokens(self) -> int:
return self.config.transductive_corpus_size * self.transductive_tokens_per_document
def _shift_rotary_embedding(self) -> None:
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True)
if self.backbone.config.model_type.startswith("nomic") and disable_transductive_rotary_embedding:
# We only want to apply positional embeddings to the
# *text* portion of the backbone network.
self.backbone.config.rotary_start_pos = 0.0
rotary_disabled = 0
rotary_start_pos = self.num_corpus_tokens
for module in self.backbone.modules():
if hasattr(module, "rotary_emb_dim"):
module.rotary_start_pos = rotary_start_pos
rotary_disabled += 1
print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}")
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
dataset_embeddings: torch.Tensor,
output_hidden_states: bool = False,
null_dataset_embedding: bool = False,
) -> torch.Tensor:
# print(f"[DatasetConditionedBiencoder - 0] input_ids.shape => {input_ids.shape} // dataset_embeddings.shape =", dataset_embeddings.shape)
soft_prompt = self._prepare_dataset_embeddings(
input_ids=input_ids,
dataset_embeddings=dataset_embeddings,
null_dataset_embedding=null_dataset_embedding,
)
# print(f"[DatasetConditionedBiencoder - 1] soft_prompt.shape => {soft_prompt.shape}")
backbone_attention_mask = torch.ones(
soft_prompt.shape[0:2],
dtype=torch.long,
device=soft_prompt.device,
)
inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d)
# print("[2] inputs_embeds.shape =", inputs_embeds.shape)
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
# print("[3.a] inputs_embeds.shape =", inputs_embeds.shape)
attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
# print("[3.b] attention_mask.shape =", attention_mask.shape)
output = self.backbone(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
) # (1, 4 + b + s, d)
# trim soft prompt
output_vectors = output.last_hidden_state
# use only these tokens
n_soft_prompt_tokens = soft_prompt.shape[1]
# print("n_soft_prompt_tokens =", n_soft_prompt_tokens)
output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :]
output_attention_mask = attention_mask[:, n_soft_prompt_tokens:]
# print("pooling output_vectors.shape =", output_vectors.shape, "and output_attention_mask.shape =", output_attention_mask.shape)
output_pooled = mean_pool(output_vectors, output_attention_mask)
# average with original vectors
# TODO: Argparse for pooling strategy.
# output_vectors = torch.cat((soft_prompt_pooled, output_pooled), dim=1) # (b, d) + (b, d) -> (b, 2d)
# print("output_pooled.shape =", output_pooled.shape)
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
# print("returning output.shape =", output.shape)
if output_hidden_states:
return {
"hidden_states": output_vectors,
"pooled": output,
}
else:
return output
class DatasetPrefixBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
def __init__(
self,
config, #: transformers.PreTrainedConfig,
embedder: transformers.PreTrainedModel,
):
super().__init__(config=config)
self.embedder = embedder
self.hidden_size = self.embedder.config.hidden_size
self.contextual_init()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
dataset_input_ids: torch.Tensor,
dataset_attention_mask: torch.Tensor,
output_hidden_states: bool = False,
) -> torch.Tensor:
R = torch.randint(low=0, high=len(dataset_input_ids), size=(len(input_ids),), device=dataset_input_ids.device)
dataset_input_ids = dataset_input_ids[R]
input_ids = torch.cat((dataset_input_ids, input_ids), dim=1)
dataset_attention_mask = torch.ones_like(dataset_attention_mask, device=dataset_attention_mask.device)
input_attention_mask = torch.cat((dataset_attention_mask, attention_mask), dim=1)
output_attention_mask = torch.cat(
(torch.zeros_like(dataset_input_ids), attention_mask), dim=1
)
output = self.embedder(
input_ids=input_ids,
attention_mask=input_attention_mask,
)
output_vectors = output.last_hidden_state
output_pooled = mean_pool(output_vectors, output_attention_mask)
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
if output_hidden_states:
S_d = dataset_attention_mask.shape[1]
output_vectors = output_vectors[:, S_d:, :]
return {
"hidden_states": output_vectors,
"pooled": output,
}
else:
return output
class DatasetTransformer(transformers.PreTrainedModel):
config_class = ContextualModelConfig
embedder: transformers.PreTrainedModel
dataset_backbone: transformers.PreTrainedModel
def __init__(
self,
config,
):
super().__init__(config=config)
dataset_backbone, _ = load_embedder_and_tokenizer(
vars(config).get("dataset_backbone", config.embedder)
)
if config.limit_layers:
print0(f"Limiting layers to {config.limit_layers}")
limit_layers(dataset_backbone, config.limit_layers)
biencoder_config = copy.deepcopy(config)
biencoder_config.embedding_output_dim = None
biencoder_config.limit_layers = vars(self.config).get("limit_layers_first_stage", None)
self.first_stage_model = BiEncoder(
config=biencoder_config,
)
if vars(config).get("autoregressive_backbone", False):
self.second_stage_model = DatasetConditionedAutoregressive(
config=config,
dataset_backbone=dataset_backbone,
first_stage_hidden_size=self.first_stage_model.hidden_size,
)
else:
self.second_stage_model = DatasetConditionedBiencoder(
config=config,
dataset_backbone=dataset_backbone
)
self.temp = config.logit_scale
if config.disable_dropout:
disable_dropout(self)
transductive_tie_token_embeddings = vars(self.config).get("transductive_tie_token_embeddings", False)
if transductive_tie_token_embeddings:
self.second_stage_model.backbone.embeddings.word_embeddings.weight = (
self.first_stage_model.embedder.embeddings.word_embeddings.weight
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
dataset_input_ids: Optional[torch.Tensor],
dataset_attention_mask: Optional[torch.Tensor],
output_hidden_states: bool = False,
) -> torch.Tensor:
"""
input_ids (long torch.Tensor) – ids of input tokens
attention_mask (bool torch.Tensor)
"""
dataset_embeddings = self.first_stage_model(
input_ids=dataset_input_ids,
attention_mask=dataset_attention_mask
)
return self.second_stage_model(
input_ids=input_ids,
attention_mask=attention_mask,
dataset_embeddings=dataset_embeddings,
output_hidden_states=output_hidden_states,
)
def get_model_class(name: str):
if name in 'transductive':
return DatasetTransformer
elif name == 'biencoder':
return BiEncoder
elif name == "dataset_prefix_biencoder":
return DatasetPrefixBiencoder
else:
raise ValueError(f'unknown model cls {name}')