Peter commited on
Commit
b4c0306
1 Parent(s): 8d9ed7d

✨ integrate constrained gen

Browse files

Signed-off-by: Peter <74869040+pszemraj@users.noreply.github.com>

Files changed (3) hide show
  1. app.py +31 -26
  2. constrained_generation.py +3 -5
  3. converse.py +11 -5
app.py CHANGED
@@ -5,6 +5,9 @@ app.py - the main file for the app. This creates the flask app and handles the r
5
 
6
  import argparse
7
  import logging
 
 
 
8
  import os
9
  import sys
10
  import time
@@ -16,7 +19,7 @@ import gradio as gr
16
  import nltk
17
  import torch
18
  from cleantext import clean
19
- from gradio.inputs import Slider, Textbox
20
  from transformers import pipeline
21
 
22
  from converse import discussion
@@ -40,13 +43,12 @@ warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
40
  import transformers
41
 
42
  transformers.logging.set_verbosity_error()
43
- logging.basicConfig()
44
  cwd = Path.cwd()
45
  my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
46
 
47
 
48
  def chat(
49
- prompt_message, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 20
50
  ) -> str:
51
  """
52
  chat - the main function for the chatbot. This is the function that is called when the user
@@ -55,6 +57,7 @@ def chat(
55
  :param float temperature: the temperature value for the model, defaults to 0.6
56
  :param float top_p: the top_p value for the model, defaults to 0.95
57
  :param int top_k: the top_k value for the model, defaults to 25
 
58
  :return str: the response from the model
59
  """
60
  history = []
@@ -64,6 +67,7 @@ def chat(
64
  top_p=top_p,
65
  top_k=top_k,
66
  temperature=temperature,
 
67
  )
68
  history = [prompt_message, response]
69
  html = ""
@@ -85,7 +89,8 @@ def ask_gpt(
85
  top_p=0.95,
86
  top_k=25,
87
  temperature=0.5,
88
- constrained_generation=True,
 
89
  ) -> str:
90
  """
91
  ask_gpt - helper function that asks the GPT model a question and returns the response
@@ -99,19 +104,20 @@ def ask_gpt(
99
  :param float top_p: the top_p value for the model, defaults to 0.95
100
  :param int top_k: the top_k value for the model, defaults to 25
101
  :param float temperature: the temperature value for the model, defaults to 0.6
 
102
  :return str: the response from the model
103
  """
104
  st = time.perf_counter()
105
  prompt = clean(message) # clean user input
106
  prompt = prompt.strip() # get rid of any extra whitespace
107
  in_len = len(chat_pipe.tokenizer(prompt).input_ids)
108
- if in_len > 512:
109
- # truncate to last 512 tokens
110
  tokens = chat_pipe.tokenizer(prompt).input_ids
111
- trunc_tokens = tokens[-512:]
112
  prompt = chat_pipe.tokenizer.decode(trunc_tokens)
113
  print(f"truncated prompt to {len(trunc_tokens)} tokens, input length: {in_len}")
114
-
115
  resp = discussion(
116
  prompt_text=prompt,
117
  pipeline=chat_pipe,
@@ -122,7 +128,7 @@ def ask_gpt(
122
  temperature=temperature,
123
  max_length=max_length,
124
  min_length=min_length,
125
- constrained_generation = constrained_generation,
126
  )
127
  gpt_et = time.perf_counter()
128
  gpt_rt = round(gpt_et - st, 2)
@@ -134,10 +140,9 @@ def ask_gpt(
134
  cln_resp = synthesize_grammar(corrector=grammarbot, message=rawtxt)
135
  bot_resp_a = corr(remove_repeated_words(cln_resp))
136
  bot_resp = fix_punct_spacing(bot_resp_a)
137
- print(f"the prompt was:\n\t{message}\nand the response was:\n\t{bot_resp}\n")
138
  corr_rt = round(time.perf_counter() - gpt_et, 4)
139
  print(
140
- f"took {gpt_rt + corr_rt} sec to respond, {gpt_rt} for GPT, {corr_rt} for correction\n"
141
  )
142
  return remove_trailing_punctuation(bot_resp)
143
 
@@ -225,7 +230,7 @@ if __name__ == "__main__":
225
  Textbox(
226
  default="Why is everyone here eating chocolate cake?",
227
  label="prompt_message",
228
- placeholder="Enter a question",
229
  lines=2,
230
  ),
231
  Slider(
@@ -233,20 +238,21 @@ if __name__ == "__main__":
233
  ),
234
  Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.95, label="top_p"),
235
  Slider(minimum=0, maximum=100, step=5, default=20, label="top_k"),
 
236
  ],
