|
from transformers import PreTrainedModel |
|
|
|
from pythainlp import word_vector |
|
import torch |
|
|
|
from .configuration import ThaiLightWeightEncoderConfig |
|
from .projector import Projector |
|
|
|
|
|
class ThaiLightWeightEncoderModel(PreTrainedModel): |
|
config_class = ThaiLightWeightEncoderConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.wv = word_vector.WordVector(model_name=config.word_vector_model_name) |
|
self.projector = Projector( |
|
input_embedding_dim=config.input_embedding_dim, |
|
final_embedding_dim=config.final_embedding_dim, |
|
dropout=config.dropout |
|
) |
|
|
|
def forward(self, text: str): |
|
embed = self.wv.sentence_vectorizer(text, use_mean=True)[0] |
|
proj_embed = self.projector(torch.from_numpy(embed).float()) |
|
proj_embed = proj_embed.to("cpu").detach().numpy() |
|
return proj_embed |