gpt-2-chatbot / app.py
jdoexbox360's picture
Update app.py
305a962
raw
history blame
No virus
1.21 kB
import gradio as gr
from gradio.components import Slider, Textbox, Radio
import tensorflow as tf
from transformers import GPT2LMHeadModel, GPT2Tokenizer
global tokenizer, model, script_speaker_name, script_responder_name, convo
tokenizer = GPT2Tokenizer.from_pretrained("ethzanalytics/ai-msgbot-gpt2-XL-dialogue")
model = GPT2LMHeadModel.from_pretrained("ethzanalytics/ai-msgbot-gpt2-XL-dialogue", pad_token_id=tokenizer.eos_token_id)
script_speaker_name = "person alpha"
script_responder_name = "person beta"
def output(prompt, output_length):
global convo
sentence = convo + '\n' + script_speaker_name + ': ' + prompt + '\n' + script_responder_name + ': '
input_ids = tokenizer.encode(sentence, return_tensors='pt')
# generate text until the output length (which includes the context length) reaches 50
output = model.generate(input_ids, max_new_tokens=output_length, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
convo = tokenizer.decode(output[0], skip_special_tokens=True)
return convo
convo = ''
iface = gr.Interface(fn=output, inputs=["text", Slider(minimum=0.0, maximum=1.0, step=0.05, default=0.4, label="Output Length")], outputs="text")
iface.launch()