237
  outputs="html",
238
  examples_per_page=8,
239
  examples=[
240
- ["Point Break or Bad Boys II?", 0.75, 0.95, 50],
241
- ["So... you're saying this wasn't an accident?", 0.6, 0.95, 40],
242
- ["Hi, my name is Reginald", 0.6, 0.95, 100],
243
- ["Happy birthday!", 0.9, 0.95, 50],
244
- ["I have a question, can you help me?", 0.6, 0.95, 50],
245
- ["Do you know a joke?", 0.8, 0.85, 50],
246
- ["Will you marry me?", 0.9, 0.95, 100],
247
- ["Are you single?", 0.95, 0.95, 100],
248
- ["Do you like people?", 0.7, 0.95, 25],
249
- ["You never took a shortcut before?", 0.7, 0.95, 100],
250
  ],
251
  title=f"GPT Chatbot Demo: {default_model} Model",
252
  description=f"A Demo of a Chatbot trained for conversation with humans. Size XL= 1.5B parameters.\n\n"
@@ -254,20 +260,19 @@ if __name__ == "__main__":
254
  "You can find a link to the model card **[here](https://huggingface.co/ethzanalytics/ai-msgbot-gpt2-XL-dialogue)**\n\n"
255
  "1. responses can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
256
  "2. the model was trained on several different datasets. fact-check responses instead of regarding as a true statement.\n"
257
- "3. Try adjusting the **[generation parameters](https://huggingface.co/blog/how-to-generate)** to get a better understanding of how they work!\n",
 
258
  css="""
259
  .chatbox {display:flex;flex-direction:row}
260
  .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
261
  .user_msg {background-color:cornflowerblue;color:white;align-self:start}
262
  .resp_msg {background-color:lightgray;align-self:self-end}
263
  """,
264
- allow_screenshot=True,
265
  allow_flagging="never",
266
  theme="dark",
267
  )
268
 
269
  # launch the gradio interface and start the server
270
  iface.launch(
271
- # prevent_thread_lock=True,
272
- enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
273
  )
 
5
 
6
  import argparse
7
  import logging
8
+
9
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
10
+
11
  import os
12
  import sys
13
  import time
 
19
  import nltk
20
  import torch
21
  from cleantext import clean
22
+ from gradio.inputs import Slider, Textbox, Radio
23
  from transformers import pipeline
24
 
25
  from converse import discussion
 
43
  import transformers
44
 
45
  transformers.logging.set_verbosity_error()
 
46
  cwd = Path.cwd()
47
  my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
48
 
49
 
50
  def chat(
51
+ prompt_message, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 20, constrained_generation: str = "False"
52
  ) -> str:
53
  """
54
  chat - the main function for the chatbot. This is the function that is called when the user
 
57
  :param float temperature: the temperature value for the model, defaults to 0.6
58
  :param float top_p: the top_p value for the model, defaults to 0.95
59
  :param int top_k: the top_k value for the model, defaults to 25
60
+ :param bool constrained_generation: whether to use constrained generation or not, defaults to False
61
  :return str: the response from the model
62
  """
63
  history = []
 
67
  top_p=top_p,
68
  top_k=top_k,
69
  temperature=temperature,
70
+ constrained_generation="true" in constrained_generation.lower(),
71
  )
72
  history = [prompt_message, response]
73
  html = ""
 
89
  top_p=0.95,
90
  top_k=25,
91
  temperature=0.5,
92
+ constrained_generation=False,
93
+ max_input_length=128,
94
  ) -> str:
95
  """
96
  ask_gpt - helper function that asks the GPT model a question and returns the response
 
104
  :param float top_p: the top_p value for the model, defaults to 0.95
105
  :param int top_k: the top_k value for the model, defaults to 25
106
  :param float temperature: the temperature value for the model, defaults to 0.6
107
+ :param bool constrained_generation: whether to use constrained generation or not, defaults to False
108
  :return str: the response from the model
109
  """
110
  st = time.perf_counter()
111
  prompt = clean(message) # clean user input
112
  prompt = prompt.strip() # get rid of any extra whitespace
113
  in_len = len(chat_pipe.tokenizer(prompt).input_ids)
114
+ if in_len > max_input_length:
115
+ # truncate to last max_input_length tokens
116
  tokens = chat_pipe.tokenizer(prompt).input_ids
117
+ trunc_tokens = tokens[-max_input_length:]
118
  prompt = chat_pipe.tokenizer.decode(trunc_tokens)
119
  print(f"truncated prompt to {len(trunc_tokens)} tokens, input length: {in_len}")
