File size: 3,392 Bytes
35269fa
07eb905
35269fa
2053f34
fd8aaeb
ebb732a
9468cf9
ebb732a
 
 
3b634cc
ebb732a
 
d83fd00
ebb732a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85329fe
ebb732a
 
5a86693
 
ebb732a
 
5a86693
ebb732a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import pip
import os
pip.main(['install', 'transformers'])
pip.main(['install', 'torch'])
pip.main(['install', 'pymongo'])
import gradio as gr
from transformers import pipeline
import pymongo

mongo_client = pymongo.MongoClient(os.environ['DB_URI'])
db = mongo_client["eagle"]
btn_disable=gr.Button.update(interactive=False)
btn_enable=gr.Button.update(interactive=True)
generator = pipeline("text-generation", model="dgnk007/eagle")

def store_in_mongodb(collection_name, data):
    collection = db[collection_name]
    return collection.insert_one(data)
    

def generate_text(message,sequences):
    prompt_template=f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n {message}\n\n### Response:\n\n"
    generated_text = generator(prompt_template, max_length=1024,return_full_text=False,eos_token_id=21017,pad_token_id=50256, num_return_sequences=sequences)
    return generated_text

def general_function(input_text):
    output_text = generate_text(input_text,1)[0]['generated_text']
    store_in_mongodb("general_collection", {"input": input_text, "output": output_text})
    return  output_text

def arena_function(input_text):
    output_text1,output_text2 = generate_text(input_text,1),generate_text(input_text,1)
    data_to_store = {
        "input": input_text,
        "r1": output_text1[0]['generated_text'],
        "r2": output_text2[0]['generated_text'],
    }
    id=store_in_mongodb("arena_collection", data_to_store)
    return  output_text1[0]['generated_text'], output_text2[0]['generated_text'], id.inserted_id,btn_enable,btn_enable,btn_enable,btn_enable

general_interface = gr.Interface(fn=general_function, inputs=gr.Textbox(label="Enter your text here:", min_width=600), outputs="text")

def reward_click(id,reward):
    db["arena_collection"].update_one(
        {"_id": id},
        {"$set": {"reward": reward}}
    )
    return btn_disable,btn_disable,btn_disable,btn_disable

with gr.Blocks() as arena_interface:
    obid=gr.State([])
    with gr.Row():
        with gr.Column():
            input_box = gr.Textbox(label="Enter your text here:", min_width=600)
            prompt = gr.Button("Submit", variant="primary")
    with gr.Row():
        gr.Examples(['what is google?','what is youtube?'], input_box,)
    with gr.Row():
        output_block = [
            gr.Textbox(label="Response 1", interactive=False),
            gr.Textbox(label="Response 2", interactive=False),
            obid
        ]
    with gr.Row():
        tie=gr.Button(value="Tie",size='sm',interactive=False)
        r1=gr.Button(value="Response 1 Wins",variant='primary',interactive=False)
        r2=gr.Button(value="Response 2 Wins",variant='primary',interactive=False)
        bad=gr.Button(value="Both are Bad",variant='secondary',interactive=False)
        buttonGroup=[tie,r1,r2,bad]
    prompt.click(fn=arena_function, inputs=input_box, outputs=output_block+buttonGroup)
    tie.click(fn=reward_click,inputs=[obid,gr.State('tie')],outputs=buttonGroup)
    r1.click(fn=reward_click,inputs=[obid,gr.State('r1')],outputs=buttonGroup)
    r2.click(fn=reward_click,inputs=[obid,gr.State('r2')],outputs=buttonGroup)
    bad.click(fn=reward_click,inputs=[obid,gr.State('bad')],outputs=buttonGroup)
demo = gr.TabbedInterface([general_interface, arena_interface], ["General", "Arena"])

demo.launch()