import torch from torch import nn from sentence_transformers import SentenceTransformer from regressor import * import numpy as np import os ENCODER = os.getenv("ENCODER") class NextUsRegressor(nn.Module): def __init__(self): super(NextUsRegressor, self).__init__() self.embedder = SentenceTransformer(ENCODER) self.regressor = WRegressor() return def forward(self, txts): # expects a list of strings if type(txts) == str: txts = [txts] embedded = self.embedder.encode(np.array(txts)) embedded_tensor = torch.tensor(embedded, dtype=torch.float32) regressed = self.regressor(embedded_tensor) val = regressed.flatten().tolist()[0] return round(val, 4)