120
+ logging.info(f"prompt: {prompt}")
121
  resp = discussion(
122
  prompt_text=prompt,
123
  pipeline=chat_pipe,
 
128
  temperature=temperature,
129
  max_length=max_length,
130
  min_length=min_length,
131
+ constrained_beam_search = constrained_generation,
132
  )
133
  gpt_et = time.perf_counter()
134
  gpt_rt = round(gpt_et - st, 2)
 
140
  cln_resp = synthesize_grammar(corrector=grammarbot, message=rawtxt)
141
  bot_resp_a = corr(remove_repeated_words(cln_resp))
142
  bot_resp = fix_punct_spacing(bot_resp_a)
 
143
  corr_rt = round(time.perf_counter() - gpt_et, 4)
144
  print(
145
+ f"{gpt_rt + corr_rt} to respond, {gpt_rt} GPT, {corr_rt} for correction\n"
146
  )
147
  return remove_trailing_punctuation(bot_resp)
148
 
 
230
  Textbox(
231
  default="Why is everyone here eating chocolate cake?",
232
  label="prompt_message",
233
+ placeholder="Start a conversation with the bot",
234
  lines=2,
235
  ),
236
  Slider(
 
238
  ),
239
  Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.95, label="top_p"),
240
  Slider(minimum=0, maximum=100, step=5, default=20, label="top_k"),
241
+ Radio(choices=["True", "False"], default="False", label="constrained_generation"),
242
  ],
243
  outputs="html",
244
  examples_per_page=8,
245
  examples=[
246
+ ["Point Break or Bad Boys II?", 0.75, 0.95, 50, False],
247
+ ["So... you're saying this wasn't an accident?", 0.6, 0.95, 40, False],
248
+ ["Hi, my name is Reginald", 0.6, 0.95, 100, False],
249
+ ["Happy birthday!", 0.9, 0.95, 50, False],
250
+ ["I have a question, can you help me?", 0.6, 0.95, 50, False],
251
+ ["Do you know a joke?", 0.8, 0.85, 50, False],
252
+ ["Will you marry me?", 0.9, 0.95, 100, False],
253
+ ["Are you single?", 0.95, 0.95, 100, False],
254
+ ["Do you like people?", 0.7, 0.95, 25, False],
255
+ ["You never took a shortcut before?", 0.7, 0.95, 100, False],
256
  ],
257
  title=f"GPT Chatbot Demo: {default_model} Model",
258
  description=f"A Demo of a Chatbot trained for conversation with humans. Size XL= 1.5B parameters.\n\n"
 
260
  "You can find a link to the model card **[here](https://huggingface.co/ethzanalytics/ai-msgbot-gpt2-XL-dialogue)**\n\n"
261
  "1. responses can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
262
  "2. the model was trained on several different datasets. fact-check responses instead of regarding as a true statement.\n"
263
+ "3. Try adjusting the **[generation parameters](https://huggingface.co/blog/how-to-generate)** to get a better understanding of how they work!\n"
264
+ "4. New - try using [constrained beam search](https://huggingface.co/blog/constrained-beam-search) decoding to generate more coherent responses. _(experimental, feedback welcome!)_\n",
265
  css="""
266
  .chatbox {display:flex;flex-direction:row}
267
  .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
268
  .user_msg {background-color:cornflowerblue;color:white;align-self:start}
269
  .resp_msg {background-color:lightgray;align-self:self-end}
270
  """,
 
271
  allow_flagging="never",
272
  theme="dark",
273
  )
274
 
275
  # launch the gradio interface and start the server
276
  iface.launch(
277
+ enable_queue=True,
 
278
  )
constrained_generation.py CHANGED
@@ -4,6 +4,7 @@
4
 
5
  import copy
6
  import logging
 
7
  import time
8
  from pathlib import Path
9
 
@@ -81,7 +82,7 @@ def create_kw_extractor(
81
  )
82
 
83
 
84
- def simple_kw(body_text: str, yake_ex=None, max_kw=10, verbose=False):
85
  """
86
  simple_kw - extract keywords from a text using yake
87
 
@@ -96,7 +97,7 @@ def simple_kw(body_text: str, yake_ex=None, max_kw=10, verbose=False):
96
  """
