import torch | |
from torch import nn | |
from sentence_transformers import SentenceTransformer | |
from regressor import * | |
import numpy as np | |
import os | |
ENCODER = os.getenv("ENCODER") | |
class NextUsRegressor(nn.Module): | |
def __init__(self): | |
super(NextUsRegressor, self).__init__() | |
self.embedder = SentenceTransformer(ENCODER) | |
self.regressor = WRegressor() | |
return | |
def forward(self, txts): | |
# expects a list of strings | |
if type(txts) == str: | |
txts = [txts] | |
embedded = self.embedder.encode(np.array(txts)) | |
embedded_tensor = torch.tensor(embedded, dtype=torch.float32) | |
regressed = self.regressor(embedded_tensor) | |
val = regressed.flatten().tolist()[0] | |
return round(val, 4) | |