GenAI_project / models /pert_model.py
jaothan's picture
Upload 24 files
fa64206 verified
raw
history blame contribute delete
442 Bytes
from transformers import BertForSequenceClassification, AdapterConfig
def get_pert_model(config):
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
adapter_config = AdapterConfig(mh_adapter=True, output_adapter=True, reduction_factor=config['model']['adapter']['reduction_factor'])
model.add_adapter('imdb_adapter', config=adapter_config)
model.train_adapter('imdb_adapter')
return model