cahya's picture
info to move
aaf3182
import streamlit as st
mirror_url = "https://news-generator.ai-research.id/"
print("Streamlit Version: ", st.__version__)
if st.__version__ != "1.9.0":
st.warning(f"We move to: {mirror_url}")
st.stop()
import SessionState
from mtranslate import translate
from prompts import PROMPT_LIST
import random
import time
import psutil
import os
import requests
# st.set_page_config(page_title="Indonesian GPT-2")
if "MIRROR_URL" in os.environ:
mirror_url = os.environ["MIRROR_URL"]
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
news_api_auth_token = os.getenv("NEWS_API_AUTH_TOKEN", False)
MODELS = {
"Indonesian Newspaper - Indonesian GPT-2 Medium": {
"group": "Indonesian Newspaper",
"name": "ai-research-id/gpt2-medium-newspaper",
"description": "Newspaper Generator using Indonesian GPT-2 Medium.",
"text_generator": None,
"tokenizer": None
},
}
st.sidebar.markdown("""
<style>
.centeralign {
text-align: center;
}
</style>
<p class="centeralign">
<img src="https://huggingface.co/spaces/flax-community/gpt2-indonesian/resolve/main/huggingwayang.png"/>
</p>
""", unsafe_allow_html=True)
st.sidebar.markdown(f"""
___
<p class="centeralign">
This is a collection of applications that generates sentences using Indonesian GPT-2 models!
</p>
<p class="centeralign">
Created by <a href="https://huggingface.co/indonesian-nlp">Indonesian NLP</a> team @2021
<br/>
<a href="https://github.com/indonesian-nlp/gpt2-app" target="_blank">GitHub</a> | <a href="https://github.com/indonesian-nlp/gpt2-app" target="_blank">Project Report</a>
<br/>
A mirror of the application is available <a href="{mirror_url}" target="_blank">here</a>
</p>
""", unsafe_allow_html=True)
st.sidebar.markdown("""
___
""", unsafe_allow_html=True)
model_type = st.sidebar.selectbox('Model', (MODELS.keys()))
# Disable the st.cache for this function due to issue on newer version of streamlit
# @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
def process(title: str, keywords: str, text: str,
max_length: int = 250, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0,
penalty_alpha = 0.6):
# st.write("Cache miss: process")
url = 'https://news-api.uncool.ai/api/text_generator/v1'
# url = 'http://localhost:8000/api/text_generator/v1'
headers = {'Authorization': 'Bearer ' + news_api_auth_token}
data = {
"title": title,
"keywords": keywords,
"text": text,
"max_length": max_length,
"do_sample": do_sample,
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"max_time": max_time,
"seed": seed,
"repetition_penalty": repetition_penalty,
"penalty_alpha": penalty_alpha
}
r = requests.post(url, headers=headers, data=data)
if r.status_code == 200:
result = r.json()
return result
else:
return "Error: " + r.text
st.title("Indonesian GPT-2 Applications")
prompt_group_name = MODELS[model_type]["group"]
st.header(prompt_group_name)
description = f"This is a news generator using Indonesian GPT-2 Medium. We finetuned the pre-trained model with 1.4M " \
f"articles of the Indonesian online newspaper dataset."
st.markdown(description)
model_name = f"Model name: [{MODELS[model_type]['name']}](https://huggingface.co/{MODELS[model_type]['name']})"
st.markdown(model_name)
if prompt_group_name in ["Indonesian Newspaper"]:
session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys()) + ["Custom"]
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS) - 1)
# Update prompt
if session_state.prompt is None:
session_state.prompt = prompt
elif session_state.prompt is not None and (prompt != session_state.prompt):
session_state.prompt = prompt
session_state.prompt_box = None
else:
session_state.prompt = prompt
# Update prompt box
if session_state.prompt == "Custom":
session_state.prompt_box = ""
session_state.title = ""
session_state.keywords = ""
else:
if session_state.prompt is not None and session_state.prompt_box is None:
choice = random.choice(PROMPT_LIST[prompt_group_name][session_state.prompt])
session_state.title = choice["title"]
session_state.keywords = choice["keywords"]
session_state.prompt_box = choice["text"]
session_state.title = st.text_input("Title", session_state.title)
session_state.keywords = st.text_input("Keywords", session_state.keywords)
session_state.text = st.text_area("Prompt", session_state.prompt_box)
max_length = st.sidebar.number_input(
"Maximum length",
value=250,
max_value=512,
help="The maximum length of the sequence to be generated."
)
decoding_methods = st.sidebar.radio(
"Set the decoding methods:",
key="decoding",
options=["Beam Search", "Sampling", "Contrastive Search"],
index=2
)
temperature = st.sidebar.slider(
"Temperature",
value=0.4,
min_value=0.0,
max_value=2.0
)
top_k = 30
top_p = 0.95
repetition_penalty = 0.0
penalty_alpha = None
if decoding_methods == "Beam Search":
do_sample = False
elif decoding_methods == "Sampling":
do_sample = True
top_k = st.sidebar.number_input(
"Top k",
value=top_k,
help="The number of highest probability vocabulary tokens to keep for top-k-filtering."
)
top_p = st.sidebar.number_input(
"Top p",
value=top_p,
help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher "
"are kept for generation."
)
else:
do_sample = False
repetition_penalty = 1.1
penalty_alpha = st.sidebar.number_input(
"Penalty alpha",
value=0.6,
help="The penalty alpha for contrastive search."
)
top_k = st.sidebar.number_input(
"Top k",
value=4,
help="The number of highest probability vocabulary tokens to keep for top-k-filtering."
)
seed = st.sidebar.number_input(
"Random Seed",
value=25,
help="The number used to initialize a pseudorandom number generator"
)
if decoding_methods != "Contrastive Search":
automatic_repetition_penalty = st.sidebar.checkbox(
"Automatic Repetition Penalty",
value=True
)
if not automatic_repetition_penalty:
repetition_penalty = st.sidebar.slider(
"Repetition Penalty",
value=1.0,
min_value=1.0,
max_value=2.0
)
# st.write(f"Generator: {MODELS}'")
if st.button("Run"):
with st.spinner(text="Getting results..."):
memory = psutil.virtual_memory()
# st.subheader("Result")
time_start = time.time()
# text_generator = MODELS[model_type]["text_generator"]
result = process(title=session_state.title,
keywords=session_state.keywords,
text=session_state.text, max_length=int(max_length),
temperature=temperature, do_sample=do_sample, penalty_alpha=penalty_alpha,
top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty)
time_end = time.time()
time_diff = time_end - time_start
# result = result[0]["generated_text"]
title = f"### {session_state.title}"
tldr = f"*{result['description'].strip()}*"
caption = f"*Photo Caption: {result['caption'].strip()}*" if result['caption'].strip() != "" else ""
st.markdown(title)
st.markdown(tldr)
st.markdown(result["generated_text"].replace("\n", " \n"))
st.markdown(caption.replace("\n", " \n"))
st.markdown("**Translation**")
translation = translate(result["generated_text"], "en", "id")
st.write(translation.replace("\n", " \n"))
# st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
info = f"""
*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB*
*Text generated in {time_diff:.5} seconds*
"""
st.write(info)
# Reset state
session_state.prompt = None
session_state.prompt_box = None