Spaces:
Runtime error
Runtime error
File size: 3,756 Bytes
59cee12 9177921 59cee12 7d0ffd9 59cee12 9177921 59cee12 9177921 7d0ffd9 59cee12 9177921 59cee12 7d0ffd9 9177921 59cee12 9177921 59cee12 9177921 59cee12 9177921 59cee12 9177921 59cee12 9177921 59cee12 9177921 59cee12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import streamlit as st
import SessionState
from mtranslate import translate
from prompts import PROMPT_LIST
import random
import time
from transformers import pipeline, set_seed
import tokenizers
# st.set_page_config(page_title="Image Search")
# vector_length = 128
model_name = "cahya/gpt2-small-indonesian-story"
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_generator():
st.write(f"Loading the GPT2 model {model_name}, please wait...")
text_generator = pipeline('text-generation', model=model_name)
return text_generator
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
def process(text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
temperature: float = 1.0, max_time: float = 10.0, seed=42):
st.write("Cache miss: process")
set_seed(seed)
result = text_generator(text, max_length=max_length, do_sample=do_sample,
top_k=top_k, top_p=top_p, temperature=temperature,
max_time=max_time)
return result
st.title("Indonesian Story Generator")
st.markdown(
"""
This application is a demo for Indonesian Story Generator using GPT2.
"""
)
session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
ALL_PROMPTS = list(PROMPT_LIST.keys())+["Custom"]
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
# Update prompt
if session_state.prompt is None:
session_state.prompt = prompt
elif session_state.prompt is not None and (prompt != session_state.prompt):
session_state.prompt = prompt
session_state.prompt_box = None
session_state.text = None
else:
session_state.prompt = prompt
# Update prompt box
if session_state.prompt == "Custom":
session_state.prompt_box = "Enter your text here"
else:
if session_state.prompt is not None and session_state.prompt_box is None:
session_state.prompt_box = random.choice(PROMPT_LIST[session_state.prompt])
session_state.text = st.text_area("Enter text", session_state.prompt_box)
max_length = st.sidebar.number_input(
"Maximum length",
value=100,
max_value=512,
help="The maximum length of the sequence to be generated."
)
temperature = st.sidebar.slider(
"Temperature",
value=1.0,
min_value=0.0,
max_value=10.0
)
do_sample = st.sidebar.checkbox(
"Use sampling",
value=True
)
top_k = 25
top_p = 0.95
if do_sample:
top_k = st.sidebar.number_input(
"Top k",
value=top_k
)
top_p = st.sidebar.number_input(
"Top p",
value=top_p
)
seed = st.sidebar.number_input(
"Random Seed",
value=25,
help="The number used to initialize a pseudorandom number generator"
)
text_generator = get_generator()
if st.button("Run"):
with st.spinner(text="Getting results..."):
st.subheader("Result")
time_start = time.time()
result = process(text=session_state.text, max_length=int(max_length),
temperature=temperature, do_sample=do_sample,
top_k=int(top_k), top_p=float(top_p), seed=seed)
time_end = time.time()
time_diff = time_end-time_start
result = result[0]["generated_text"]
st.write(result.replace("\n", " \n"))
st.text("Translation")
translation = translate(result, "en", "id")
st.write(translation.replace("\n", " \n"))
# st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
st.write(f"*Text generated in {time_diff:.5} seconds*")
# Reset state
session_state.prompt = None
session_state.prompt_box = None
session_state.text = None
|