97
  yake_ex = yake_ex or create_kw_extractor(
98
  max_ngram_size=2,
99
- ddpt=0.8,
100
  windowSize=10,
101
  deduplication_algo="seqm",
102
  numOfKeywords=max_kw,
@@ -219,7 +220,6 @@ def constrained_generation(
219
  if force_flexible is not None
220
  else None
221
  )
222
-
223
  try:
224
  logging.info("generating text..")
225
  result = pipeline(
@@ -236,8 +236,6 @@ def constrained_generation(
236
  length_penalty=length_penalty,
237
  repetition_penalty=repetition_penalty,
238
  return_full_text=full_text,
239
- remove_invalid_values=True,
240
- skip_special_tokens=True,
241
  clean_up_tokenization_spaces=True,
242
  early_stopping=True,
243
  do_sample=False,
 
4
 
5
  import copy
6
  import logging
7
+ logging.basicConfig(level=logging.INFO)
8
  import time
9
  from pathlib import Path
10
 
 
82
  )
83
 
84
 
85
+ def simple_kw(body_text: str, yake_ex=None, max_kw=15, verbose=False):
86
  """
87
  simple_kw - extract keywords from a text using yake
88
 
 
97
  """
98
  yake_ex = yake_ex or create_kw_extractor(
99
  max_ngram_size=2,
100
+ ddpt=0.9,
101
  windowSize=10,
102
  deduplication_algo="seqm",
103
  numOfKeywords=max_kw,
 
220
  if force_flexible is not None
221
  else None
222
  )
 
223
  try:
224
  logging.info("generating text..")
225
  result = pipeline(
 
236
  length_penalty=length_penalty,
237
  repetition_penalty=repetition_penalty,
238
  return_full_text=full_text,
 
 
239
  clean_up_tokenization_spaces=True,
240
  early_stopping=True,
241
  do_sample=False,
converse.py CHANGED
@@ -4,7 +4,8 @@
4
  https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/model#transformers.generation_utils.GenerationMixin.generate.no_repeat_ngram_size
5
  """
6
 
7
-
 
8
  import pprint as pp
9
  import time
10
 
@@ -29,7 +30,7 @@ def discussion(
29
  num_return_sequences=1,
30
  device=-1,
31
  verbose=False,
32
- constrained_generation=False,
33
  ):
34
  """
35
  discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot.
@@ -66,7 +67,8 @@ def discussion(
66
  pp.pprint(this_prompt, indent=4)
67
  # call the model
68
  print("\n... generating...")
69
- if constrained_generation:
 
70
  response = constrained_generation(
71
  prompt=this_prompt,
72
  pipeline=pipeline,
@@ -75,7 +77,7 @@ def discussion(
75
  repetition_penalty=1.0,
76
  num_beams=4,
77
  timeout=timeout,
78
- verbose=verbose,
79
  full_text=full_text,
80
  speaker_name=speaker,
81
  responder_name=responder,
@@ -83,12 +85,15 @@ def discussion(
83
 
84
  bot_dialogue = consolidate_texts(
85
  name_resp=responder,
86
- model_resp=response,
 
 
87
  name_spk=speaker,
88
  verbose=verbose,
89
  print_debug=True,
90
  )
91
  else:
 
92
  bot_dialogue = gen_response(
93
  this_prompt,
94
  pipeline,
@@ -123,6 +128,7 @@ def discussion(
123
  p_list.append("\n")
124
 
125
  print("\nfinished!")
 
126
  # return the bot response and the full conversation
127
 
128
  return {"out_text": bot_resp, "full_conv": p_list}
 
4
  https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/model#transformers.generation_utils.GenerationMixin.generate.no_repeat_ngram_size
5
  """
6
 
7
+ import logging
8
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
9
  import pprint as pp
10
  import time
11
 
 
30
  num_return_sequences=1,
31
  device=-1,
32
  verbose=False,
33
+ constrained_beam_search=False,
34
  ):
35
  """
36
  discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot.
 
67
  pp.pprint(this_prompt, indent=4)
68
  # call the model
69
  print("\n... generating...")
70
+ if constrained_beam_search:
71
+ logging.info("using constrained beam search")
72
  response = constrained_generation(
73
  prompt=this_prompt,
74
  pipeline=pipeline,
 
77
  repetition_penalty=1.0,
78
  num_beams=4,
79
  timeout=timeout,
80
+ verbose=False,
81
  full_text=full_text,
82
  speaker_name=speaker,
83
  responder_name=responder,
 
85
 
86
  bot_dialogue = consolidate_texts(
87
  name_resp=responder,
88
+ model_resp=response.split(
89
+ "\n"
90
+ ),
91
  name_spk=speaker,
92
  verbose=verbose,
93
  print_debug=True,
94
  )
95
  else:
96
+ logging.info("using sampling")
97
  bot_dialogue = gen_response(
98
  this_prompt,
99
  pipeline,
 
128
  p_list.append("\n")
129
 
130
  print("\nfinished!")
131
+ logging.info(f"finished generating response:\n\t{bot_resp}")
132
  # return the bot response and the full conversation
133
 
134
  return {"out_text": bot_resp, "full_conv": p_list}