awacke1 commited on
Commit
9112592
·
1 Parent(s): cce4162

Create new file

Browse files
Files changed (1) hide show
  1. summarize.py +133 -0
summarize.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ from tqdm.auto import tqdm
5
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
6
+
7
+
8
+ def load_model_and_tokenizer(model_name):
9
+ """
10
+ load_model_and_tokenizer - a function that loads a model and tokenizer from huggingface
11
+ Args:
12
+ model_name (str): the name of the model to load
13
+ Returns:
14
+ AutoModelForSeq2SeqLM: the model
15
+ AutoTokenizer: the tokenizer
16
+ """
17
+
18
+ model = AutoModelForSeq2SeqLM.from_pretrained(
19
+ model_name,
20
+ # low_cpu_mem_usage=True,
21
+ # use_cache=False,
22
+ )
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ model = model.to("cuda") if torch.cuda.is_available() else model
25
+
26
+ logging.info(f"Loaded model {model_name}")
27
+ return model, tokenizer
28
+
29
+
30
+ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
31
+ """
32
+ summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
33
+ Args:
34
+ ids (): the batch of ids
35
+ mask (): the attention mask for the batch
36
+ model (): the model to use for summarization
37
+ tokenizer (): the tokenizer to use for summarization
38
+ Returns:
39
+ str: the summary of the batch
40
+ """
41
+
42
+ ids = ids[None, :]
43
+ mask = mask[None, :]
44
+
45
+ input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
46
+ attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
47
+
48
+ global_attention_mask = torch.zeros_like(attention_mask)
49
+ # put global attention on <s> token
50
+ global_attention_mask[:, 0] = 1
51
+
52
+ summary_pred_ids = model.generate(
53
+ input_ids,
54
+ attention_mask=attention_mask,
55
+ global_attention_mask=global_attention_mask,
56
+ output_scores=True,
57
+ return_dict_in_generate=True,
58
+ **kwargs,
59
+ )
60
+ summary = tokenizer.batch_decode(
61
+ summary_pred_ids.sequences,
62
+ skip_special_tokens=True,
63
+ remove_invalid_values=True,
64
+ )
65
+ score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
66
+
67
+ return summary, score
68
+
69
+
70
+ def summarize_via_tokenbatches(
71
+ input_text: str,
72
+ model,
73
+ tokenizer,
74
+ batch_length=2048,
75
+ batch_stride=16,
76
+ **kwargs,
77
+ ):
78
+ """
79
+ summarize_via_tokenbatches - a function that takes a string and returns a summary
80
+ Args:
81
+ input_text (str): the text to summarize
82
+ model (): the model to use for summarization
83
+ tokenizer (): the tokenizer to use for summarization
84
+ batch_length (int, optional): the length of each batch. Defaults to 2048.
85
+ batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
86
+ Returns:
87
+ str: the summary
88
+ """
89
+ # log all input parameters
90
+ if batch_length < 512:
91
+ batch_length = 512
92
+ print("WARNING: batch_length was set to 512")
93
+ print(
94
+ f"input parameters: {kwargs}, batch_length={batch_length}, batch_stride={batch_stride}"
95
+ )
96
+ encoded_input = tokenizer(
97
+ input_text,
98
+ padding="max_length",
99
+ truncation=True,
100
+ max_length=batch_length,
101
+ stride=batch_stride,
102
+ return_overflowing_tokens=True,
103
+ add_special_tokens=False,
104
+ return_tensors="pt",
105
+ )
106
+
107
+ in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
108
+ gen_summaries = []
109
+
110
+ pbar = tqdm(total=len(in_id_arr))
111
+
112
+ for _id, _mask in zip(in_id_arr, att_arr):
113
+
114
+ result, score = summarize_and_score(
115
+ ids=_id,
116
+ mask=_mask,
117
+ model=model,
118
+ tokenizer=tokenizer,
119
+ **kwargs,
120
+ )
121
+ score = round(float(score), 4)
122
+ _sum = {
123
+ "input_tokens": _id,
124
+ "summary": result,
125
+ "summary_score": score,
126
+ }
127
+ gen_summaries.append(_sum)
128
+ print(f"\t{result[0]}\nScore:\t{score}")
129
+ pbar.update()
130
+
131
+ pbar.close()
132
+
133
+ return gen_summaries