|
from typing import Dict, Optional, Union |
|
|
|
import copy |
|
import torch |
|
import torch.nn as nn |
|
import transformers |
|
|
|
from cde.lib.dist import print0 |
|
from cde.lib.tensor import mean_pool, mean_pool_3d, mean_pool_weighted, last_token_pool |
|
|
|
from cde.lib import load_embedder_and_tokenizer, ContextualModelConfig |
|
|
|
|
|
def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None: |
|
if hasattr(model, 'transformer'): |
|
if hasattr(model.transformer, 'h'): |
|
|
|
model.transformer.h = model.transformer.h[:n_layers] |
|
else: |
|
model.transformer.layer = model.transformer.layer[:n_layers] |
|
elif hasattr(model, 'encoder'): |
|
if hasattr(model.encoder, 'layers'): |
|
model.encoder.layers = model.encoder.layers[:n_layers] |
|
else: |
|
model.encoder.layer = model.encoder.layer[:n_layers] |
|
else: |
|
raise RuntimeError(f"unknown how to limit layers of model {type(model)}") |
|
|
|
|
|
def disable_dropout(model: torch.nn.Module): |
|
dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)] |
|
for m in dropout_modules: |
|
m.p = 0.0 |
|
print0( |
|
f"Disabled {len(dropout_modules)} dropout modules from model type {type(model)}" |
|
) |
|
|
|
|
|
def disable_causality(model: torch.nn.Module): |
|
disabled_modules = 0 |
|
for m in model.modules(): |
|
if hasattr(m, "is_causal"): |
|
m.is_causal = False |
|
disabled_modules += 1 |
|
print0( |
|
f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}" |
|
) |
|
|
|
class ContextualModelMixin(nn.Module): |
|
@property |
|
def num_corpus_tokens(self) -> int: |
|
return self.transductive_corpus_size * self.transductive_tokens_per_document |
|
|
|
def contextual_init(self): |
|
self.n_soft_prompt = 8 |
|
self.prompt_projection = torch.nn.Sequential( |
|
torch.nn.Linear(self.hidden_size, self.hidden_size), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(self.hidden_size, self.hidden_size * self.n_soft_prompt) |
|
) |
|
self.transductive_corpus_size = vars(self.config).get("transductive_corpus_size", 1) |
|
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) |
|
self.randomize_dataset_sequence_order = True |
|
self.sequence_dropout_prob = vars(self.config).get("transductive_sequence_dropout_prob", 0.0) |
|
if self.sequence_dropout_prob > 0.0: |
|
self.sequence_dropout_null_embedding = torch.nn.Parameter( |
|
torch.randn(self.hidden_size) * 0.01, |
|
requires_grad = True |
|
) |
|
self.output_projection = torch.nn.Sequential( |
|
torch.nn.Linear(self.hidden_size, self.hidden_size), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(self.hidden_size, self.hidden_size) |
|
) |
|
|
|
def _prepare_dataset_embeddings( |
|
self, |
|
input_ids: torch.Tensor, dataset_embeddings: torch.Tensor, |
|
null_dataset_embedding: bool = False, |
|
) -> torch.Tensor: |
|
if not isinstance(dataset_embeddings, torch.Tensor): |
|
dataset_embeddings = torch.tensor(dataset_embeddings) |
|
|
|
if len(dataset_embeddings.shape) == 2: |
|
|
|
dataset_embeddings = dataset_embeddings[None, :, :] |
|
dataset_embeddings = dataset_embeddings.to(input_ids.device) |
|
|
|
batch_size = input_ids.shape[0] |
|
if (self.transductive_tokens_per_document > 1): |
|
if self.training: |
|
|
|
|
|
|
|
assert dataset_embeddings.shape[1] == self.transductive_tokens_per_document |
|
R = torch.randint( |
|
low=0, |
|
high=len(dataset_embeddings), |
|
size=(batch_size, self.config.transductive_corpus_size), |
|
device=dataset_embeddings.device |
|
) |
|
|
|
dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size)) |
|
else: |
|
dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size)) |
|
|
|
|
|
if dataset_embeddings.shape[1] > self.num_corpus_tokens: |
|
|
|
|
|
dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :] |
|
|
|
_, corpus_size, _hidden_size = dataset_embeddings.shape |
|
if _ == 1: |
|
|
|
dataset_embeddings = dataset_embeddings.expand((batch_size, -1, -1)) |
|
|
|
if self.training and self.sequence_dropout_prob > 0.0: |
|
sequence_dropout_mask = ( |
|
torch.rand((batch_size, corpus_size), device=dataset_embeddings.device) < self.sequence_dropout_prob |
|
) |
|
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) |
|
dataset_embeddings = torch.where( |
|
sequence_dropout_mask[..., None], null_embeddings, dataset_embeddings |
|
) |
|
elif null_dataset_embedding: |
|
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) |
|
dataset_embeddings = null_embeddings |
|
|
|
|
|
|
|
|
|
|
|
soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype) |
|
soft_prompt = self.prompt_projection(soft_prompt).reshape((1, self.n_soft_prompt, self.hidden_size)) |
|
soft_prompt = soft_prompt.expand((len(dataset_embeddings), -1, -1)) |
|
soft_prompt = torch.cat((dataset_embeddings, soft_prompt), dim=1) |
|
|
|
|
|
|
|
if self.training and self.randomize_dataset_sequence_order: |
|
randomized_order = torch.stack( |
|
[ |
|
torch.cat( |
|
( |
|
torch.randperm(corpus_size, device=soft_prompt.device), |
|
torch.arange(self.n_soft_prompt, device=soft_prompt.device) + corpus_size |
|
), dim=0) |
|
for _ in range(batch_size)]) |
|
randomized_order = randomized_order.to(soft_prompt.device) |
|
soft_prompt = soft_prompt.gather(1, randomized_order[..., None].expand_as(soft_prompt)) |
|
|
|
return soft_prompt |
|
|
|
class BiEncoder(transformers.PreTrainedModel): |
|
embedder: transformers.PreTrainedModel |
|
def __init__( |
|
self, |
|
config, |
|
): |
|
super().__init__(config=config) |
|
embedder, _ = load_embedder_and_tokenizer( |
|
config.embedder, |
|
) |
|
|
|
if config.limit_layers: |
|
print0(f"Limiting layers to {config.limit_layers}") |
|
limit_layers(embedder, config.limit_layers) |
|
|
|
self.embedder = embedder |
|
|
|
|
|
|
|
self.hidden_size = self.embedder.config.hidden_size |
|
|
|
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) |
|
self.mlp = torch.nn.Sequential( |
|
torch.nn.Linear(self.hidden_size, self.hidden_size), |
|
torch.nn.GELU(), |
|
torch.nn.Linear(self.hidden_size, self.config.embedding_output_dim or self.hidden_size), |
|
) |
|
self.temp = config.logit_scale |
|
|
|
if config.disable_dropout: |
|
disable_dropout(self) |
|
self.pooling_strategy = vars(config).get("pooling_strategy", "mean") |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_input_ids: Optional[torch.Tensor] = None, |
|
dataset_attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids = None, |
|
output_hidden_states: bool = False, |
|
) -> torch.Tensor: |
|
""" |
|
query_embedding (float torch.Tensor) - shape (batch_size, embedding_dim) |
|
document_embeddings (float torch.Tensor) - shape (corpus_size, embedding_dim) |
|
where the corpus_size >= batch_size and is structured like this: |
|
[d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2] |
|
for a corpus with three documents and two hard negatives per document |
|
""" |
|
|
|
|
|
del token_type_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = ( |
|
self.embedder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
).last_hidden_state |
|
) |
|
|
|
if self.transductive_tokens_per_document > 1: |
|
document_embeddings = None |
|
batch_size, seq_length, output_dim = outputs.shape |
|
|
|
if seq_length % self.transductive_tokens_per_document != 0: |
|
|
|
n_extra_embeds = self.transductive_tokens_per_document - (seq_length % self.transductive_tokens_per_document) |
|
outputs = torch.cat( |
|
(outputs, torch.zeros((batch_size, n_extra_embeds, output_dim), device=outputs.device)), |
|
dim=1 |
|
) |
|
attention_mask = torch.cat( |
|
(attention_mask, torch.zeros((batch_size, n_extra_embeds), device=attention_mask.device)), |
|
dim=1 |
|
) |
|
seq_length += n_extra_embeds |
|
print(f"Added {n_extra_embeds} padding tokens to input_ids and attention_mask") |
|
|
|
|
|
|
|
outputs = outputs.reshape( |
|
(batch_size, self.transductive_tokens_per_document, seq_length // self.transductive_tokens_per_document, output_dim) |
|
) |
|
|
|
attention_mask = attention_mask.reshape((batch_size, self.transductive_tokens_per_document, -1)) |
|
document_embeddings = mean_pool_3d(outputs, attention_mask) |
|
|
|
document_embeddings = document_embeddings.reshape((batch_size, self.transductive_tokens_per_document, output_dim)) |
|
else: |
|
if self.pooling_strategy == "mean": |
|
document_embeddings = mean_pool(outputs, attention_mask) |
|
else: |
|
document_embeddings = document_embeddings.max(dim=1) |
|
output = self.mlp(document_embeddings) |
|
|
|
if output_hidden_states: |
|
return { |
|
"hidden_states": outputs, |
|
"pooled": output, |
|
} |
|
else: |
|
return output |
|
|
|
|
|
class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualModelMixin): |
|
def __init__( |
|
self, |
|
config, |
|
dataset_backbone: transformers.PreTrainedModel, |
|
first_stage_hidden_size: int, |
|
): |
|
super().__init__(config=config) |
|
self.backbone = dataset_backbone |
|
self.backbone_hidden_size = self.backbone.config.hidden_size |
|
self.hidden_size = first_stage_hidden_size |
|
self.contextual_init() |
|
disable_causality(self.backbone) |
|
|
|
self.input_ln = torch.nn.LayerNorm( |
|
self.backbone_hidden_size, |
|
eps=1e-5 |
|
) |
|
|
|
|
|
self.output_projection = torch.nn.Sequential( |
|
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size) |
|
) |
|
self._shift_rotary_embedding() |
|
|
|
@property |
|
def num_corpus_tokens(self) -> int: |
|
return self.config.transductive_corpus_size * self.transductive_tokens_per_document |
|
|
|
@property |
|
def corpus_token_ratio(self) -> float: |
|
|
|
|
|
return self.backbone_hidden_size / self.hidden_size |
|
|
|
def corpus_token_pad_size(self, n_tokens: int) -> int: |
|
return self.hidden_size % self.backbone_hidden_size |
|
|
|
def _shift_rotary_embedding(self) -> None: |
|
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) |
|
|
|
print("Warning: Positional embedding disabling not implemented for LLAMA.") |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_embeddings: torch.Tensor, |
|
output_hidden_states: bool = False, |
|
null_dataset_embedding: bool = False, |
|
) -> torch.Tensor: |
|
soft_prompt = self._prepare_dataset_embeddings( |
|
input_ids=input_ids, |
|
dataset_embeddings=dataset_embeddings, |
|
null_dataset_embedding=null_dataset_embedding, |
|
) |
|
|
|
|
|
|
|
num_soft_elements = torch.prod(torch.tensor(soft_prompt.shape[1:])).item() |
|
soft_prompt = soft_prompt.reshape((soft_prompt.shape[0], num_soft_elements)) |
|
num_padding_elements = self.backbone_hidden_size - (num_soft_elements % self.backbone_hidden_size) |
|
padding = torch.ones((soft_prompt.shape[0], num_padding_elements), device=soft_prompt.device) |
|
soft_prompt = torch.cat((soft_prompt, padding), dim=1) |
|
soft_prompt = soft_prompt.reshape( |
|
(soft_prompt.shape[0], -1, self.backbone_hidden_size) |
|
) |
|
soft_prompt = self.input_ln(soft_prompt) |
|
|
|
|
|
backbone_attention_mask = torch.ones( |
|
soft_prompt.shape[0:2], |
|
dtype=torch.long, |
|
device=soft_prompt.device, |
|
) |
|
token_embeddings = self.backbone.get_input_embeddings() |
|
inputs_embeds = token_embeddings(input_ids) |
|
|
|
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) |
|
|
|
input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) |
|
|
|
|
|
output = self.backbone( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=input_attention_mask, |
|
output_hidden_states=True, |
|
) |
|
|
|
last_hidden_state = output.hidden_states[-1] |
|
n_soft_prompt_tokens = soft_prompt.shape[1] |
|
|
|
output_vectors = last_hidden_state[:, n_soft_prompt_tokens:, :] |
|
output_attention_mask = input_attention_mask[:, n_soft_prompt_tokens:] |
|
|
|
|
|
if vars(self.config).get("pooling_strategy") == "last_token": |
|
output_pooled = last_token_pool(output_vectors, output_attention_mask) |
|
elif vars(self.config).get("pooling_strategy") == "mean": |
|
output_pooled = mean_pool(output_vectors, output_attention_mask) |
|
else: |
|
output_pooled = mean_pool_weighted(output_vectors, output_attention_mask) |
|
|
|
|
|
|
|
output = self.output_projection(output_pooled) |
|
|
|
if output_hidden_states: |
|
return { |
|
"hidden_states": output_vectors, |
|
"pooled": output, |
|
} |
|
else: |
|
return output |
|
|
|
|
|
class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin): |
|
def __init__( |
|
self, |
|
config, |
|
dataset_backbone: transformers.PreTrainedModel, |
|
): |
|
super().__init__(config=config) |
|
self.backbone = dataset_backbone |
|
self.hidden_size = self.backbone.config.hidden_size |
|
self.hidden_size = dataset_backbone.config.hidden_size |
|
|
|
|
|
|
|
|
|
self.contextual_init() |
|
self._shift_rotary_embedding() |
|
|
|
@property |
|
def num_corpus_tokens(self) -> int: |
|
return self.config.transductive_corpus_size * self.transductive_tokens_per_document |
|
|
|
def _shift_rotary_embedding(self) -> None: |
|
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) |
|
if self.backbone.config.model_type.startswith("nomic") and disable_transductive_rotary_embedding: |
|
|
|
|
|
self.backbone.config.rotary_start_pos = 0.0 |
|
rotary_disabled = 0 |
|
|
|
rotary_start_pos = self.num_corpus_tokens |
|
for module in self.backbone.modules(): |
|
if hasattr(module, "rotary_emb_dim"): |
|
module.rotary_start_pos = rotary_start_pos |
|
rotary_disabled += 1 |
|
print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}") |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_embeddings: torch.Tensor, |
|
output_hidden_states: bool = False, |
|
null_dataset_embedding: bool = False, |
|
) -> torch.Tensor: |
|
|
|
soft_prompt = self._prepare_dataset_embeddings( |
|
input_ids=input_ids, |
|
dataset_embeddings=dataset_embeddings, |
|
null_dataset_embedding=null_dataset_embedding, |
|
) |
|
|
|
backbone_attention_mask = torch.ones( |
|
soft_prompt.shape[0:2], |
|
dtype=torch.long, |
|
device=soft_prompt.device, |
|
) |
|
inputs_embeds = self.backbone.embeddings(input_ids) |
|
|
|
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) |
|
|
|
attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) |
|
|
|
output = self.backbone( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
output_vectors = output.last_hidden_state |
|
|
|
|
|
n_soft_prompt_tokens = soft_prompt.shape[1] |
|
|
|
|
|
output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :] |
|
output_attention_mask = attention_mask[:, n_soft_prompt_tokens:] |
|
|
|
|
|
output_pooled = mean_pool(output_vectors, output_attention_mask) |
|
|
|
|
|
|
|
|
|
|
|
output = self.output_projection(output_pooled) |
|
|
|
|
|
|
|
if output_hidden_states: |
|
return { |
|
"hidden_states": output_vectors, |
|
"pooled": output, |
|
} |
|
else: |
|
return output |
|
|
|
|
|
class DatasetPrefixBiencoder(transformers.PreTrainedModel, ContextualModelMixin): |
|
def __init__( |
|
self, |
|
config, |
|
embedder: transformers.PreTrainedModel, |
|
): |
|
super().__init__(config=config) |
|
self.embedder = embedder |
|
self.hidden_size = self.embedder.config.hidden_size |
|
self.contextual_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_input_ids: torch.Tensor, |
|
dataset_attention_mask: torch.Tensor, |
|
output_hidden_states: bool = False, |
|
) -> torch.Tensor: |
|
R = torch.randint(low=0, high=len(dataset_input_ids), size=(len(input_ids),), device=dataset_input_ids.device) |
|
|
|
dataset_input_ids = dataset_input_ids[R] |
|
input_ids = torch.cat((dataset_input_ids, input_ids), dim=1) |
|
|
|
dataset_attention_mask = torch.ones_like(dataset_attention_mask, device=dataset_attention_mask.device) |
|
input_attention_mask = torch.cat((dataset_attention_mask, attention_mask), dim=1) |
|
output_attention_mask = torch.cat( |
|
(torch.zeros_like(dataset_input_ids), attention_mask), dim=1 |
|
) |
|
|
|
output = self.embedder( |
|
input_ids=input_ids, |
|
attention_mask=input_attention_mask, |
|
) |
|
|
|
output_vectors = output.last_hidden_state |
|
output_pooled = mean_pool(output_vectors, output_attention_mask) |
|
output = self.output_projection(output_pooled) |
|
|
|
if output_hidden_states: |
|
S_d = dataset_attention_mask.shape[1] |
|
output_vectors = output_vectors[:, S_d:, :] |
|
return { |
|
"hidden_states": output_vectors, |
|
"pooled": output, |
|
} |
|
else: |
|
return output |
|
|
|
|
|
class DatasetTransformer(transformers.PreTrainedModel): |
|
config_class = ContextualModelConfig |
|
embedder: transformers.PreTrainedModel |
|
dataset_backbone: transformers.PreTrainedModel |
|
def __init__( |
|
self, |
|
config, |
|
): |
|
super().__init__(config=config) |
|
dataset_backbone, _ = load_embedder_and_tokenizer( |
|
vars(config).get("dataset_backbone", config.embedder) |
|
) |
|
|
|
if config.limit_layers: |
|
print0(f"Limiting layers to {config.limit_layers}") |
|
limit_layers(dataset_backbone, config.limit_layers) |
|
|
|
biencoder_config = copy.deepcopy(config) |
|
biencoder_config.embedding_output_dim = None |
|
biencoder_config.limit_layers = vars(self.config).get("limit_layers_first_stage", None) |
|
self.first_stage_model = BiEncoder( |
|
config=biencoder_config, |
|
) |
|
|
|
if vars(config).get("autoregressive_backbone", False): |
|
self.second_stage_model = DatasetConditionedAutoregressive( |
|
config=config, |
|
dataset_backbone=dataset_backbone, |
|
first_stage_hidden_size=self.first_stage_model.hidden_size, |
|
) |
|
else: |
|
self.second_stage_model = DatasetConditionedBiencoder( |
|
config=config, |
|
dataset_backbone=dataset_backbone |
|
) |
|
|
|
self.temp = config.logit_scale |
|
if config.disable_dropout: |
|
disable_dropout(self) |
|
|
|
transductive_tie_token_embeddings = vars(self.config).get("transductive_tie_token_embeddings", False) |
|
if transductive_tie_token_embeddings: |
|
self.second_stage_model.backbone.embeddings.word_embeddings.weight = ( |
|
self.first_stage_model.embedder.embeddings.word_embeddings.weight |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_input_ids: Optional[torch.Tensor], |
|
dataset_attention_mask: Optional[torch.Tensor], |
|
output_hidden_states: bool = False, |
|
) -> torch.Tensor: |
|
""" |
|
input_ids (long torch.Tensor) – ids of input tokens |
|
attention_mask (bool torch.Tensor) |
|
""" |
|
dataset_embeddings = self.first_stage_model( |
|
input_ids=dataset_input_ids, |
|
attention_mask=dataset_attention_mask |
|
) |
|
return self.second_stage_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
dataset_embeddings=dataset_embeddings, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
|
|
|
|
def get_model_class(name: str): |
|
if name in 'transductive': |
|
return DatasetTransformer |
|
elif name == 'biencoder': |
|
return BiEncoder |
|
elif name == "dataset_prefix_biencoder": |
|
return DatasetPrefixBiencoder |
|
else: |
|
raise ValueError(f'unknown model cls {name}') |
|
|