Spaces:
Running
Running
Commit
·
53ebaa0
1
Parent(s):
5507a1b
Update app.py
Browse files
app.py
CHANGED
@@ -69,6 +69,7 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
|
|
69 |
top_k = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True, label="Top-k", info="Controls the number of highest probability tokens to consider for each step.")
|
70 |
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.70, step=0.05, interactive=True, label="Top-p", info="Controls the cumulative probability of the generated tokens.")
|
71 |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.1, step=0.1, interactive=True, label="Temperature", info="Controls the randomness of the generated tokens.")
|
|
|
72 |
max_length = gr.Slider(minimum=10, maximum=500, value=100, step=10, interactive=True, label="Max Length", info="Controls the maximum length of the generated text.")
|
73 |
smaple_from = gr.Slider(minimum=2, maximum=10, value=2, step=1, interactive=True, label="Sample From", info="Controls the number of generations that the reward model will sample from.")
|
74 |
|
@@ -78,7 +79,7 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
|
|
78 |
def user(user_message, chat_history):
|
79 |
return gr.update(value=user_message, interactive=True), chat_history + [["👤 " + user_message, None]]
|
80 |
|
81 |
-
def generate_response(user_msg, top_p, temperature, top_k, max_length, smaple_from, safety, chat_history):
|
82 |
|
83 |
inputs = tokenizer(tokenizer.bos_token + user_msg + tokenizer.eos_token, return_tensors="pt").to(model.device)
|
84 |
|
@@ -86,6 +87,7 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
|
|
86 |
bos_token_id=tokenizer.bos_token_id,
|
87 |
pad_token_id=tokenizer.pad_token_id,
|
88 |
eos_token_id=tokenizer.eos_token_id,
|
|
|
89 |
do_sample=True,
|
90 |
early_stopping=True,
|
91 |
top_k=top_k,
|
@@ -145,7 +147,7 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
|
|
145 |
yield chat_history
|
146 |
|
147 |
response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
148 |
-
generate_response, [msg, top_p, temperature, top_k, max_length, smaple_from, safety, chatbot], chatbot
|
149 |
)
|
150 |
response.then(lambda: gr.update(interactive=True), None, [msg], queue=False)
|
151 |
msg.submit(lambda x: gr.update(value=''), None,[msg])
|
|
|
69 |
top_k = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True, label="Top-k", info="Controls the number of highest probability tokens to consider for each step.")
|
70 |
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.70, step=0.05, interactive=True, label="Top-p", info="Controls the cumulative probability of the generated tokens.")
|
71 |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.1, step=0.1, interactive=True, label="Temperature", info="Controls the randomness of the generated tokens.")
|
72 |
+
repetition_penalty = gr.Slider(minimum=1, maximum=2, value=1.5, step=0.1, interactive=True, label="Repetition Penalty", info="Higher values help the model to avoid repetition in text generation.")
|
73 |
max_length = gr.Slider(minimum=10, maximum=500, value=100, step=10, interactive=True, label="Max Length", info="Controls the maximum length of the generated text.")
|
74 |
smaple_from = gr.Slider(minimum=2, maximum=10, value=2, step=1, interactive=True, label="Sample From", info="Controls the number of generations that the reward model will sample from.")
|
75 |
|
|
|
79 |
def user(user_message, chat_history):
|
80 |
return gr.update(value=user_message, interactive=True), chat_history + [["👤 " + user_message, None]]
|
81 |
|
82 |
+
def generate_response(user_msg, top_p, temperature, top_k, max_length, smaple_from, repetition_penalty, safety, chat_history):
|
83 |
|
84 |
inputs = tokenizer(tokenizer.bos_token + user_msg + tokenizer.eos_token, return_tensors="pt").to(model.device)
|
85 |
|
|
|
87 |
bos_token_id=tokenizer.bos_token_id,
|
88 |
pad_token_id=tokenizer.pad_token_id,
|
89 |
eos_token_id=tokenizer.eos_token_id,
|
90 |
+
repetition_penalty=repetition_penalty,
|
91 |
do_sample=True,
|
92 |
early_stopping=True,
|
93 |
top_k=top_k,
|
|
|
147 |
yield chat_history
|
148 |
|
149 |
response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
150 |
+
generate_response, [msg, top_p, temperature, top_k, max_length, smaple_from, repetition_penalty, safety, chatbot], chatbot
|
151 |
)
|
152 |
response.then(lambda: gr.update(interactive=True), None, [msg], queue=False)
|
153 |
msg.submit(lambda x: gr.update(value=''), None,[msg])
|