Spaces:
Runtime error
Runtime error
Slider for repetition_penalty (#4)
Browse files- Updated app.py (bffdc1ac1b3e6973adfed829958385f0312ad1a2)
app.py
CHANGED
@@ -14,13 +14,8 @@ checkpoint = "CohereForAI/aya-101"
|
|
14 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
15 |
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map=device)
|
16 |
|
17 |
-
#Set a the value of the repetition penalty
|
18 |
-
#The higher the value, the less repetitive the generated text will be
|
19 |
-
#Note that `repetition_penalty` has to be a strictly positive float
|
20 |
-
repetition_penalty = 1.8
|
21 |
-
|
22 |
@spaces.GPU
|
23 |
-
def aya(text, max_new_tokens):
|
24 |
model.to(device)
|
25 |
inputs = tokenizer.encode(text, return_tensors="pt").to(device)
|
26 |
outputs = model.generate(inputs, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty)
|
@@ -33,9 +28,10 @@ def main():
|
|
33 |
gr.Markdown(description)
|
34 |
input_text = gr.Textbox(label="🗣️Input Text")
|
35 |
max_new_tokens_slider = gr.Slider(minimum=150, maximum=1648, step=1, value=250, label="Size of your inputs and answer")
|
|
|
36 |
submit_button = gr.Button("Use🌐Aya")
|
37 |
output_text = gr.Textbox(label="🌐Aya", interactive=False)
|
38 |
-
submit_button.click(fn=aya, inputs=[input_text, max_new_tokens_slider], outputs=output_text)
|
39 |
|
40 |
demo.launch()
|
41 |
|
|
|
14 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
15 |
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map=device)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
17 |
@spaces.GPU
|
18 |
+
def aya(text, max_new_tokens, repetition_penalty):
|
19 |
model.to(device)
|
20 |
inputs = tokenizer.encode(text, return_tensors="pt").to(device)
|
21 |
outputs = model.generate(inputs, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty)
|
|
|
28 |
gr.Markdown(description)
|
29 |
input_text = gr.Textbox(label="🗣️Input Text")
|
30 |
max_new_tokens_slider = gr.Slider(minimum=150, maximum=1648, step=1, value=250, label="Size of your inputs and answer")
|
31 |
+
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=4.0, step=0.1, value=1.8, label="Repetition Penalty")
|
32 |
submit_button = gr.Button("Use🌐Aya")
|
33 |
output_text = gr.Textbox(label="🌐Aya", interactive=False)
|
34 |
+
submit_button.click(fn=aya, inputs=[input_text, max_new_tokens_slider, repetition_penalty_slider], outputs=output_text)
|
35 |
|
36 |
demo.launch()
|
37 |
|