File size: 3,649 Bytes
a602c25
dbd9204
a602c25
dbd9204
a602c25
f80d34c
c0860ca
 
 
a602c25
78cc221
 
 
a602c25
78cc221
 
a602c25
78cc221
d1c541a
 
 
 
 
 
 
 
 
78cc221
 
0334332
 
91f14e5
7608b65
 
 
 
 
 
 
0334332
 
7608b65
5ca7e26
 
 
b730ad7
 
5ca7e26
 
78cc221
 
0334332
e1772ff
13c135d
443f792
0334332
7608b65
f0b1c44
0334332
f0b1c44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0334332
91f14e5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import transformers
import gradio as gr
import torch

from transformers import GPT2LMHeadModel, GPT2Tokenizer

model_name = 'microsoft/DialoGPT-large'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

def predict(input, history=[]):
    # tokenize the new input sentence
    new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)

    # generate a response 
    history = model.generate(
                            bot_input_ids,
                            max_length=1000,
                            pad_token_id=tokenizer.eos_token_id,
                            no_repeat_ngram_size=3,
                            top_p = 0.92,
                            top_k = 50
                            ).tolist()
                            
    # convert the tokens to text, and then split the responses into lines
    response = tokenizer.decode(history[0]).split("<|endoftext|>")
    response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)]  # convert to tuples of list
    #response.remove("")
    
    # write some HTML
    html = "<div class='chatbot'>"
    for m, msg in enumerate(response):
        cls = "user" if m%2 == 0 else "bot"
        html += "<div class='msg {}'> {}</div>".format(cls, msg)
    html += "</div>"
    
    return response, history
    #return html, history
    
css = """
.chatbox {display:flex;flex-direction:column}
.msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.msg.user {background-color:cornflowerblue;color:white;margin-right:10px}
.msg.bot {background-color:lightgray;align-self:self-end;margin-left:10px}
.footer {display:none !important}
"""

gr.Interface(fn=predict,
    theme="grass",
    title="DialoGPT-large",
    inputs=["text", "state"],
    outputs=["chatbot", "state"],
    #css=css
    ).launch()

'''
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")

def predict(input, history=[]):
    # tokenize the new input sentence
    new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)

    # generate a response 
    history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()

    # convert the tokens to text, and then split the responses into lines
    response = tokenizer.decode(history[0]).split("<|endoftext|>")
    response.remove("")
    
    # write some HTML
    html = "<div class='chatbot'>"
    for m, msg in enumerate(response):
        cls = "user" if m%2 == 0 else "bot"
        html += "<div class='msg {}'> {}</div>".format(cls, msg)
    html += "</div>"
    
    return html, history

import gradio as gr

css = """
.chatbox {display:flex;flex-direction:column}
.msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.msg.user {background-color:cornflowerblue;color:white}
.msg.bot {background-color:lightgray;align-self:self-end}
.footer {display:none !important}
"""

gr.Interface(fn=predict,
             theme="default",
             inputs=[gr.inputs.Textbox(placeholder="How are you?"), "state"],
             outputs=["html", "state"],
             css=css).launch()
'''