File size: 280 Bytes
fa64206
 
 
 
 
 
1
2
3
4
5
6
7
from transformers import DistilBertForSequenceClassification

def get_student_model(config):
    model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
    model.config.hidden_size = config['model']['student']['hidden_size']
    return model