|
import torch |
|
import tqdm |
|
from torch import nn |
|
from transformers import T5PreTrainedModel, T5EncoderModel |
|
|
|
class T5EncoderWithProjection(T5PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.t5_encoder = T5EncoderModel(config) |
|
self.projection = nn.Linear(config.d_model, config.d_model, bias=False) |
|
|
|
self.post_init() |
|
|
|
def forward(self, **input_args): |
|
hidden_states = self.t5_encoder(**input_args).last_hidden_state |
|
hidden_states = hidden_states[:, 0, :] |
|
batch_embeddings = self.projection(hidden_states) |
|
return batch_embeddings |
|
|