ChatFinance / app.py
Joe99's picture
Update app.py
428c8b1
raw
history blame
1.6 kB
import transformers
import gradio as gr
# import warnings
import torch
# warnings.simplefilter('ignore')
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
#add padding token, beginstring and endstring tokens
tokenizer.add_special_tokens(
{
"pad_token":"<pad>",
"bos_token":"<startstring>",
"eos_token":"<endstring>"
})
#add bot token since it is not a special token
tokenizer.add_tokens(["<bot>:"])
print("=====Done 1")
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))
model.load_state_dict(torch.load('./gpt2talk.pt', map_location=torch.device('cpu')))
print("=====Done 2")
model.eval()
def inference(quiz):
quiz1 = quiz
quiz = "<startstring>"+quiz+" <bot>:"
quiztoken = tokenizer(quiz,
return_tensors='pt'
)
answer = model.generate(**quiztoken, max_length=200, top_k=0.7,top_p=0.1)[0]
answer = tokenizer.decode(answer, skip_special_tokens=True)
answer = answer.replace(" <bot>:","").replace(quiz1,"") + '.'
return answer
def chatbot(input_text):
response = inference(input_text)
return response
# Create the Gradio interface
print("=====Done 3")
iface = gr.Interface(
fn=chatbot,
inputs=gr.Textbox(),
outputs=gr.Textbox(),
live=False, #set false to avoid caching
interpretation="chat",
title="ChatFinance",
description="Ask the a question and see its response!",
)
print("=====Done 4")
# Launch the Gradio interface
iface.launch()