Space4 / app.py
SAMBOOM's picture
Update app.py
34c9838 verified
raw
history blame
1.53 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
@st.cache_data
def load_model(model_name):
model = AutoModelForCausalLM.from_pretrained(model_name)
return model
model = load_model("gpt2-large")
def infer(sent, max_length, temperature, top_k, top_p):
input_ids = tokenizer.encode(sent, return_tensors="pt")
output_sequences = model.generate(
input_ids=input_ids,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
num_return_sequences=1
)
return output_sequences
default_value = "You: Ask me anything!"
#prompts
st.title("Chat with GPT-2 💬")
st.write("GPT-2 is a large transformer-based language model with 1.5 billion parameters. It is trained to predict the next word in a sentence, given all of the previous words. This makes it great for text generation and for answering questions about the text it's given.")
messages = [{"role": "system", "content": "You are a helpful assistant."}]
user_input = st.text_input("You:", default_value)
if user_input:
messages.append({"role": "user", "content": user_input})
output_sequences = infer(user_input, max_length=100, temperature=0.7, top_k=40, top_p=0.9)
generated_sequence = output_sequences[0].tolist()
generated_text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
messages.append({"role": "assistant", "content": generated_text})
for message in messages:
st.write(f"{message['role']}: {message['content']}")