File size: 1,896 Bytes
cc28aad b73e8d0 e7899d4 b73e8d0 c4fe73d 20cda87 c4fe73d b73e8d0 cc28aad 20cda87 b73e8d0 cc28aad c4fe73d 20cda87 b73e8d0 c4fe73d cc28aad c4fe73d b73e8d0 cc28aad b73e8d0 cc28aad 979f78b 20cda87 cc28aad 20cda87 cc28aad c4fe73d b73e8d0 c4fe73d cc28aad b73e8d0 20cda87 cc28aad b73e8d0 979f78b cc28aad 20cda87 cc28aad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
import torch
import gradio as gr
#model_name = "facebook/blenderbot-400M-distill"
model_name = "microsoft/DialoGPT-medium"
chat_token = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def converse(user_input, chat_history=[]):
user_input_ids = chat_token(user_input + chat_token.eos_token, return_tensors='pt').input_ids
# keep history in the tensor
bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
# get response
chat_history = model.generate(bot_input_ids, max_length=1000, pad_token_id=chat_token.eos_token_id).tolist()
print (chat_history)
response = chat_token.decode(chat_history[0]).split("<|endoftext|>")
print("Starting to print response...")
print(response)
# html for display
html = "<div class='mybot'>"
for x, mesg in enumerate(response):
if x%2!=0 :
mesg="BOT: " + mesg
clazz="bot"
else :
clazz="user"
print("Value of x: ")
print(x)
print("Message: ")
print (mesg)
html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
html += "</div>"
print(html)
return html, chat_history
css = """
.mychat {display:flex;flex-direction:column}
.mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
.mesg.user {background-color:lightblue;color:white}
.mesg.bot {background-color:orange;color:white,align-self:self-end}
.footer {display:none !important}
"""
text=gr.inputs.Textbox(label="User Input", placeholder="Let's start a chat...")
gr.Interface(fn=converse,
theme="default",
inputs=[text, "state"],
outputs=["html", "state"],
css=css).launch() |