Peter commited on
Commit
0b3d061
1 Parent(s): 235585a

🐛 fix input len bug

Browse files

Signed-off-by: Peter <74869040+pszemraj@users.noreply.github.com>

Files changed (3) hide show
  1. app.py +8 -5
  2. converse.py +12 -8
  3. grammar_improve.py +5 -3
app.py CHANGED
@@ -101,11 +101,13 @@ def ask_gpt(
101
  st = time.perf_counter()
102
  prompt = clean(message) # clean user input
103
  prompt = prompt.strip() # get rid of any extra whitespace
104
- in_len = len(prompt)
105
  if in_len > 512:
106
- prompt = prompt[-512:] # truncate to 512 chars
107
- print(f"Truncated prompt to last 512 chars: started with {in_len} chars")
108
- max_len = min(max_len, 512)
 
 
109
 
110
  resp = discussion(
111
  prompt_text=prompt,
@@ -115,7 +117,8 @@ def ask_gpt(
115
  top_p=top_p,
116
  top_k=top_k,
117
  temperature=temperature,
118
- max_length=max_len,
 
119
  )
120
  gpt_et = time.perf_counter()
121
  gpt_rt = round(gpt_et - st, 2)
 
101
  st = time.perf_counter()
102
  prompt = clean(message) # clean user input
103
  prompt = prompt.strip() # get rid of any extra whitespace
104
+ in_len = len(chat_pipe.tokenizer(prompt).input_ids)
105
  if in_len > 512:
106
+ # truncate to last 512 tokens
107
+ tokens = chat_pipe.tokenizer(prompt).input_ids
108
+ trunc_tokens = tokens[-512:]
109
+ prompt = chat_pipe.tokenizer.decode(trunc_tokens)
110
+ print(f"truncated prompt to {len(trunc_tokens)} tokens, input length: {in_len}")
111
 
112
  resp = discussion(
113
  prompt_text=prompt,
 
117
  top_p=top_p,
118
  top_k=top_k,
119
  temperature=temperature,
120
+ max_length=max_length,
121
+ min_length=min_length,
122
  )
123
  gpt_et = time.perf_counter()
124
  gpt_rt = round(gpt_et - st, 2)
converse.py CHANGED
@@ -17,7 +17,8 @@ def discussion(
17
  responder: str,
18
  pipeline,
19
  timeout=45,
20
- max_length=128,
 
21
  top_p=0.95,
22
  top_k=50,
23
  temperature=0.7,
@@ -104,7 +105,8 @@ def gen_response(
104
  speaker: str,
105
  responder: str,
106
  timeout=45,
107
- max_length=128,
 
108
  top_p=0.95,
109
  top_k=50,
110
  temperature=0.7,
@@ -125,7 +127,8 @@ def gen_response(
125
  responder : str, the name of the person who is responding to the prompt
126
  pipeline : transformers.Pipeline, the pipeline to use for generating the response
127
  timeout : int, optional, the number of seconds to wait before timing out, by default 45
128
- max_length : int, optional, the maximum number of tokens to generate, defaults to 128
 
129
  top_p : float, optional, the top probability to use for sampling, defaults to 0.95
130
  top_k : int, optional, the top k to use for sampling, defaults to 50
131
  temperature : float, optional, the temperature to use for sampling, defaults to 0.7
@@ -139,15 +142,16 @@ def gen_response(
139
  str, the generated text
140
 
141
  """
142
-
143
- if max_length > 1024:
144
- max_length = 1024
145
- print("max_length is too large, setting to 1024")
146
  st = time.perf_counter()
147
 
148
  response = pipeline(
149
  query,
150
- max_length=max_length,
 
151
  temperature=temperature,
152
  top_k=top_k,
153
  top_p=top_p,
 
17
  responder: str,
18
  pipeline,
19
  timeout=45,
20
+ min_length=4,
21
+ max_length=64,
22
  top_p=0.95,
23
  top_k=50,
24
  temperature=0.7,
 
105
  speaker: str,
106
  responder: str,
107
  timeout=45,
108
+ min_length=4,
109
+ max_length=64,
110
  top_p=0.95,
111
  top_k=50,
112
  temperature=0.7,
 
127
  responder : str, the name of the person who is responding to the prompt
128
  pipeline : transformers.Pipeline, the pipeline to use for generating the response
129
  timeout : int, optional, the number of seconds to wait before timing out, by default 45
130
+ min_length : int, optional, the minimum number of tokens to generate, defaults to 4
131
+ max_length : int, optional, the maximum number of tokens to generate, defaults to 64
132
  top_p : float, optional, the top probability to use for sampling, defaults to 0.95
133
  top_k : int, optional, the top k to use for sampling, defaults to 50
134
  temperature : float, optional, the temperature to use for sampling, defaults to 0.7
 
142
  str, the generated text
143
 
144
  """
145
+ input_len = len(pipeline.tokenizer(query).input_ids)
146
+ if max_length + input_len > 1024:
147
+ max_length = max(1024 - input_len, 8)
148
+ print(f"max_length too large, setting to {max_length}")
149
  st = time.perf_counter()
150
 
151
  response = pipeline(
152
  query,
153
+ min_length=min_length + input_len,
154
+ max_length=max_length + input_len,
155
  temperature=temperature,
156
  top_k=top_k,
157
  top_p=top_p,
grammar_improve.py CHANGED
@@ -137,10 +137,11 @@ def synthesize_grammar(
137
  """
138
  st = time.perf_counter()
139
  input_text = clean(message, lower=False)
 
140
  results = corrector(
141
  input_text,
142
- max_length=int(1.1 * len(input_text)),
143
- min_length=2 if len(input_text) < 64 else int(0.2 * len(input_text)),
144
  num_beams=num_beams,
145
  repetition_penalty=repetition_penalty,
146
  length_penalty=length_penalty,
@@ -479,7 +480,8 @@ def correct_grammar(
479
  """
480
  st = time.perf_counter()
481
 
482
- if len(input_text) < 5:
 
483
  return input_text
484
  max_length = min(int(math.ceil(len(input_text) * 1.2)), 128)
485
  batch = tokenizer(
 
137
  """
138
  st = time.perf_counter()
139
  input_text = clean(message, lower=False)
140
+ input_len = len(corrector.tokenizer(input_text).input_ids)
141
  results = corrector(
142
  input_text,
143
+ max_length=int(1.1 * input_len),
144
+ min_length=2 if input_len < 64 else int(0.2 * input_len),
145
  num_beams=num_beams,
146
  repetition_penalty=repetition_penalty,
147
  length_penalty=length_penalty,
 
480
  """
481
  st = time.perf_counter()
482
 
483
+ if len(tokenizer(input_text).input_ids) < 4:
484
+ print(f"input text of {input_text} is too short to be corrected")
485
  return input_text
486
  max_length = min(int(math.ceil(len(input_text) * 1.2)), 128)
487
  batch = tokenizer(