Blaise-g commited on
Commit
4c95a2f
β€’
1 Parent(s): 84a8fa1

Delete summ.py

Browse files
Files changed (1) hide show
  1. summ.py +0 -133
summ.py DELETED
@@ -1,133 +0,0 @@
1
- import logging
2
-
3
- import torch
4
- from tqdm.auto import tqdm
5
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
6
-
7
-
8
- def load_model_and_tokenizer(model_name):
9
- """
10
- load_model_and_tokenizer - a function that loads a model and tokenizer from huggingface
11
- Args:
12
- model_name (str): the name of the model to load
13
- Returns:
14
- AutoModelForSeq2SeqLM: the model
15
- AutoTokenizer: the tokenizer
16
- """
17
-
18
- model = AutoModelForSeq2SeqLM.from_pretrained(
19
- model_name,
20
- # low_cpu_mem_usage=True,
21
- # use_cache=False,
22
- )
23
- tokenizer = AutoTokenizer.from_pretrained(model_name)
24
- model = model.to("cuda") if torch.cuda.is_available() else model
25
-
26
- logging.info(f"Loaded model {model_name}")
27
- return model, tokenizer
28
-
29
-
30
- def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
31
- """
32
- summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
33
- Args:
34
- ids (): the batch of ids
35
- mask (): the attention mask for the batch
36
- model (): the model to use for summarization
37
- tokenizer (): the tokenizer to use for summarization
38
- Returns:
39
- str: the summary of the batch
40
- """
41
-
42
- ids = ids[None, :]
43
- mask = mask[None, :]
44
-
45
- input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
46
- attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
47
-
48
- global_attention_mask = torch.zeros_like(attention_mask)
49
- # put global attention on <s> token
50
- global_attention_mask[:, 0] = 1
51
-
52
- summary_pred_ids = model.generate(
53
- input_ids,
54
- attention_mask=attention_mask,
55
- global_attention_mask=global_attention_mask,
56
- output_scores=True,
57
- return_dict_in_generate=True,
58
- **kwargs,
59
- )
60
- summary = tokenizer.batch_decode(
61
- summary_pred_ids.sequences,
62
- skip_special_tokens=True,
63
- remove_invalid_values=True,
64
- )
65
- score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
66
-
67
- return summary, score
68
-
69
-
70
- def summarize_via_tokenbatches(
71
- input_text: str,
72
- model,
73
- tokenizer,
74
- batch_length=2048,
75
- batch_stride=16,
76
- **kwargs,
77
- ):
78
- """
79
- summarize_via_tokenbatches - a function that takes a string and returns a summary
80
- Args:
81
- input_text (str): the text to summarize
82
- model (): the model to use for summarization
83
- tokenizer (): the tokenizer to use for summarization
84
- batch_length (int, optional): the length of each batch. Defaults to 2048.
85
- batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
86
- Returns:
87
- str: the summary
88
- """
89
- # log all input parameters
90
- if batch_length < 512:
91
- batch_length = 512
92
- print("WARNING: batch_length was set to 512")
93
- print(
94
- f"input parameters: {kwargs}, batch_length={batch_length}, batch_stride={batch_stride}"
95
- )
96
- encoded_input = tokenizer(
97
- input_text,
98
- padding="max_length",
99
- truncation=True,
100
- max_length=batch_length,
101
- stride=batch_stride,
102
- return_overflowing_tokens=True,
103
- add_special_tokens=False,
104
- return_tensors="pt",
105
- )
106
-
107
- in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
108
- gen_summaries = []
109
-
110
- pbar = tqdm(total=len(in_id_arr))
111
-
112
- for _id, _mask in zip(in_id_arr, att_arr):
113
-
114
- result, score = summarize_and_score(
115
- ids=_id,
116
- mask=_mask,
117
- model=model,
118
- tokenizer=tokenizer,
119
- **kwargs,
120
- )
121
- score = round(float(score), 4)
122
- _sum = {
123
- "input_tokens": _id,
124
- "summary": result,
125
- "summary_score": score,
126
- }
127
- gen_summaries.append(_sum)
128
- print(f"\t{result[0]}\nScore:\t{score}")
129
- pbar.update()
130
-
131
- pbar.close()
132
-
133
- return gen_summaries