|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2-large") |
|
@st.cache_data |
|
def load_model(model_name): |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
return model |
|
|
|
model = load_model("gpt2-large") |
|
|
|
def infer(sent, max_length, temperature, top_k, top_p): |
|
input_ids = tokenizer.encode(sent, return_tensors="pt") |
|
output_sequences = model.generate( |
|
input_ids=input_ids, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
do_sample=True, |
|
num_return_sequences=1 |
|
) |
|
return output_sequences |
|
|
|
default_value = "You: Ask me anything!" |
|
|
|
|
|
st.title("Chat with GPT-2 💬") |
|
st.write("GPT-2 is a large transformer-based language model with 1.5 billion parameters. It is trained to predict the next word in a sentence, given all of the previous words. This makes it great for text generation and for answering questions about the text it's given.") |
|
|
|
messages = [{"role": "system", "content": "You are a helpful assistant."}] |
|
|
|
user_input = st.text_input("You:", default_value) |
|
if user_input: |
|
messages.append({"role": "user", "content": user_input}) |
|
|
|
output_sequences = infer(user_input, max_length=100, temperature=0.7, top_k=40, top_p=0.9) |
|
generated_sequence = output_sequences[0].tolist() |
|
generated_text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) |
|
|
|
messages.append({"role": "assistant", "content": generated_text}) |
|
|
|
for message in messages: |
|
st.write(f"{message['role']}: {message['content']}") |