Peter commited on
Commit
950a38f
1 Parent(s): 203509f

✨ integrate constrained textgen

Browse files

Signed-off-by: Peter <74869040+pszemraj@users.noreply.github.com>

Files changed (2) hide show
  1. app.py +2 -0
  2. converse.py +42 -17
app.py CHANGED
@@ -85,6 +85,7 @@ def ask_gpt(
85
  top_p=0.95,
86
  top_k=25,
87
  temperature=0.5,
 
88
  ) -> str:
89
  """
90
  ask_gpt - helper function that asks the GPT model a question and returns the response
@@ -121,6 +122,7 @@ def ask_gpt(
121
  temperature=temperature,
122
  max_length=max_length,
123
  min_length=min_length,
 
124
  )
125
  gpt_et = time.perf_counter()
126
  gpt_rt = round(gpt_et - st, 2)
 
85
  top_p=0.95,
86
  top_k=25,
87
  temperature=0.5,
88
+ constrained_generation=True,
89
  ) -> str:
90
  """
91
  ask_gpt - helper function that asks the GPT model a question and returns the response
 
122
  temperature=temperature,
123
  max_length=max_length,
124
  min_length=min_length,
125
+ constrained_generation = constrained_generation,
126
  )
127
  gpt_et = time.perf_counter()
128
  gpt_rt = round(gpt_et - st, 2)
converse.py CHANGED
@@ -10,6 +10,7 @@ import time
10
 
11
  from grammar_improve import remove_trailing_punctuation
12
 
 
13
 
14
  def discussion(
15
  prompt_text: str,
@@ -28,6 +29,7 @@ def discussion(
28
  num_return_sequences=1,
29
  device=-1,
30
  verbose=False,
 
31
  ):
32
  """
33
  discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot.
@@ -64,23 +66,46 @@ def discussion(
64
  pp.pprint(this_prompt, indent=4)
65
  # call the model
66
  print("\n... generating...")
67
- bot_dialogue = gen_response(
68
- this_prompt,
69
- pipeline,
70
- speaker,
71
- responder,
72
- timeout=timeout,
73
- max_length=max_length,
74
- top_p=top_p,
75
- top_k=top_k,
76
- temperature=temperature,
77
- full_text=full_text,
78
- no_repeat_ngram_size=no_repeat_ngram_size,
79
- length_penalty=length_penalty,
80
- num_return_sequences=num_return_sequences,
81
- device=device,
82
- verbose=verbose,
83
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1:
85
  bot_resp = ", ".join(bot_dialogue)
86
  elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1:
 
10
 
11
  from grammar_improve import remove_trailing_punctuation
12
 
13
+ from constrained_generation import constrained_generation
14
 
15
  def discussion(
16
  prompt_text: str,
 
29
  num_return_sequences=1,
30
  device=-1,
31
  verbose=False,
32
+ constrained_generation=False,
33
  ):
34
  """
35
  discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot.
 
66
  pp.pprint(this_prompt, indent=4)
67
  # call the model
68
  print("\n... generating...")
69
+ if constrained_generation:
70
+ response = constrained_generation(
71
+ prompt=this_prompt,
72
+ pipeline=pipeline,
73
+ no_repeat_ngram_size=no_repeat_ngram_size,
74
+ length_penalty=length_penalty,
75
+ repetition_penalty=1.0,
76
+ num_beams=4,
77
+ timeout=timeout,
78
+ verbose=verbose,
79
+ full_text=full_text,
80
+ speaker_name=speaker,
81
+ responder_name=responder,
82
+ )
83
+
84
+ bot_dialogue = consolidate_texts(
85
+ name_resp=responder,
86
+ model_resp=response,
87
+ name_spk=speaker,
88
+ verbose=verbose,
89
+ print_debug=True,
90
+ )
91
+ else:
92
+ bot_dialogue = gen_response(
93
+ this_prompt,
94
+ pipeline,
95
+ speaker,
96
+ responder,
97
+ timeout=timeout,
98
+ max_length=max_length,
99
+ top_p=top_p,
100
+ top_k=top_k,
101
+ temperature=temperature,
102
+ full_text=full_text,
103
+ no_repeat_ngram_size=no_repeat_ngram_size,
104
+ length_penalty=length_penalty,
105
+ num_return_sequences=num_return_sequences,
106
+ device=device,
107
+ verbose=verbose,
108
+ )
109
  if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1:
110
  bot_resp = ", ".join(bot_dialogue)
111
  elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1: