Esm2Text-Base-v1-1 / modeling_prot2text.py
habdine's picture
Update modeling_prot2text.py
71ad8b3 verified
from transformers import GPT2Config, AutoTokenizer, GPT2Config
from transformers import PretrainedConfig, PreTrainedModel
import transformers
from typing import Optional, Tuple, Callable, List
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
from .utils import CABlock, _GPT2LMHeadModel
from .configuration_prot2text import Prot2TextConfig
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
class Prot2TextModel(PreTrainedModel):
config_class = Prot2TextConfig
_keys_to_ignore_on_load_missing = [r"transformer"]
base_model_prefix = "decoder"
def __init__(self, config):
super().__init__(config)
self.gpt_config = GPT2Config.from_dict(config.gpt_config)
# define the GPT2 decoder
self.decoder = _GPT2LMHeadModel(self.gpt_config)
# if using ESM to encode protein's sequence, define the ESM layer, the Projection layer and the fusion layer
if config.esm:
self.esm_config = PretrainedConfig.from_dict(config.esm_config)
self.esm = transformers.EsmModel(self.esm_config)
self.to_embedding = nn.Linear(self.esm_config.hidden_size, self.gpt_config.n_embd)
if config.cross_esm_graph and config.rgcn:
self.h = nn.ModuleList([CABlock(self.gpt_config, layer_idx=i) for i in range(4)])
self.ln_f = nn.LayerNorm(self.gpt_config.n_embd, eps=self.gpt_config.layer_norm_epsilon)
self.config = config
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def get_input_embeddings(self):
if hasattr(self, "transformer"):
return self.transformer.wte
return self.decoder.transformer.wte
def warm_up(self, gpt_model=None, esm_model=None):
if esm_model is not None:
self.esm = transformers.EsmModel.from_pretrained(esm_model)
if gpt_model is not None:
self.decoder = _GPT2LMHeadModel.from_pretrained(gpt_model, add_cross_attention=True, use_cache=False)
self.decoder.resize_token_embeddings(self.gpt_config.vocab_size)
self.decoder.config = self.gpt_config
def forward(self,
encoder_input_ids: Optional[torch.LongTensor] = None,
edge_index: Optional[torch.LongTensor] = None,
batch: Optional[torch.LongTensor] = None,
x: Optional[torch.FloatTensor] = None,
edge_type: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values_graph_esm: Optional[Tuple[Tuple[torch.Tensor]]] = None,
decoder_attention_mask: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
get_graph_emb: Optional[bool] = False,
**delete_args,
):
use_cache = use_cache if use_cache is not None else self.gpt_config.use_cache
return_dict = return_dict if return_dict is not None else self.gpt_config.use_return_dict
if decoder_input_ids is not None and len(decoder_input_ids.size()) == 3:
decoder_input_ids = decoder_input_ids.squeeze(0)
if self.config.esm:
if self.config.prot2text_version=='1.0':
if encoder_input_ids.size()[1] != 1021:
raise ValueError("For this version of the model you need to PAD/Truncate the amino acid sequence for the ESM model to 1021")
esm_emb = self.esm(input_ids=encoder_input_ids, attention_mask=attention_mask, return_dict=return_dict).last_hidden_state
esm_emb = self.to_embedding(esm_emb)
graph_emb = esm_emb
else:
attention_mask = None
if self.config.prot2text_version=='1.0':
attention_mask = None
if get_graph_emb:
return graph_emb
transformer_outputs = self.decoder(input_ids=decoder_input_ids,
past_key_values=past_key_values,
attention_mask=decoder_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=graph_emb,
encoder_attention_mask=attention_mask,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return transformer_outputs
@torch.no_grad()
def generate_protein_description(self,
protein_sequence=None,
tokenizer=None,
device='cpu',
streamer=None,
max_new_tokens=None,
do_sample=None,
top_p=None,
top_k=None,
temperature=None,
num_beams=1,
repetition_penalty=None
):
if self.config.esm and not self.config.rgcn and protein_sequence==None:
raise ValueError(
"The model you are trying to use is based only on protein sequence, please provide an amino-acid protein_sequence"
)
if self.config.esm:
esmtokenizer = AutoTokenizer.from_pretrained(self.config.esm_model_name)
seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
inputs={}
inputs['encoder_input_ids'] = seq['input_ids']
inputs['attention_mask'] = seq['attention_mask']
inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone()
inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id
self.to(device)
inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
encoder_state = dict()
encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
if streamer is None:
generated = tokenizer.batch_decode(self.decoder.generate(input_ids=inputs['decoder_input_ids'], encoder_outputs=encoder_state, use_cache=True), skip_special_tokens=True)
return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
else:
return self.decoder.generate(input_ids=inputs['decoder_input_ids'],
encoder_outputs=encoder_state,
use_cache=True,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty)
@torch.no_grad()
def generate(self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
):
encoder_state = self(**kwargs, get_graph_emb=True)
input_ids = kwargs['decoder_input_ids']
attention_mask = kwargs['decoder_attention_mask']
kwargs['encoder_attention_mask'] = kwargs['attention_mask']
if not self.config.cross_esm_graph and self.config.rgcn and self.config.esm:
t_add = torch.ones((kwargs['encoder_attention_mask'].size(0), 1)).to(kwargs['encoder_attention_mask'].get_device())
kwargs['encoder_attention_mask'] = torch.cat((t_add, kwargs['encoder_attention_mask']), dim=1)
for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids', 'decoder_input_ids', 'decoder_attention_mask', 'batch', 'attention_mask', 'max_length',
'_num_nodes', 'node_id', 'name', 'sequence', 'distance_matrix', 'distance', 'coordinates', 'ptr', 'num_nodes',]:
if key in kwargs.keys():
kwargs.pop(key)
return self.decoder.generate(input_ids=input_ids,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
encoder_outputs={'hidden_states': encoder_state, 'attentions':0},
**kwargs
)