import torch from transformers import AutoTokenizer, AutoModelForCausalLM import streamlit as st st.title("Japanese Text Generation") tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-ppo", use_fast=False) model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-ppo") logs = [] def generate_text(input_prompt): token_ids = tokenizer.encode(input_prompt, add_special_tokens=False, return_tensors="pt") with torch.no_grad(): output_ids = model.generate( token_ids.to("cpu"), do_sample=True, max_new_tokens=128, temperature=0.7, repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):]) generated_text = generated_text.replace("", "\n") return generated_text prompt = st.text_area("Enter the prompt:") if st.button("Submit"): generated_output = generate_text(prompt) logs.append((prompt, generated_output)) for log in logs: with st.beta_container(): st.write("---") st.subheader("Time: {}".format(log[0])) st.write("**Input**: {}".format(log[0])) st.write("**Output**: {}".format(log[1]))