Blaise-g commited on
Commit
36460c9
β€’
1 Parent(s): 66a576c

Update summarize.py

Browse files
Files changed (1) hide show
  1. summarize.py +27 -21
summarize.py CHANGED
@@ -27,7 +27,7 @@ def load_model_and_tokenizer(model_name):
27
  return model, tokenizer
28
 
29
 
30
- def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
31
  """
32
  summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
33
  Args:
@@ -35,6 +35,7 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
35
  mask (): the attention mask for the batch
36
  model (): the model to use for summarization
37
  tokenizer (): the tokenizer to use for summarization
 
38
  Returns:
39
  str: the summary of the batch
40
  """
@@ -44,27 +45,32 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
44
 
45
  input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
46
  attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
47
-
48
- #global_attention_mask = torch.zeros_like(attention_mask)
49
- # put global attention on <s> token
50
- #global_attention_mask[:, 0] = 1
51
-
52
- summary_pred_ids = model.generate(
53
- input_ids,
54
- attention_mask=attention_mask,
55
- #global_attention_mask=global_attention_mask,
56
- output_scores=True,
57
- return_dict_in_generate=True,
58
- **kwargs,
59
- )
 
 
 
 
 
 
 
60
  summary = tokenizer.batch_decode(
61
  summary_pred_ids.sequences,
62
  skip_special_tokens=True,
63
  remove_invalid_values=True,
64
  )
65
- score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
66
-
67
- return summary, score
68
 
69
 
70
  def summarize_via_tokenbatches(
@@ -111,21 +117,21 @@ def summarize_via_tokenbatches(
111
 
112
  for _id, _mask in zip(in_id_arr, att_arr):
113
 
114
- result, score = summarize_and_score(
115
  ids=_id,
116
  mask=_mask,
117
  model=model,
118
  tokenizer=tokenizer,
119
  **kwargs,
120
  )
121
- score = round(float(score), 4)
122
  _sum = {
123
  "input_tokens": _id,
124
  "summary": result,
125
- "summary_score": score,
126
  }
127
  gen_summaries.append(_sum)
128
- print(f"\t{result[0]}\nScore:\t{score}")
129
  pbar.update()
130
 
131
  pbar.close()
 
27
  return model, tokenizer
28
 
29
 
30
+ def summarize(ids, mask, model, tokenizer, model_arch, **kwargs):
31
  """
32
  summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
33
  Args:
 
35
  mask (): the attention mask for the batch
36
  model (): the model to use for summarization
37
  tokenizer (): the tokenizer to use for summarization
38
+ model
39
  Returns:
40
  str: the summary of the batch
41
  """
 
45
 
46
  input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
47
  attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
48
+
49
+ if model_arch == 'LED':
50
+ global_attention_mask = torch.zeros_like(attention_mask)
51
+ # put global attention on <s> token
52
+ global_attention_mask[:, 0] = 1
53
+ summary_pred_ids = model.generate(
54
+ input_ids,
55
+ attention_mask=attention_mask,
56
+ global_attention_mask=global_attention_mask,
57
+ return_dict_in_generate=True,
58
+ **kwargs,
59
+ )
60
+
61
+ else:
62
+ summary_pred_ids = model.generate(
63
+ input_ids,
64
+ attention_mask=attention_mask,
65
+ return_dict_in_generate=True,
66
+ **kwargs,
67
+ )
68
  summary = tokenizer.batch_decode(
69
  summary_pred_ids.sequences,
70
  skip_special_tokens=True,
71
  remove_invalid_values=True,
72
  )
73
+ return summary
 
 
74
 
75
 
76
  def summarize_via_tokenbatches(
 
117
 
118
  for _id, _mask in zip(in_id_arr, att_arr):
119
 
120
+ result = summarize(
121
  ids=_id,
122
  mask=_mask,
123
  model=model,
124
  tokenizer=tokenizer,
125
  **kwargs,
126
  )
127
+ rate = round(float(len()), 3)
128
  _sum = {
129
  "input_tokens": _id,
130
  "summary": result,
131
+ "compression_rate": rate,
132
  }
133
  gen_summaries.append(_sum)
134
+ print(f"\t{result[0]}\nRate:\t{rate}")
135
  pbar.update()
136
 
137
  pbar.close()