project-baize commited on
Commit
e0ec530
1 Parent(s): 9d96329

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -35,14 +35,14 @@ def predict(text,
35
  return
36
 
37
  inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
38
- if inputs is False:
39
- yield chatbot+[[text,"Sorry, the input is too long."]],history,"Generate Fail"
40
- return
41
  else:
42
  prompt,inputs=inputs
43
  begin_length = len(prompt)
 
44
  torch.cuda.empty_cache()
45
- input_ids = inputs["input_ids"].to(device)
46
  global total_count
47
  total_count += 1
48
  print(total_count)
@@ -63,6 +63,7 @@ def predict(text,
63
  return
64
  except:
65
  pass
 
66
  #print(text)
67
  #print(x)
68
  #print("="*80)
@@ -150,7 +151,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
150
  )
151
  max_context_length_tokens = gr.Slider(
152
  minimum=0,
153
- maximum=4096,
154
  value=2048,
155
  step=128,
156
  interactive=True,
 
35
  return
36
 
37
  inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
38
+ if inputs is None:
39
+ yield chatbot,history,"Too Long Input"
40
+ return
41
  else:
42
  prompt,inputs=inputs
43
  begin_length = len(prompt)
44
+ input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
45
  torch.cuda.empty_cache()
 
46
  global total_count
47
  total_count += 1
48
  print(total_count)
 
63
  return
64
  except:
65
  pass
66
+ torch.cuda.empty_cache()
67
  #print(text)
68
  #print(x)
69
  #print("="*80)
 
151
  )
152
  max_context_length_tokens = gr.Slider(
153
  minimum=0,
154
+ maximum=3072,
155
  value=2048,
156
  step=128,
157
  interactive=True,