indochat / app.py
cahya's picture
add label for dec methode
07c107f
raw
history blame
No virus
4.47 kB
import torch
import gradio as gr
from transformers import pipeline
import os
from mtranslate import translate
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
text_generation_model = "cahya/indochat-tiny"
text_generation = pipeline("text-generation", text_generation_model, use_auth_token=HF_AUTH_TOKEN, device=device)
def get_answer(user_input, decoding_methods, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha):
if decoding_methods == "Beam Search":
do_sample = False
penalty_alpha = 0
elif decoding_methods == "Sampling":
do_sample = True
penalty_alpha = 0
num_beams = 1
else:
do_sample = False
num_beams = 1
print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
prompt = f"User: {user_input}\nAssistant: "
generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1,
num_beams=num_beams, do_sample=do_sample, top_k=top_k, top_p=top_p,
temperature=temperature, repetition_penalty=repetition_penalty,
penalty_alpha=penalty_alpha)
answer = generated_text[0]["generated_text"]
answer_without_prompt = answer[len(prompt)+1:]
user_input_en = translate(user_input, "en", "id")
answer_without_prompt_en = translate(answer_without_prompt, "en", "id")
return [(f"{user_input} ", None), (answer_without_prompt, "")], \
[(f"{user_input_en} ", None), (answer_without_prompt_en, "")]
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("""## IndoChat
A Prove of Concept of a multilingual Chatbot (in this case a bilingual, English and Indonesian), fine-tuned with
multilingual instructions dataset. The base model is a GPT2-Medium (340M params) which was pretrained with 75GB
of Indonesian and English dataset, where English part is only less than 1% of the whole dataset.
""")
with gr.Row():
with gr.Column():
user_input = gr.inputs.Textbox(placeholder="",
label="Ask me something in Indonesian or English",
default="Bagaimana cara mendidik anak supaya tidak berbohong?")
decoding_methods = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"],
default="Sampling", label="Decoding Method")
num_beams = gr.inputs.Slider(label="Number of beams for beam search",
default=1, minimum=1, maximum=10, step=1)
top_k = gr.inputs.Slider(label="Top K",
default=30, maximum=50, minimum=1, step=1)
top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search",
default=0.5, step=0.05, minimum=0.05, maximum=1.0)
with gr.Row():
button_generate_story = gr.Button("Submit")
with gr.Column():
# generated_answer = gr.Textbox()
generated_answer = gr.HighlightedText(
label="Generated Text",
combine_adjacent=True,
).style(color_map={"": "blue", "-": "green"})
generated_answer_en = gr.HighlightedText(
label="Translation",
combine_adjacent=True,
).style(color_map={"": "blue", "-": "green"})
with gr.Row():
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)")
button_generate_story.click(get_answer,
inputs=[user_input, decoding_methods, num_beams, top_k, top_p, temperature,
repetition_penalty, penalty_alpha],
outputs=[generated_answer, generated_answer_en])
demo.launch(enable_queue=False)