Update summarize.py
Browse files- summarize.py +5 -10
summarize.py
CHANGED
@@ -27,9 +27,9 @@ def load_model_and_tokenizer(model_name):
|
|
27 |
return model, tokenizer
|
28 |
|
29 |
|
30 |
-
def
|
31 |
"""
|
32 |
-
|
33 |
Args:
|
34 |
ids (): the batch of ids
|
35 |
mask (): the attention mask for the batch
|
@@ -53,8 +53,6 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
|
53 |
input_ids,
|
54 |
attention_mask=attention_mask,
|
55 |
#global_attention_mask=global_attention_mask,
|
56 |
-
output_scores=True,
|
57 |
-
return_dict_in_generate=True,
|
58 |
**kwargs,
|
59 |
)
|
60 |
summary = tokenizer.batch_decode(
|
@@ -62,9 +60,8 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
|
62 |
skip_special_tokens=True,
|
63 |
remove_invalid_values=True,
|
64 |
)
|
65 |
-
score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
|
66 |
len_res = len(summary_pred_ids.sequences.cpu().numpy()[0])
|
67 |
-
return summary,
|
68 |
|
69 |
|
70 |
def summarize_via_tokenbatches(
|
@@ -111,23 +108,21 @@ def summarize_via_tokenbatches(
|
|
111 |
|
112 |
for _id, _mask in zip(in_id_arr, att_arr):
|
113 |
|
114 |
-
result,
|
115 |
ids=_id,
|
116 |
mask=_mask,
|
117 |
model=model,
|
118 |
tokenizer=tokenizer,
|
119 |
**kwargs,
|
120 |
)
|
121 |
-
score = round(float(score), 4)
|
122 |
rate = round(float((len(_id)-l)/len(_id)),3)
|
123 |
_sum = {
|
124 |
"input_tokens": _id,
|
125 |
"summary": result,
|
126 |
-
"summary_score": score,
|
127 |
"compression_rate": rate,
|
128 |
}
|
129 |
gen_summaries.append(_sum)
|
130 |
-
print(f"\t{result[0]}\
|
131 |
pbar.update()
|
132 |
|
133 |
pbar.close()
|
|
|
27 |
return model, tokenizer
|
28 |
|
29 |
|
30 |
+
def summarize(ids, mask, model, tokenizer, **kwargs):
|
31 |
"""
|
32 |
+
summarize - given a batch of ids and a mask, returns a summary and the token length of the output summary
|
33 |
Args:
|
34 |
ids (): the batch of ids
|
35 |
mask (): the attention mask for the batch
|
|
|
53 |
input_ids,
|
54 |
attention_mask=attention_mask,
|
55 |
#global_attention_mask=global_attention_mask,
|
|
|
|
|
56 |
**kwargs,
|
57 |
)
|
58 |
summary = tokenizer.batch_decode(
|
|
|
60 |
skip_special_tokens=True,
|
61 |
remove_invalid_values=True,
|
62 |
)
|
|
|
63 |
len_res = len(summary_pred_ids.sequences.cpu().numpy()[0])
|
64 |
+
return summary, len_res
|
65 |
|
66 |
|
67 |
def summarize_via_tokenbatches(
|
|
|
108 |
|
109 |
for _id, _mask in zip(in_id_arr, att_arr):
|
110 |
|
111 |
+
result, l = summarize(
|
112 |
ids=_id,
|
113 |
mask=_mask,
|
114 |
model=model,
|
115 |
tokenizer=tokenizer,
|
116 |
**kwargs,
|
117 |
)
|
|
|
118 |
rate = round(float((len(_id)-l)/len(_id)),3)
|
119 |
_sum = {
|
120 |
"input_tokens": _id,
|
121 |
"summary": result,
|
|
|
122 |
"compression_rate": rate,
|
123 |
}
|
124 |
gen_summaries.append(_sum)
|
125 |
+
print(f"\t{result[0]}\nCompression:\t{rate}")
|
126 |
pbar.update()
|
127 |
|
128 |
pbar.close()
|