File size: 3,722 Bytes
87d4fba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr

from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM
)

tokenizer = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-base-seq2seq")
model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-base-seq2seq")

preset_examples = [
    ('Instruction: given a dialog context, you need to response empathically.',
     '', 'Does money buy happiness?', 'Chitchat'),
]


def generate(instruction, knowledge, dialog, top_p, min_length, max_length):
    if knowledge != '':
        knowledge = '[KNOWLEDGE] ' + knowledge
    dialog = ' EOS '.join(dialog)
    query = f"{instruction} [CONTEXT] {dialog} {knowledge}"

    input_ids = tokenizer(f"{query}", return_tensors="pt").input_ids
    outputs = model.generate(input_ids, min_length=int(
        min_length), max_length=int(max_length), top_p=top_p, do_sample=True)
    output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(query)
    print(output)
    return output


def api_call_generation(instruction, knowledge, query, top_p, min_length, max_length):

    dialog = [
        query
    ]
    response = generate(instruction, knowledge, dialog,
                        top_p, min_length, max_length)

    return response


def change_example(choice):
    choice_idx = int(choice.split()[-1]) - 1
    instruction, knowledge, query, instruction_type = preset_examples[choice_idx]
    return [gr.update(lines=1, visible=True, value=instruction), gr.update(visible=True, value=knowledge), gr.update(lines=1, visible=True, value=query), gr.update(visible=True, value=instruction_type)]

def change_textbox(choice):
    if choice == "Chitchat":
        return gr.update(lines=1, visible=True, value="Instruction: given a dialog context, you need to response empathically.")
    elif choice == "Grounded Response Generation":
        return gr.update(lines=1, visible=True, value="Instruction: given a dialog context and related knowledge, you need to response safely based on the knowledge.")
    else:
        return gr.update(lines=1, visible=True, value="Instruction: given a dialog context and related knowledge, you need to answer the question based on the knowledge.")


with gr.Blocks() as demo:
    gr.Markdown("# The broken God")
    gr.Markdown('''All hail Mekhane. Reject flesh. Embrace metal''')

    dropdown = gr.Dropdown(
        [f"Example {i+1}" for i in range(1)], label='Examples')

    radio = gr.Radio(
        ["Conversational Question Answering", "Chitchat", "Grounded Response Generation"], label="Instruction Type", value='Conversational Question Answering'
    )
    instruction = gr.Textbox(lines=1, interactive=True, label="Instruction",
                             value="Instruction: given a dialog context and related knowledge, you need to answer the question based on the knowledge.")
    radio.change(fn=change_textbox, inputs=radio, outputs=instruction)
    knowledge = gr.Textbox(lines=6, label="Knowledge")
    query = gr.Textbox(lines=1, label="User Query")

    dropdown.change(change_example, dropdown, [instruction, knowledge, query, radio])

    with gr.Row():
        with gr.Column(scale=1):
            response = gr.Textbox(label="Response", lines=2)

        with gr.Column(scale=1):
            top_p = gr.Slider(0, 1, value=0.9, label='top_p')
            min_length = gr.Number(8, label='min_length')
            max_length = gr.Number(
                64, label='max_length (should be larger than min_length)')

    greet_btn = gr.Button("Generate")
    greet_btn.click(fn=api_call_generation, inputs=[
                    instruction, knowledge, query, top_p, min_length, max_length], outputs=response)

demo.launch()