import torch import torch.nn as nn from transformers import AutoModel class TextEncoder(nn.Module): def __init__(self, output_dim=64, lang_model="sentence-transformers/all-MiniLM-L6-v2", unfreeze_n_blocks=4): super().__init__() self.lang_model = lang_model self.encoder = AutoModel.from_pretrained(lang_model) # freeze all parameters for param in self.encoder.parameters(): param.requires_grad = False # unfreeze the last few encoder layers for layer in self.encoder.encoder.layer[ - unfreeze_n_blocks :]: for param in layer.parameters(): param.requires_grad = True # unfreeze the pooler layer for param in self.encoder.pooler.parameters(): param.requires_grad = True self.fc = nn.Linear(self.encoder.config.hidden_size, output_dim) def forward(self, input_ids, attention_mask=None): x = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] x = self.fc(x) return x