Blaise-g commited on
Commit
ff43405
β€’
1 Parent(s): 09c590a

Update summarize.py

Browse files
Files changed (1) hide show
  1. summarize.py +3 -3
summarize.py CHANGED
@@ -45,14 +45,14 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
45
  input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
46
  attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
47
 
48
- global_attention_mask = torch.zeros_like(attention_mask)
49
  # put global attention on <s> token
50
- global_attention_mask[:, 0] = 1
51
 
52
  summary_pred_ids = model.generate(
53
  input_ids,
54
  attention_mask=attention_mask,
55
- global_attention_mask=global_attention_mask,
56
  output_scores=True,
57
  return_dict_in_generate=True,
58
  **kwargs,
 
45
  input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
46
  attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
47
 
48
+ #global_attention_mask = torch.zeros_like(attention_mask)
49
  # put global attention on <s> token
50
+ #global_attention_mask[:, 0] = 1
51
 
52
  summary_pred_ids = model.generate(
53
  input_ids,
54
  attention_mask=attention_mask,
55
+ #global_attention_mask=global_attention_mask,
56
  output_scores=True,
57
  return_dict_in_generate=True,
58
  **kwargs,