Peter commited on
Commit
203509f
1 Parent(s): a04dbc6

✨ add constrained gen script

Browse files

Signed-off-by: Peter <74869040+pszemraj@users.noreply.github.com>

Files changed (1) hide show
  1. constrained_generation.py +255 -0
constrained_generation.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ constrained_generation.py - use constrained beam search to generate text from a model with entered constraints
3
+ """
4
+
5
+ import copy
6
+ import logging
7
+ import time
8
+ from pathlib import Path
9
+
10
+ import yake
11
+ from transformers import AutoTokenizer, PhrasalConstraint
12
+
13
+ def get_tokenizer(model_name="gpt2", verbose=False):
14
+ """
15
+ get_tokenizer - returns a tokenizer object
16
+
17
+ :param model_name: name of the model to use, default gpt2
18
+ :param verbose: verbosity
19
+ """
20
+ tokenizer = AutoTokenizer.from_pretrained(
21
+ model_name, add_special_tokens=False, padding=True, truncation=True
22
+ )
23
+ tokenizer.pad_token = tokenizer.eos_token
24
+ if verbose:
25
+ print(f"loaded tokenizer {model_name}")
26
+ return tokenizer
27
+
28
+
29
+ def unique_words(list_of_strings):
30
+ """
31
+ unique_words - return a list of unique words from a list of strings. Uses set to remove duplicates.
32
+ """
33
+ unique_words = []
34
+ output_list = []
35
+ for string in list_of_strings:
36
+ # split string into words
37
+ words = string.split()
38
+ # check if word is unique
39
+ unique_status = True
40
+ for word in words:
41
+ if word not in unique_words:
42
+ unique_words.append(word)
43
+ else:
44
+ unique_status = False
45
+ break
46
+ if unique_status:
47
+ output_list.append(string)
48
+
49
+ return output_list
50
+
51
+
52
+ def create_kw_extractor(
53
+ language="en",
54
+ max_ngram_size=3,
55
+ deduplication_algo="seqm",
56
+ windowSize=10,
57
+ numOfKeywords=10,
58
+ ddpt=0.7,
59
+ ):
60
+ """
61
+ creates a keyword extractor object
62
+
63
+ :param language: language of the text
64
+ :param max_ngram_size: max ngram size
65
+ :param deduplication_algo: deduplication algorithm
66
+ :param windowSize: window size
67
+ :param numOfKeywords: number of keywords
68
+ :param ddpt: Deduplication Percentage Threshold
69
+
70
+ :return: keyword extractor object
71
+ """
72
+ assert ddpt >= 0 and ddpt <= 1, f"need 0<thresh<1, got {ddpt}"
73
+ return yake.KeywordExtractor(
74
+ lan=language,
75
+ n=max_ngram_size,
76
+ dedupLim=ddpt,
77
+ dedupFunc=deduplication_algo,
78
+ windowsSize=windowSize,
79
+ top=numOfKeywords,
80
+ features=None,
81
+ )
82
+
83
+
84
+ def simple_kw(body_text: str, yake_ex=None, max_kw=10, verbose=False):
85
+ """
86
+ simple_kw - extract keywords from a text using yake
87
+
88
+ Args:
89
+ body_text (str): text to extract keywords from
90
+ yake_ex (yake.KeywordExtractor, optional): yake keyword extractor. Defaults to None.
91
+ max_kw (int, optional): maximum number of keywords to extract. Defaults to 10.
92
+ verbose (bool, optional): Defaults to False.
93
+
94
+ Returns:
95
+ list: list of keywords
96
+ """
97
+ yake_ex = yake_ex or create_kw_extractor(
98
+ max_ngram_size=2,
99
+ ddpt=0.8,
100
+ windowSize=10,
101
+ deduplication_algo="seqm",
102
+ numOfKeywords=max_kw,
103
+ ) # per optuna study
104
+
105
+ keywords = yake_ex.extract_keywords(body_text)
106
+ keywords_list = [str(kw[0]).lower() for kw in keywords]
107
+ logging.info(
108
+ f"YAKE: found {len(keywords_list)} keywords, the top {max_kw} are: {keywords_list[:max_kw]}"
109
+ )
110
+
111
+ if verbose:
112
+
113
+ print(f"found {len(keywords_list)} keywords, the top {max_kw} are:")
114
+ print(keywords_list[:max_kw])
115
+ logging.info(f"found {len(keywords_list)} keywords, the top {max_kw} are:")
116
+
117
+ return keywords_list[:max_kw]
118
+
119
+
120
+ def constrained_generation(
121
+ prompt: str,
122
+ pipeline,
123
+ tokenizer=None,
124
+ no_repeat_ngram_size=2,
125
+ length_penalty=0.7,
126
+ repetition_penalty=3.5,
127
+ num_beams=4,
128
+ max_generated_tokens=48,
129
+ min_generated_tokens=2,
130
+ timeout=300,
131
+ num_return_sequences=1,
132
+ verbose=False,
133
+ full_text=False,
134
+ force_word: str = None,
135
+ speaker_name: str = "Person Alpha",
136
+ responder_name: str = "Person Beta",
137
+ **kwargs,
138
+ ):
139
+ """
140
+ constrained_generation - generate text based on prompt and constraints
141
+
142
+ USAGE
143
+ -----
144
+ response = constrained_generation("hey man - how have you been lately?",
145
+ tokenizer, my_chatbot, verbose=True,
146
+ force_word=" meme", num_beams=32)
147
+
148
+ Parameters
149
+ ----------
150
+ prompt : str, prompt to use for generation,
151
+ tokenizer : transformers.PreTrainedTokenizer, tokenizer to use, must be compatible with model
152
+ pipeline : transformers.pipeline, pipeline to use, must be compatible with tokenizer & text2text model
153
+ no_repeat_ngram_size : int, optional, default=2,
154
+ num_beams : int, optional, default=8,
155
+ max_generated_tokens : int, optional, default=64,
156
+ min_generated_tokens : int, optional, default=16,
157
+ verbose : bool, optional, default=False, print output
158
+ force_word : _type_, optional, default=None, force word to be used in generation
159
+ speaker_name : str, optional, default="Person Alpha", name of speaker
160
+ responder_name : str, optional, default="Person Beta", name of responder
161
+
162
+ Returns
163
+ -------
164
+ response : str, generated text
165
+ """
166
+ st = time.perf_counter()
167
+ tokenizer = tokenizer or copy.deepcopy(pipeline.tokenizer)
168
+ tokenizer.add_prefix_space = True
169
+ tokenizer.add_special_tokens = False
170
+
171
+ prompt_length = len(tokenizer(prompt, truncation=True).input_ids)
172
+ if responder_name.lower() not in prompt.lower():
173
+ prompt = f"{prompt}\n\n{responder_name}:\n"
174
+ # key_prompt_phrases = get_keyberts(prompt)
175
+ key_prompt_phrases = simple_kw(prompt)
176
+
177
+ try:
178
+ responder_name_words = responder_name.lower().split()
179
+ speaker_name_words = speaker_name.lower().split()
180
+ except Exception as e:
181
+ responder_name_words = []
182
+ speaker_name_words = []
183
+ logging.info(f"could not split names: {e}")
184
+
185
+ key_prompt_phrases = [
186
+ p
187
+ for p in key_prompt_phrases
188
+ if not any([name in p for name in responder_name_words])
189
+ and not any([name in p for name in speaker_name_words])
190
+ ]
191
+ force_flexible = unique_words(key_prompt_phrases)
192
+ print(f"found keywords: {force_flexible}")
193
+
194
+ if verbose:
195
+ logging.info(f"found the following keywords: {force_flexible}")
196
+ logging.info(
197
+ f"forcing the word: {force_word}"
198
+ ) if force_word is not None else logging.info("\n")
199
+ else:
200
+ logging.info(f"found the following keywords: {force_flexible}")
201
+
202
+ if len(force_flexible) == 0:
203
+ force_flexible = None
204
+ constraints = (
205
+ [
206
+ PhrasalConstraint(
207
+ tokenizer(force_word, add_special_tokens=False).input_ids,
208
+ ),
209
+ ]
210
+ if force_word is not None
211
+ else None
212
+ )
213
+ force_words_ids = (
214
+ [
215
+ tokenizer(
216
+ force_flexible,
217
+ ).input_ids,
218
+ ]
219
+ if force_flexible is not None
220
+ else None
221
+ )
222
+
223
+ try:
224
+ logging.info("generating text..")
225
+ result = pipeline(
226
+ prompt,
227
+ constraints=constraints if force_word is not None else None,
228
+ force_words_ids=force_words_ids if force_flexible is not None else None,
229
+ max_length=None,
230
+ max_new_tokens=max_generated_tokens,
231
+ min_length=min_generated_tokens + prompt_length if full_text else min_generated_tokens,
232
+ num_beams=num_beams,
233
+ no_repeat_ngram_size=no_repeat_ngram_size,
234
+ num_return_sequences=num_return_sequences,
235
+ max_time=timeout,
236
+ length_penalty=length_penalty,
237
+ repetition_penalty=repetition_penalty,
238
+ return_full_text=full_text,
239
+ remove_invalid_values=True,
240
+ skip_special_tokens=True,
241
+ clean_up_tokenization_spaces=True,
242
+ early_stopping=True,
243
+ do_sample=False,
244
+ **kwargs,
245
+ )
246
+ response = result[0]["generated_text"]
247
+ rt = round((time.perf_counter() - st) / 60, 3)
248
+ logging.info(f"generated response in {rt} minutes")
249
+ if verbose:
250
+ print(f"input prompt:\n\t{prompt}")
251
+ print(f"response:\n\t{response}")
252
+ except Exception as e:
253
+ logging.info(f"could not generate response: {e}")
254
+ response = "Sorry, I don't know how to respond to that."
255
+ return response