Tonic commited on
Commit
b9faabf
1 Parent(s): af27f87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -29,7 +29,7 @@ class TuluChatBot:
29
  prompt = f"<|assistant|>\n {self.system_message}\n\n <|user|>{user_message}\n\n<|assistant|>\n"
30
  return prompt
31
 
32
- def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
33
  prompt = self.format_prompt(user_message)
34
  inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
35
  input_ids = inputs["input_ids"].to(self.model.device)
@@ -43,16 +43,21 @@ class TuluChatBot:
43
  temperature=temperature,
44
  top_p=top_p,
45
  repetition_penalty=repetition_penalty,
46
- do_sample=True
47
  )
48
 
49
  response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
50
  return response
51
 
52
- def gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty):
53
  Tulu_bot.set_system_message(system_message)
54
- response = Tulu_bot.predict(user_message, temperature, max_new_tokens, top_p, repetition_penalty)
 
 
 
 
55
  return response
 
56
 
57
  Tulu_bot = TuluChatBot(model, tokenizer)
58
 
@@ -63,10 +68,11 @@ iface = gr.Interface(
63
  inputs=[
64
  gr.Textbox(label="Your Message", type="text", lines=3),
65
  gr.Textbox(label="Introduce a Character Here or Set a Scene (system prompt)", type="text", lines=2),
66
- gr.Slider(label="Max new tokens", value=1269, minimum=550, maximum=3200, step=1),
67
- gr.Slider(label="Temperature", value=1.2, minimum=0.05, maximum=4.0, step=0.05),
68
- gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
69
- gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
 
70
  ],
71
  outputs="text",
72
  theme="ParityError/Anime"
 
29
  prompt = f"<|assistant|>\n {self.system_message}\n\n <|user|>{user_message}\n\n<|assistant|>\n"
30
  return prompt
31
 
32
+ def predict(self, user_message, temperature, max_new_tokens, top_p, repetition_penalty, do_sample):
33
  prompt = self.format_prompt(user_message)
34
  inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
35
  input_ids = inputs["input_ids"].to(self.model.device)
 
43
  temperature=temperature,
44
  top_p=top_p,
45
  repetition_penalty=repetition_penalty,
46
+ do_sample=do_sample
47
  )
48
 
49
  response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
50
  return response
51
 
52
+ def gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample):
53
  Tulu_bot.set_system_message(system_message)
54
+ if not do_sample:
55
+ temperature = 1.2 # Default value
56
+ top_p = 0.9 # Default value
57
+ repetition_penalty = 0.9 # Default value
58
+ response = Tulu_bot.predict(user_message, temperature, max_new_tokens, top_p, repetition_penalty, do_sample)
59
  return response
60
+
61
 
62
  Tulu_bot = TuluChatBot(model, tokenizer)
63
 
 
68
  inputs=[
69
  gr.Textbox(label="Your Message", type="text", lines=3),
70
  gr.Textbox(label="Introduce a Character Here or Set a Scene (system prompt)", type="text", lines=2),
71
+ gr.Checkbox(label="Advanced", value=True, id="do_sample"), # Add an id to the checkbox
72
+ gr.Slider(label="Max new tokens", value=1269, minimum=550, maximum=3200, step=1, visible=gr.Visibility(id="do_sample", value=True)),
73
+ gr.Slider(label="Temperature", value=1.2, minimum=0.05, maximum=4.0, step=0.05, visible=gr.Visibility(id="do_sample", value=True)),
74
+ gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05, visible=gr.Visibility(id="do_sample", value=True)),
75
+ gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05, visible=gr.Visibility(id="do_sample", value=True))
76
  ],
77
  outputs="text",
78
  theme="ParityError/Anime"