Update summarize.py
Browse files- summarize.py +3 -3
summarize.py
CHANGED
@@ -45,14 +45,14 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
|
45 |
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
|
46 |
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
|
47 |
|
48 |
-
global_attention_mask = torch.zeros_like(attention_mask)
|
49 |
# put global attention on <s> token
|
50 |
-
global_attention_mask[:, 0] = 1
|
51 |
|
52 |
summary_pred_ids = model.generate(
|
53 |
input_ids,
|
54 |
attention_mask=attention_mask,
|
55 |
-
global_attention_mask=global_attention_mask,
|
56 |
output_scores=True,
|
57 |
return_dict_in_generate=True,
|
58 |
**kwargs,
|
|
|
45 |
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
|
46 |
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
|
47 |
|
48 |
+
#global_attention_mask = torch.zeros_like(attention_mask)
|
49 |
# put global attention on <s> token
|
50 |
+
#global_attention_mask[:, 0] = 1
|
51 |
|
52 |
summary_pred_ids = model.generate(
|
53 |
input_ids,
|
54 |
attention_mask=attention_mask,
|
55 |
+
#global_attention_mask=global_attention_mask,
|
56 |
output_scores=True,
|
57 |
return_dict_in_generate=True,
|
58 |
**kwargs,
|