File size: 2,668 Bytes
6666c9d
6ddba8c
 
f8daa7e
ebea7a7
f8daa7e
ebea7a7
6ddba8c
cdb7c8c
e7cefea
d287f72
cdb7c8c
6666c9d
cdb7c8c
32b8908
e7cefea
2970a90
 
e7cefea
2970a90
e7cefea
cdb7c8c
25a4e57
 
 
 
0b16a1a
 
fd7a80a
cdb7c8c
6666c9d
 
6ddba8c
 
 
 
 
 
 
 
6666c9d
 
 
 
6ddba8c
 
 
 
 
6666c9d
 
6ddba8c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr

tokenizer = AutoTokenizer.from_pretrained('nicholasKluge/Aira-Instruct-124M',
    use_auth_token="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")
model = AutoModelForCausalLM.from_pretrained('nicholasKluge/Aira-Instruct-124M',
    use_auth_token="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") 

disclaimer = """**`Disclaimer`:** This demo should be used for research purposes only. Commercial use is strictly **prohibited**. The model output is not censored and the authors do not endorse the opinions in the generated content. **Use at your own risk.**"""

with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
    
    gr.Markdown("""<h1><center>🔥Aira Demo 🤓🚀</h1></center>""")
    
    with gr.Row(scale=1, equal_height=True):
        
        with gr.Column(scale=5):
            chatbot = gr.Chatbot(label="Aira").style(height=300)
        
        with gr.Column(scale=2):
           
            with gr.Tab(label="Parameters ⚙️"):
                top_k = gr.Slider( minimum=10, maximum=100, value=50, step=5, interactive=True, label="Top-k",)
                top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.70, step=0.05, interactive=True, label="Top-p",)
                temperature = gr.Slider( minimum=0.001, maximum=2.0, value=0.1, step=0.1, interactive=True, label="Temperature",)
                max_length = gr.Slider( minimum=10, maximum=500, value=100, step=10, interactive=True, label="Max Length",)

    msg = gr.Textbox(label="Write a question or comment to Aira", placeholder="Hi Aira, how are you?")
    clear = gr.Button("Clear Conversation 🧹")
    gr.Markdown(disclaimer)

    def generate_response(message, chat_history, top_k, top_p, temperature, max_length):
        inputs = tokenizer(tokenizer.bos_token + message + tokenizer.eos_token, return_tensors="pt")

        response = model.generate(**inputs,
            bos_token_id=tokenizer.bos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            do_sample=True,
            early_stopping=True,   
            top_k=top_k, 
            max_length=max_length,
            top_p=top_p,
            temperature=temperature, 
            num_return_sequences=1)
        
        chat_history.append((f"👤 {message}", f"""🤖 {tokenizer.decode(response[0], skip_special_tokens=True).replace(message, "")}"""))

        return "", chat_history
    
    msg.submit(generate_response, [msg, chatbot, top_k, top_p, temperature, max_length], [msg, chatbot])
    clear.click(lambda: None, None, chatbot, queue=False)

demo.launch()