File size: 3,870 Bytes
04c25c5
 
9b9128d
b3f1b86
 
1d73b44
 
 
 
 
b3f1b86
 
 
 
 
 
1d73b44
b3f1b86
e9b47ff
d786de6
e9b47ff
 
b3f1b86
 
 
e9b47ff
1d73b44
b3f1b86
 
04c25c5
 
b3f1b86
04c25c5
b3f1b86
 
04c25c5
61da65e
 
 
 
04c25c5
 
 
b3f1b86
49e3627
04c25c5
 
b3f1b86
04c25c5
 
b3f1b86
 
04c25c5
 
d7942b7
b3f1b86
61da65e
b3f1b86
 
 
 
 
 
 
61da65e
04c25c5
b3f1b86
0d86f5d
04c25c5
 
 
 
b3f1b86
61da65e
 
b3f1b86
61da65e
 
b3f1b86
 
 
0d86f5d
04c25c5
 
b3f1b86
 
 
 
 
 
61da65e
b3f1b86
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
from huggingface_hub import InferenceClient
import random

models = [
    "google/gemma-7b",
    "google/gemma-7b-it",
    "google/gemma-2b",
    "google/gemma-2b-it"
]

clients = [
    InferenceClient(models[0]),
    InferenceClient(models[1]),
    InferenceClient(models[2]),
    InferenceClient(models[3]),
]

def format_prompt(message, history):
    prompt = ""
    if history:
        for user_prompt, bot_response in history:
            prompt += f"<start_of_turn>usuário{user_prompt}<end_of_turn>"
            prompt += f"<start_of_turn>modelo{bot_response}"
    prompt += f"<start_of_turn>usuário{message}<end_of_turn><start_of_turn>modelo"
    return prompt

def chat_inf(system_prompt, prompt, history, client_choice, seed, temp, tokens, top_p, rep_p):
    client = clients[int(client_choice) - 1]
    if not history:
        history = []
        hist_len = 0
    if history:
        hist_len = len(history)

    generate_kwargs = dict(
        temperature=temp,
        max_new_tokens=tokens,
        top_p=top_p,
        repetition_penalty=rep_p,
        do_sample=True,
        seed=seed,
    )
    
    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""
    
    for response in stream:
        output += response.token.text
        yield [(prompt, output)]
    history.append((prompt, output))
    yield history

def clear_fn():
    return None, None, None

rand_val = random.randint(1, 1111111111111111)

def check_rand(inp, val):
    if inp == True:
        return gr.Slider(label="Semente", minimum=1, maximum=1111111111111111, value=random.randint(1, 1111111111111111))
    else:
        return gr.Slider(label="Semente", minimum=1, maximum=1111111111111111, value=int(val))

with gr.Blocks() as app:
    gr.HTML("""<center><h1 style='font-size:xx-large;'>Modelos Google Gemma</h1><br><h3>Executando no Cliente de Inferência Huggingface</h3><br><h7>EXPERIMENTAL""")
    chat_b = gr.Chatbot(height=500)
    with gr.Group():
        with gr.Row():
            with gr.Column(scale=3):
                inp = gr.Textbox(label="Prompt")
                sys_inp = gr.Textbox(label="Prompt do Sistema (opcional)")
                with gr.Row():
                    with gr.Column(scale=2):
                        btn = gr.Button("Conversar")
                    with gr.Column(scale=1):
                        with gr.Group():
                            stop_btn = gr.Button("Parar")
                            clear_btn = gr.Button("Limpar")
                client_choice = gr.Dropdown(label="Modelos", type='index', choices=[c for c in models], value=models[0], interactive=True)

            with gr.Column(scale=1):
                with gr.Group():
                    rand = gr.Checkbox(label="Semente Aleatória", value=True)
                    seed = gr.Slider(label="Semente", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
                    tokens = gr.Slider(label="Máximo de novos tokens", value=6400, minimum=0, maximum=8000, step=64, interactive=True, visible=True, info="O número máximo de tokens")
                    temp = gr.Slider(label="Temperatura", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
                    top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
                    rep_p = gr.Slider(label="Penalidade de Repetição", step=0.1, minimum=0.1, maximum=2.0, value=1.0)

    go = btn.click(check_rand, [rand, seed], seed).then(chat_inf, [sys_inp, inp, chat_b, client_choice, seed, temp, tokens, top_p, rep_p], chat_b)
    stop_btn.click(None, None, None, cancels=go)
    clear_btn.click(clear_fn, None, [inp, sys_inp, chat_b])

app.queue(default_concurrency_limit=10).launch()