hungdungn47 commited on
Commit
d76f6bc
·
1 Parent(s): 9bf0efa

fix line 99

Browse files
Files changed (1) hide show
  1. infer_concat.py +1 -2
infer_concat.py CHANGED
@@ -96,8 +96,7 @@ def infer_2_hier(model, data_loader, device, tokenizer):
96
  summaries.append(summary)
97
  summaries = torch.cat(summaries, dim = 1)
98
 
99
- all_summaries.append(tokenizer.decode(summaries, skip_special_tokens=True))
100
-
101
 
102
  end = time.time()
103
  print(f"Time: {end-start}")
 
96
  summaries.append(summary)
97
  summaries = torch.cat(summaries, dim = 1)
98
 
99
+ all_summaries.append(tokenizer.decode(summaries.squeeze(), skip_special_tokens=True))
 
100
 
101
  end = time.time()
102
  print(f"Time: {end-start}")