File size: 387 Bytes
fa64206
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
from transformers import BertForSequenceClassification, LoRAConfig

def get_lora_model(config):
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
    lora_config = LoRAConfig(r=config['model']['lora']['r'], alpha=config['model']['lora']['alpha'])
    model.add_lora('imdb_lora', config=lora_config)
    model.train_lora('imdb_lora')
    return model