Spaces:
Runtime error
Runtime error
import os | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096' # do this before importing pytorch | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
from transformers import EsmModel | |
import torch | |
import numpy as np | |
from lightning.pytorch import seed_everything | |
from typing import Tuple | |
import torch | |
import gc | |
from torch.optim.lr_scheduler import _LRScheduler | |
from transformers import EsmModel, PreTrainedModel | |
from configuration import MetaLATTEConfig | |
seed_everything(42) | |
class GELU(nn.Module): | |
"""Implementation of the gelu activation function. | |
For information: OpenAI GPT's gelu is slightly different | |
(and gives slightly different results): | |
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
""" | |
def forward(self, x): | |
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0))) | |
def rotate_half(x): | |
x1, x2 = x.chunk(2, dim=-1) # x: B, L, H, hidden # x1: B, L, H, hidden // 2 | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_pos_emb(x, cos, sin): | |
# Assuming x has shape (B, L, H, HIDDEN_DIM) | |
# cos and sin have shape (1, L, HIDDEN_DIM) | |
cos = cos.unsqueeze(2) # (1, L, 1, HIDDEN_DIM) | |
sin = sin.unsqueeze(2) # (1, L, 1, HIDDEN_DIM) | |
return (x * cos) + (rotate_half(x) * sin) | |
class RotaryEmbedding(torch.nn.Module): | |
""" | |
The rotary position embeddings from RoFormer_ (Su et. al). | |
A crucial insight from the method is that the query and keys are | |
transformed by rotation matrices which depend on the relative positions. | |
Other implementations are available in the Rotary Transformer repo_ and in | |
GPT-NeoX_, GPT-NeoX was an inspiration | |
.. _RoFormer: https://arxiv.org/abs/2104.09864 | |
.. _repo: https://github.com/ZhuiyiTechnology/roformer | |
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox | |
.. warning: Please note that this embedding is not registered on purpose, as it is transformative | |
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis | |
""" | |
def __init__(self, dim: int, *_, **__): | |
super().__init__() | |
# Generate and save the inverse frequency buffer (non trainable) | |
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
self.register_buffer("inv_freq", inv_freq) | |
self._seq_len_cached = None | |
self._cos_cached = None | |
self._sin_cached = None | |
def _update_cos_sin_tables(self, x, seq_dimension=1): | |
seq_len = x.shape[seq_dimension] | |
# Reset the tables if the sequence length has changed, | |
# or if we're on a new device (possibly due to tracing for instance) | |
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: | |
self._seq_len_cached = seq_len | |
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) # L, 256 | |
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) # L, 512 | |
self._cos_cached = emb.cos()[None, :, :] # 1, L, 512 | |
self._sin_cached = emb.sin()[None, :, :] # 1, L, 512 | |
return self._cos_cached, self._sin_cached | |
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k) | |
return ( | |
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), # B, L, H, hidden | |
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), | |
) | |
def macro_f1(y_true, y_pred, thresholds): | |
y_pred_binary = (y_pred >= thresholds).float() | |
tp = (y_true * y_pred_binary).sum(dim=0) | |
fp = ((1 - y_true) * y_pred_binary).sum(dim=0) | |
fn = (y_true * (1 - y_pred_binary)).sum(dim=0) | |
precision = tp / (tp + fp + 1e-7) | |
recall = tp / (tp + fn + 1e-7) | |
f1 = 2 * precision * recall / (precision + recall + 1e-7) | |
macro_f1 = f1.mean() | |
return macro_f1 | |
def safeguard_softmax(logits, dim=-1): | |
# remove max number to prevent exp() to be INF | |
max_logits, _ = logits.max(dim=dim, keepdim=True) | |
exp_logits = torch.exp(logits - max_logits) | |
exp_sum = exp_logits.sum(dim=dim, keepdim=True) | |
probs = exp_logits / (exp_sum + 1e-7) # Adding a small epsilon to prevent division by zero | |
return probs | |
class PositionalAttentionHead(nn.Module): | |
def __init__(self, hidden_dim, n_heads): | |
super(PositionalAttentionHead, self).__init__() | |
self.n_heads = n_heads | |
self.hidden_dim = hidden_dim | |
self.head_dim = hidden_dim // n_heads | |
self.preattn_ln = nn.LayerNorm(self.head_dim) | |
self.Q = nn.Linear(self.head_dim, self.head_dim, bias=False) | |
self.K = nn.Linear(self.head_dim, self.head_dim, bias=False) | |
self.V = nn.Linear(self.head_dim, self.head_dim, bias=False) | |
self.rot_emb = RotaryEmbedding(self.head_dim) | |
def forward(self, x, attention_mask): | |
batch_size, seq_len, _ = x.size() # B, L, H | |
x = x.view(batch_size, seq_len, self.n_heads, self.head_dim) | |
x = self.preattn_ln(x) | |
q = self.Q(x) | |
k = self.K(x) | |
v = self.V(x) | |
q, k = self.rot_emb(q, k) | |
gc.collect() | |
torch.cuda.empty_cache() | |
attn_scores = torch.einsum('bqhd,bkhd->bhqk', q, k) / math.sqrt(self.head_dim) | |
#print(attention_mask.unsqueeze(1).shape) | |
#print(attention_mask.unsqueeze(1).unsqueeze(1).shape) | |
attn_scores = attn_scores.masked_fill(torch.logical_not(attention_mask.unsqueeze(1).unsqueeze(1)), float("-inf")) # B, H, L, L | |
attn_probs = safeguard_softmax(attn_scores, dim=-1) | |
x = torch.einsum('bhqk,bkhd->bqhd', attn_probs, v) | |
x = x.reshape(batch_size, seq_len, self.hidden_dim) # B, L, H | |
gc.collect() | |
torch.cuda.empty_cache() | |
return x, attn_probs | |
class CosineAnnealingWithWarmup(_LRScheduler): | |
# Implement based on Llama paper's description | |
# https://arxiv.org/abs/2302.13971 | |
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1): | |
self.warmup_steps = warmup_steps | |
self.total_steps = total_steps | |
self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate | |
super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch) | |
def get_lr(self): | |
if self.last_epoch < self.warmup_steps: | |
return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs] | |
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) | |
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress)) | |
decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio | |
return [decayed_lr * base_lr for base_lr in self.base_lrs] | |
class RobertaLMHead(nn.Module): | |
"""Head for masked language modeling.""" | |
def __init__(self, embed_dim, output_dim, weight): | |
super().__init__() | |
self.dense = nn.Linear(embed_dim, embed_dim) | |
self.layer_norm = nn.LayerNorm(embed_dim) | |
self.weight = weight | |
self.gelu = GELU() | |
self.bias = nn.Parameter(torch.zeros(output_dim)) | |
def forward(self, features): | |
x = self.dense(features) | |
x = self.gelu(x) | |
x = self.layer_norm(x) | |
# project back to size of vocabulary with bias | |
x = F.linear(x, self.weight) + self.bias | |
return x | |
class MultitaskProteinModel(PreTrainedModel): | |
config_class = MetaLATTEConfig | |
base_model_prefix = "metalatte" | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.esm_model = EsmModel.from_pretrained(self.config.esm_model_name) | |
# layer freezing for the original esm model | |
# first freeze all | |
for param in self.esm_model.parameters(): | |
param.requires_grad = False | |
# unfreeze the required layers | |
for i in range(config.num_layers_to_finetune): | |
for param in self.esm_model.encoder.layer[-i-1].parameters(): | |
param.requires_grad = True | |
self.lm_head = RobertaLMHead(embed_dim = 1280, output_dim=33, weight=self.esm_model.embeddings.word_embeddings.weight) | |
# esm_dim should be 1280 | |
self.attn_head = PositionalAttentionHead(self.config.hidden_size, self.config.num_attention_heads) | |
self.attn_ln = nn.LayerNorm(self.config.hidden_size) | |
self.attn_skip = nn.Linear(self.config.hidden_size, self.config.hidden_size) | |
self.linear_layers = nn.ModuleList() | |
# Add linear layers after the attention head | |
for _ in range(self.config.num_linear_layers): | |
self.linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size)) | |
self.reduction_layers = nn.Sequential( | |
nn.Linear(self.config.hidden_size, self.config.hidden_dim), | |
GELU(), | |
nn.Linear(self.config.hidden_dim, self.config.num_labels) | |
) | |
self.clf_ln = nn.LayerNorm(self.config.hidden_size) | |
self.classification_thresholds = nn.Parameter(torch.tensor([0.5]*self.config.num_labels)) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
config = kwargs.pop("config", None) | |
if config is None: | |
config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path) | |
model = cls(config) | |
state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location=torch.device('cpu'))['state_dict'] | |
model.load_state_dict(state_dict, strict=False) | |
return model | |
def forward(self, input_ids, attention_mask=None): | |
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) | |
embeddings = outputs.last_hidden_state | |
attention_masks = attention_mask | |
x_pool, x_attns = self.attn_head(embeddings, attention_masks) | |
x_pool = self.attn_ln(x_pool + self.attn_skip(x_pool)) # Added skip connection for the attention layer | |
for linear_layer in self.linear_layers: | |
residue = x_pool | |
x_pool = linear_layer(x_pool) # 1280 -> 1280 | |
x_pool = F.silu(x_pool) | |
x_pool = x_pool + residue # Skip connection | |
x_weighted = torch.einsum('bhlk,bld->bhld', x_attns, x_pool) # (B, H, L, 1280) | |
x_combined = x_weighted.mean(dim=1) # Average over heads: (B, L, 1280) | |
x_combined = self.clf_ln(x_combined) | |
mlm_logits = self.lm_head(x_combined) | |
attention_masks = attention_masks.unsqueeze(-1).float() # (B, L, 1) | |
attention_sum = attention_masks.sum(dim=1, keepdim=True) # (B, 1, 1) | |
x_combined_masked = (x_combined * attention_masks).sum(dim=1) / attention_sum.squeeze(1) # (B, 1280) | |
# Compute classification logits | |
x_pred = self.reduction_layers(x_combined_masked) | |
gc.collect() | |
torch.cuda.empty_cache() | |
return x_pred, x_attns, x_combined_masked, mlm_logits | |
def predict(self, input_ids, attention_mask=None): | |
x_pred, _, _, _ = self.forward(input_ids, attention_mask) | |
classification_output = torch.sigmoid(x_pred) | |
predictions = (classification_output >= self.classification_thresholds).float() | |
for i, pred in enumerate(predictions): | |
if pred.sum() == 0: | |
weighted_probs = classification_output[i] | |
max_class = torch.argmax(weighted_probs) | |
predictions[i, max_class] = 1.0 | |
return classification_output, predictions | |