Spaces:
Running
Running
import streamlit as st | |
# from transformers import T5Tokenizer,AutoModelForCausalLM | |
model_name = "rinna/japanese-gpt2-small" | |
# tokenizer = T5Tokenizer.from_pretrained(model_name) | |
# model = AutoModelForCausalLM.from_pretrained(model_name) | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Load the pre-trained GPT-2 model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# App title | |
st.set_page_config(page_title="ChatBot") | |
if "messages" not in st.session_state.keys(): | |
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}] | |
# Display chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
# Function for generating LLM response | |
# def generate_response(prompt_input): | |
# input = tokenizer.encode(prompt_input, return_tensors="pt") | |
# output = model.generate(input, do_sample=True, max_length=30, num_return_sequences=1) | |
# return tokenizer.batch_decode(output) | |
def generate_response(prompt, max_length=50): | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
# Generate response | |
with torch.no_grad(): | |
output = model.generate(input_ids, max_length=max_length, num_return_sequences=1, pad_token_id=50256) | |
response = tokenizer.decode(output[0], skip_special_tokens=True) | |
return response | |
# User-provided prompt | |
if prompt := st.chat_input(): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.write(prompt) | |
# Generate a new response if last message is not from assistant | |
if st.session_state.messages[-1]["role"] != "assistant": | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
response = generate_response(prompt) | |
st.write(response) | |
message = {"role": "assistant", "content": response} | |
st.session_state.messages.append(message) |