soarhigh commited on
Commit
a721924
·
1 Parent(s): 50361dc

print tabulated results

Browse files
Files changed (1) hide show
  1. nextus_regressor_class.py +11 -4
nextus_regressor_class.py CHANGED
@@ -18,13 +18,20 @@ class NextUsRegressor(nn.Module):
18
  return
19
 
20
  def forward(self, txts):
21
- # expects a list of strings
22
  if type(txts) == str:
23
  txts = [txts]
24
  embedded = self.embedder.encode(np.array(txts))
 
 
25
  embedded_tensor = torch.tensor(embedded, dtype=torch.float32)
26
  regressed = self.regressor(embedded_tensor)
27
- vals = regressed.flatten().tolist()
28
- return str(vals)
29
- #return "\n".join([str(v) for v in vals])
30
 
 
 
 
 
 
 
 
 
 
 
18
  return
19
 
20
  def forward(self, txts):
 
21
  if type(txts) == str:
22
  txts = [txts]
23
  embedded = self.embedder.encode(np.array(txts))
24
+ # embedded_tensor = self.embedder(np.array(txts))
25
+
26
  embedded_tensor = torch.tensor(embedded, dtype=torch.float32)
27
  regressed = self.regressor(embedded_tensor)
 
 
 
28
 
29
+ # return regressed.tolist()
30
+ # TODO: actually handle list of strings
31
+ vals = regressed.flatten().tolist()
32
+ # must return the whole thing, not just the 0-th element
33
+ strs = list()
34
+ for t, v in list(zip(txts, vals)):
35
+ strs.append(str(round(v, 4)) + "\t" + t[:20])
36
+ return "\n".join(strs)
37
+ # return torch.tensor(val).unsqueeze(1)