Spaces:
Running
Running
🚧 update for longt5
Browse filesSigned-off-by: peter szemraj <peterszemraj@gmail.com>
- summarize.py +23 -13
summarize.py
CHANGED
@@ -15,20 +15,19 @@ def load_model_and_tokenizer(model_name):
|
|
15 |
AutoModelForSeq2SeqLM: the model
|
16 |
AutoTokenizer: the tokenizer
|
17 |
"""
|
18 |
-
|
19 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
20 |
model_name,
|
21 |
# low_cpu_mem_usage=True,
|
22 |
# use_cache=False,
|
23 |
-
)
|
24 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
25 |
-
model = model.to("cuda") if torch.cuda.is_available() else model
|
26 |
|
27 |
-
logging.info(f"Loaded model {model_name}")
|
28 |
return model, tokenizer
|
29 |
|
30 |
|
31 |
-
def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
32 |
"""
|
33 |
summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
|
34 |
|
@@ -37,6 +36,7 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
|
37 |
mask (): the attention mask for the batch
|
38 |
model (): the model to use for summarization
|
39 |
tokenizer (): the tokenizer to use for summarization
|
|
|
40 |
|
41 |
Returns:
|
42 |
str: the summary of the batch
|
@@ -52,14 +52,23 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
|
52 |
# put global attention on <s> token
|
53 |
global_attention_mask[:, 0] = 1
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
summary = tokenizer.batch_decode(
|
64 |
summary_pred_ids.sequences,
|
65 |
skip_special_tokens=True,
|
@@ -70,6 +79,7 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
|
70 |
return summary, score
|
71 |
|
72 |
|
|
|
73 |
def summarize_via_tokenbatches(
|
74 |
input_text: str,
|
75 |
model,
|
|
|
15 |
AutoModelForSeq2SeqLM: the model
|
16 |
AutoTokenizer: the tokenizer
|
17 |
"""
|
18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
20 |
model_name,
|
21 |
# low_cpu_mem_usage=True,
|
22 |
# use_cache=False,
|
23 |
+
).to(device)
|
24 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
25 |
|
26 |
+
logging.info(f"Loaded model {model_name} to {device}")
|
27 |
return model, tokenizer
|
28 |
|
29 |
|
30 |
+
def summarize_and_score(ids, mask, model, tokenizer, is_general_attention_model=True, **kwargs):
|
31 |
"""
|
32 |
summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
|
33 |
|
|
|
36 |
mask (): the attention mask for the batch
|
37 |
model (): the model to use for summarization
|
38 |
tokenizer (): the tokenizer to use for summarization
|
39 |
+
is_general_attention_model (bool, optional): whether the model is a general attention model. Defaults to True.
|
40 |
|
41 |
Returns:
|
42 |
str: the summary of the batch
|
|
|
52 |
# put global attention on <s> token
|
53 |
global_attention_mask[:, 0] = 1
|
54 |
|
55 |
+
if is_general_attention_model:
|
56 |
+
summary_pred_ids = model.generate(
|
57 |
+
input_ids,
|
58 |
+
attention_mask=attention_mask,
|
59 |
+
output_scores=True,
|
60 |
+
return_dict_in_generate=True,
|
61 |
+
**kwargs,
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
summary_pred_ids = model.generate(
|
65 |
+
input_ids,
|
66 |
+
attention_mask=attention_mask,
|
67 |
+
global_attention_mask=global_attention_mask,
|
68 |
+
output_scores=True,
|
69 |
+
return_dict_in_generate=True,
|
70 |
+
**kwargs,
|
71 |
+
)
|
72 |
summary = tokenizer.batch_decode(
|
73 |
summary_pred_ids.sequences,
|
74 |
skip_special_tokens=True,
|
|
|
79 |
return summary, score
|
80 |
|
81 |
|
82 |
+
|
83 |
def summarize_via_tokenbatches(
|
84 |
input_text: str,
|
85 |
model,
|