Spaces:
Sleeping
Sleeping
Peter
commited on
Commit
•
66e7228
1
Parent(s):
e05a3b5
:art: apply black
Browse files- app.py +23 -13
- summarize.py +39 -37
app.py
CHANGED
@@ -17,6 +17,7 @@ import transformers
|
|
17 |
transformers.logging.set_verbosity_error()
|
18 |
logging.basicConfig()
|
19 |
|
|
|
20 |
def truncate_word_count(text, max_words=512):
|
21 |
"""
|
22 |
truncate_word_count - a helper function for the gradio module
|
@@ -38,6 +39,7 @@ def truncate_word_count(text, max_words=512):
|
|
38 |
processed["truncated_text"] = text
|
39 |
return processed
|
40 |
|
|
|
41 |
def proc_submission(
|
42 |
input_text: str,
|
43 |
num_beams,
|
@@ -80,15 +82,15 @@ def proc_submission(
|
|
80 |
history["was_truncated"] = False
|
81 |
|
82 |
_summaries = summarize_via_tokenbatches(
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
89 |
sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in _summaries]
|
90 |
|
91 |
-
|
92 |
history["Input"] = input_text
|
93 |
history["Summary Text"] = "\n\t".join(sum_text)
|
94 |
history["Summary Scores"] = "\n".join(sum_scores)
|
@@ -104,7 +106,8 @@ def proc_submission(
|
|
104 |
|
105 |
return html
|
106 |
|
107 |
-
|
|
|
108 |
src = _here / examples_dir
|
109 |
src.mkdir(exist_ok=True)
|
110 |
examples = [f for f in src.glob("*.txt")]
|
@@ -113,15 +116,18 @@ def load_examples(examples_dir='examples'):
|
|
113 |
for example in examples:
|
114 |
with open(example, "r") as f:
|
115 |
text = f.read()
|
116 |
-
text_examples.append([text, 4, 2048, 0.7,3.5,3])
|
117 |
|
118 |
return text_examples
|
119 |
|
|
|
120 |
if __name__ == "__main__":
|
121 |
|
122 |
-
model, tokenizer = load_model_and_tokenizer(
|
123 |
title = "Long-form text summarization with LED on the BookSumm dataset"
|
124 |
-
description =
|
|
|
|
|
125 |
|
126 |
gr.Interface(
|
127 |
proc_submission,
|
@@ -130,7 +136,11 @@ if __name__ == "__main__":
|
|
130 |
gr.inputs.Slider(
|
131 |
minimum=4, maximum=10, label="num_beams", default=4, step=1
|
132 |
),
|
133 |
-
gr.Dropdown(
|
|
|
|
|
|
|
|
|
134 |
gr.inputs.Slider(
|
135 |
minimum=0.5, maximum=1.1, label="length_penalty", default=0.7, step=0.05
|
136 |
),
|
@@ -150,4 +160,4 @@ if __name__ == "__main__":
|
|
150 |
title=title,
|
151 |
description=description,
|
152 |
examples=load_examples(),
|
153 |
-
).launch(enable_queue=True, share=True)
|
|
|
17 |
transformers.logging.set_verbosity_error()
|
18 |
logging.basicConfig()
|
19 |
|
20 |
+
|
21 |
def truncate_word_count(text, max_words=512):
|
22 |
"""
|
23 |
truncate_word_count - a helper function for the gradio module
|
|
|
39 |
processed["truncated_text"] = text
|
40 |
return processed
|
41 |
|
42 |
+
|
43 |
def proc_submission(
|
44 |
input_text: str,
|
45 |
num_beams,
|
|
|
82 |
history["was_truncated"] = False
|
83 |
|
84 |
_summaries = summarize_via_tokenbatches(
|
85 |
+
history["input_text"],
|
86 |
+
model,
|
87 |
+
tokenizer,
|
88 |
+
batch_length=token_batch_length,
|
89 |
+
**settings,
|
90 |
+
)
|
91 |
+
sum_text = [s["summary"][0] for s in _summaries]
|
92 |
sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in _summaries]
|
93 |
|
|
|
94 |
history["Input"] = input_text
|
95 |
history["Summary Text"] = "\n\t".join(sum_text)
|
96 |
history["Summary Scores"] = "\n".join(sum_scores)
|
|
|
106 |
|
107 |
return html
|
108 |
|
109 |
+
|
110 |
+
def load_examples(examples_dir="examples"):
|
111 |
src = _here / examples_dir
|
112 |
src.mkdir(exist_ok=True)
|
113 |
examples = [f for f in src.glob("*.txt")]
|
|
|
116 |
for example in examples:
|
117 |
with open(example, "r") as f:
|
118 |
text = f.read()
|
119 |
+
text_examples.append([text, 4, 2048, 0.7, 3.5, 3])
|
120 |
|
121 |
return text_examples
|
122 |
|
123 |
+
|
124 |
if __name__ == "__main__":
|
125 |
|
126 |
+
model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary")
|
127 |
title = "Long-form text summarization with LED on the BookSumm dataset"
|
128 |
+
description = (
|
129 |
+
"This is a simple example of using the LED model to summarize a long-form text."
|
130 |
+
)
|
131 |
|
132 |
gr.Interface(
|
133 |
proc_submission,
|
|
|
136 |
gr.inputs.Slider(
|
137 |
minimum=4, maximum=10, label="num_beams", default=4, step=1
|
138 |
),
|
139 |
+
gr.Dropdown(
|
140 |
+
choices=[512, 1024, 2048, 4096],
|
141 |
+
label="token_batch_length",
|
142 |
+
default=2048,
|
143 |
+
),
|
144 |
gr.inputs.Slider(
|
145 |
minimum=0.5, maximum=1.1, label="length_penalty", default=0.7, step=0.05
|
146 |
),
|
|
|
160 |
title=title,
|
161 |
description=description,
|
162 |
examples=load_examples(),
|
163 |
+
).launch(enable_queue=True, share=True)
|
summarize.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
from tqdm.auto import tqdm
|
3 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
4 |
|
|
|
5 |
def load_model_and_tokenizer(model_name):
|
6 |
"""
|
7 |
load_model_and_tokenizer - a function that loads a model and tokenizer from huggingface
|
@@ -14,14 +15,15 @@ def load_model_and_tokenizer(model_name):
|
|
14 |
"""
|
15 |
|
16 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
)
|
21 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
22 |
model = model.to("cuda") if torch.cuda.is_available() else model
|
23 |
return model, tokenizer
|
24 |
|
|
|
25 |
def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
26 |
"""
|
27 |
summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
|
@@ -36,43 +38,43 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
|
36 |
str: the summary of the batch
|
37 |
"""
|
38 |
|
39 |
-
|
40 |
ids = ids[None, :]
|
41 |
mask = mask[None, :]
|
42 |
|
43 |
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
|
44 |
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
|
45 |
|
46 |
-
|
47 |
attention_mask = mask.to("cuda")
|
48 |
global_attention_mask = torch.zeros_like(attention_mask)
|
49 |
# put global attention on <s> token
|
50 |
global_attention_mask[:, 0] = 1
|
51 |
|
52 |
summary_pred_ids = model.generate(
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
summary = tokenizer.batch_decode(
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
|
66 |
|
67 |
return summary, score
|
68 |
|
|
|
69 |
def summarize_via_tokenbatches(
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
76 |
"""
|
77 |
summarize_via_tokenbatches - a function that takes a string and returns a summary
|
78 |
|
@@ -88,15 +90,15 @@ def summarize_via_tokenbatches(
|
|
88 |
"""
|
89 |
|
90 |
encoded_input = tokenizer(
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
|
101 |
in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
|
102 |
gen_summaries = []
|
@@ -112,11 +114,11 @@ def summarize_via_tokenbatches(
|
|
112 |
tokenizer=tokenizer,
|
113 |
**kwargs,
|
114 |
)
|
115 |
-
score = round(float(score),4)
|
116 |
_sum = {
|
117 |
-
"input_tokens":_id,
|
118 |
-
"summary":result,
|
119 |
-
"summary_score":score,
|
120 |
}
|
121 |
gen_summaries.append(_sum)
|
122 |
print(f"\t{result[0]}\nScore:\t{score}")
|
@@ -124,4 +126,4 @@ def summarize_via_tokenbatches(
|
|
124 |
|
125 |
pbar.close()
|
126 |
|
127 |
-
return gen_summaries
|
|
|
2 |
from tqdm.auto import tqdm
|
3 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
4 |
|
5 |
+
|
6 |
def load_model_and_tokenizer(model_name):
|
7 |
"""
|
8 |
load_model_and_tokenizer - a function that loads a model and tokenizer from huggingface
|
|
|
15 |
"""
|
16 |
|
17 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
18 |
+
model_name,
|
19 |
+
low_cpu_mem_usage=True,
|
20 |
+
use_cache=False,
|
21 |
+
)
|
22 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
23 |
model = model.to("cuda") if torch.cuda.is_available() else model
|
24 |
return model, tokenizer
|
25 |
|
26 |
+
|
27 |
def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
28 |
"""
|
29 |
summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
|
|
|
38 |
str: the summary of the batch
|
39 |
"""
|
40 |
|
|
|
41 |
ids = ids[None, :]
|
42 |
mask = mask[None, :]
|
43 |
|
44 |
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
|
45 |
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
|
46 |
|
|
|
47 |
attention_mask = mask.to("cuda")
|
48 |
global_attention_mask = torch.zeros_like(attention_mask)
|
49 |
# put global attention on <s> token
|
50 |
global_attention_mask[:, 0] = 1
|
51 |
|
52 |
summary_pred_ids = model.generate(
|
53 |
+
input_ids,
|
54 |
+
attention_mask=attention_mask,
|
55 |
+
global_attention_mask=global_attention_mask,
|
56 |
+
output_scores=True,
|
57 |
+
return_dict_in_generate=True,
|
58 |
+
**kwargs,
|
59 |
+
)
|
60 |
summary = tokenizer.batch_decode(
|
61 |
+
summary_pred_ids.sequences,
|
62 |
+
skip_special_tokens=True,
|
63 |
+
remove_invalid_values=True,
|
64 |
+
)
|
65 |
score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
|
66 |
|
67 |
return summary, score
|
68 |
|
69 |
+
|
70 |
def summarize_via_tokenbatches(
|
71 |
+
input_text: str,
|
72 |
+
model,
|
73 |
+
tokenizer,
|
74 |
+
batch_length=2048,
|
75 |
+
batch_stride=16,
|
76 |
+
**kwargs,
|
77 |
+
):
|
78 |
"""
|
79 |
summarize_via_tokenbatches - a function that takes a string and returns a summary
|
80 |
|
|
|
90 |
"""
|
91 |
|
92 |
encoded_input = tokenizer(
|
93 |
+
input_text,
|
94 |
+
padding="max_length",
|
95 |
+
truncation=True,
|
96 |
+
max_length=batch_length,
|
97 |
+
stride=batch_stride,
|
98 |
+
return_overflowing_tokens=True,
|
99 |
+
add_special_tokens=False,
|
100 |
+
return_tensors="pt",
|
101 |
+
)
|
102 |
|
103 |
in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
|
104 |
gen_summaries = []
|
|
|
114 |
tokenizer=tokenizer,
|
115 |
**kwargs,
|
116 |
)
|
117 |
+
score = round(float(score), 4)
|
118 |
_sum = {
|
119 |
+
"input_tokens": _id,
|
120 |
+
"summary": result,
|
121 |
+
"summary_score": score,
|
122 |
}
|
123 |
gen_summaries.append(_sum)
|
124 |
print(f"\t{result[0]}\nScore:\t{score}")
|
|
|
126 |
|
127 |
pbar.close()
|
128 |
|
129 |
+
return gen_summaries
|