Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import streamlit as st | |
st.title("Japanese Text Generation") | |
tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-ppo", use_fast=False) | |
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-ppo") | |
logs = [] | |
def generate_text(input_prompt): | |
token_ids = tokenizer.encode(input_prompt, add_special_tokens=False, return_tensors="pt") | |
with torch.no_grad(): | |
output_ids = model.generate( | |
token_ids.to("cpu"), | |
do_sample=True, | |
max_new_tokens=128, | |
temperature=0.7, | |
repetition_penalty=1.1, | |
pad_token_id=tokenizer.pad_token_id, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
generated_text = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):]) | |
generated_text = generated_text.replace("<NL>", "\n") | |
return generated_text | |
prompt = st.text_area("Enter the prompt:") | |
if st.button("Submit"): | |
generated_output = generate_text(prompt) | |
logs.append((prompt, generated_output)) | |
for log in logs: | |
with st.beta_container(): | |
st.write("---") | |
st.subheader("Time: {}".format(log[0])) | |
st.write("**Input**: {}".format(log[0])) | |
st.write("**Output**: {}".format(log[1])) | |