print tabulated results
Browse files- nextus_regressor_class.py +11 -4
nextus_regressor_class.py
CHANGED
@@ -18,13 +18,20 @@ class NextUsRegressor(nn.Module):
|
|
18 |
return
|
19 |
|
20 |
def forward(self, txts):
|
21 |
-
# expects a list of strings
|
22 |
if type(txts) == str:
|
23 |
txts = [txts]
|
24 |
embedded = self.embedder.encode(np.array(txts))
|
|
|
|
|
25 |
embedded_tensor = torch.tensor(embedded, dtype=torch.float32)
|
26 |
regressed = self.regressor(embedded_tensor)
|
27 |
-
vals = regressed.flatten().tolist()
|
28 |
-
return str(vals)
|
29 |
-
#return "\n".join([str(v) for v in vals])
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
return
|
19 |
|
20 |
def forward(self, txts):
|
|
|
21 |
if type(txts) == str:
|
22 |
txts = [txts]
|
23 |
embedded = self.embedder.encode(np.array(txts))
|
24 |
+
# embedded_tensor = self.embedder(np.array(txts))
|
25 |
+
|
26 |
embedded_tensor = torch.tensor(embedded, dtype=torch.float32)
|
27 |
regressed = self.regressor(embedded_tensor)
|
|
|
|
|
|
|
28 |
|
29 |
+
# return regressed.tolist()
|
30 |
+
# TODO: actually handle list of strings
|
31 |
+
vals = regressed.flatten().tolist()
|
32 |
+
# must return the whole thing, not just the 0-th element
|
33 |
+
strs = list()
|
34 |
+
for t, v in list(zip(txts, vals)):
|
35 |
+
strs.append(str(round(v, 4)) + "\t" + t[:20])
|
36 |
+
return "\n".join(strs)
|
37 |
+
# return torch.tensor(val).unsqueeze(1)
|