Prot2Text-Medium-v1-0 / configuration_prot2text.py
habdine's picture
Upload code
d49dad6 verified
""" Prot2Text configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers import AutoConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Prot2TextConfig(PretrainedConfig):
model_type = "prot2text"
keys_to_ignore_at_inference = ["past_key_values"]
_keys_to_ignore_on_load_missing = [r"transformer"]
def __init__(
self,
cross_esm_graph=True,
decoder_start_token_id=50257,
early_stopping=True,
eos_token_id=50258,
bos_token_id=50257,
esm=True,
esm_model_name="facebook/esm2_t6_8M_UR50D",
gpt_model_name="gpt2",
length_penalty=2.0,
max_new_tokens=256,
no_repeat_ngram_size=3,
pad_token_id=50256,
prot2text_version="1.1",
rgcn=True,
rgc_input_dim=67,
rgcn_n_layers=6,
gpt_config=None,
esm_config=None,
**kwargs,
):
self.cross_esm_graph = cross_esm_graph
self.decoder_start_token_id = decoder_start_token_id
self.early_stopping = early_stopping
self.eos_token_id = eos_token_id
self.esm = esm
self.esm_model_name = esm_model_name
self.gpt_model_name = gpt_model_name
self.length_penalty = length_penalty
self.max_new_tokens = max_new_tokens
self.no_repeat_ngram_size = no_repeat_ngram_size
self.pad_token_id = pad_token_id
self.prot2text_version = prot2text_version
self.rgcn = rgcn
self.rgc_input_dim = rgc_input_dim
self.rgcn_n_layers = rgcn_n_layers
if gpt_config is None:
self.gpt_config = AutoConfig.from_pretrained(gpt_model_name,
_name_or_path= gpt_model_name,
is_encoder_decoder=True,
use_cache=False,
add_cross_attention=True,
bos_token_id=bos_token_id,
decoder_start_token_id=decoder_start_token_id,
eos_token_id=eos_token_id,
max_new_tokens=max_new_tokens,
pad_token_id=50256,
vocab_size=50259,
num_beams=1,
max_length=256,
min_length=1).to_dict()
else:
self.gpt_config = gpt_config
if esm_config is None:
self.esm_config = AutoConfig.from_pretrained(esm_model_name).to_dict()
self.esm_config = esm_config
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)