Spaces:
Runtime error
Runtime error
import streamlit as st | |
mirror_url = "https://news-generator.ai-research.id/" | |
print("Streamlit Version: ", st.__version__) | |
if st.__version__ != "1.9.0": | |
st.warning(f"We move to: {mirror_url}") | |
st.stop() | |
import SessionState | |
from mtranslate import translate | |
from prompts import PROMPT_LIST | |
import random | |
import time | |
import psutil | |
import os | |
import requests | |
# st.set_page_config(page_title="Indonesian GPT-2") | |
if "MIRROR_URL" in os.environ: | |
mirror_url = os.environ["MIRROR_URL"] | |
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False) | |
news_api_auth_token = os.getenv("NEWS_API_AUTH_TOKEN", False) | |
MODELS = { | |
"Indonesian Newspaper - Indonesian GPT-2 Medium": { | |
"group": "Indonesian Newspaper", | |
"name": "ai-research-id/gpt2-medium-newspaper", | |
"description": "Newspaper Generator using Indonesian GPT-2 Medium.", | |
"text_generator": None, | |
"tokenizer": None | |
}, | |
} | |
st.sidebar.markdown(""" | |
<style> | |
.centeralign { | |
text-align: center; | |
} | |
</style> | |
<p class="centeralign"> | |
<img src="https://huggingface.co/spaces/flax-community/gpt2-indonesian/resolve/main/huggingwayang.png"/> | |
</p> | |
""", unsafe_allow_html=True) | |
st.sidebar.markdown(f""" | |
___ | |
<p class="centeralign"> | |
This is a collection of applications that generates sentences using Indonesian GPT-2 models! | |
</p> | |
<p class="centeralign"> | |
Created by <a href="https://huggingface.co/indonesian-nlp">Indonesian NLP</a> team @2021 | |
<br/> | |
<a href="https://github.com/indonesian-nlp/gpt2-app" target="_blank">GitHub</a> | <a href="https://github.com/indonesian-nlp/gpt2-app" target="_blank">Project Report</a> | |
<br/> | |
A mirror of the application is available <a href="{mirror_url}" target="_blank">here</a> | |
</p> | |
""", unsafe_allow_html=True) | |
st.sidebar.markdown(""" | |
___ | |
""", unsafe_allow_html=True) | |
model_type = st.sidebar.selectbox('Model', (MODELS.keys())) | |
# Disable the st.cache for this function due to issue on newer version of streamlit | |
# @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id}) | |
def process(title: str, keywords: str, text: str, | |
max_length: int = 250, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95, | |
temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0, | |
penalty_alpha = 0.6): | |
# st.write("Cache miss: process") | |
url = 'https://news-api.uncool.ai/api/text_generator/v1' | |
# url = 'http://localhost:8000/api/text_generator/v1' | |
headers = {'Authorization': 'Bearer ' + news_api_auth_token} | |
data = { | |
"title": title, | |
"keywords": keywords, | |
"text": text, | |
"max_length": max_length, | |
"do_sample": do_sample, | |
"top_k": top_k, | |
"top_p": top_p, | |
"temperature": temperature, | |
"max_time": max_time, | |
"seed": seed, | |
"repetition_penalty": repetition_penalty, | |
"penalty_alpha": penalty_alpha | |
} | |
r = requests.post(url, headers=headers, data=data) | |
if r.status_code == 200: | |
result = r.json() | |
return result | |
else: | |
return "Error: " + r.text | |
st.title("Indonesian GPT-2 Applications") | |
prompt_group_name = MODELS[model_type]["group"] | |
st.header(prompt_group_name) | |
description = f"This is a news generator using Indonesian GPT-2 Medium. We finetuned the pre-trained model with 1.4M " \ | |
f"articles of the Indonesian online newspaper dataset." | |
st.markdown(description) | |
model_name = f"Model name: [{MODELS[model_type]['name']}](https://huggingface.co/{MODELS[model_type]['name']})" | |
st.markdown(model_name) | |
if prompt_group_name in ["Indonesian Newspaper"]: | |
session_state = SessionState.get(prompt=None, prompt_box=None, text=None) | |
ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys()) + ["Custom"] | |
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS) - 1) | |
# Update prompt | |
if session_state.prompt is None: | |
session_state.prompt = prompt | |
elif session_state.prompt is not None and (prompt != session_state.prompt): | |
session_state.prompt = prompt | |
session_state.prompt_box = None | |
else: | |
session_state.prompt = prompt | |
# Update prompt box | |
if session_state.prompt == "Custom": | |
session_state.prompt_box = "" | |
session_state.title = "" | |
session_state.keywords = "" | |
else: | |
if session_state.prompt is not None and session_state.prompt_box is None: | |
choice = random.choice(PROMPT_LIST[prompt_group_name][session_state.prompt]) | |
session_state.title = choice["title"] | |
session_state.keywords = choice["keywords"] | |
session_state.prompt_box = choice["text"] | |
session_state.title = st.text_input("Title", session_state.title) | |
session_state.keywords = st.text_input("Keywords", session_state.keywords) | |
session_state.text = st.text_area("Prompt", session_state.prompt_box) | |
max_length = st.sidebar.number_input( | |
"Maximum length", | |
value=250, | |
max_value=512, | |
help="The maximum length of the sequence to be generated." | |
) | |
decoding_methods = st.sidebar.radio( | |
"Set the decoding methods:", | |
key="decoding", | |
options=["Beam Search", "Sampling", "Contrastive Search"], | |
index=2 | |
) | |
temperature = st.sidebar.slider( | |
"Temperature", | |
value=0.4, | |
min_value=0.0, | |
max_value=2.0 | |
) | |
top_k = 30 | |
top_p = 0.95 | |
repetition_penalty = 0.0 | |
penalty_alpha = None | |
if decoding_methods == "Beam Search": | |
do_sample = False | |
elif decoding_methods == "Sampling": | |
do_sample = True | |
top_k = st.sidebar.number_input( | |
"Top k", | |
value=top_k, | |
help="The number of highest probability vocabulary tokens to keep for top-k-filtering." | |
) | |
top_p = st.sidebar.number_input( | |
"Top p", | |
value=top_p, | |
help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher " | |
"are kept for generation." | |
) | |
else: | |
do_sample = False | |
repetition_penalty = 1.1 | |
penalty_alpha = st.sidebar.number_input( | |
"Penalty alpha", | |
value=0.6, | |
help="The penalty alpha for contrastive search." | |
) | |
top_k = st.sidebar.number_input( | |
"Top k", | |
value=4, | |
help="The number of highest probability vocabulary tokens to keep for top-k-filtering." | |
) | |
seed = st.sidebar.number_input( | |
"Random Seed", | |
value=25, | |
help="The number used to initialize a pseudorandom number generator" | |
) | |
if decoding_methods != "Contrastive Search": | |
automatic_repetition_penalty = st.sidebar.checkbox( | |
"Automatic Repetition Penalty", | |
value=True | |
) | |
if not automatic_repetition_penalty: | |
repetition_penalty = st.sidebar.slider( | |
"Repetition Penalty", | |
value=1.0, | |
min_value=1.0, | |
max_value=2.0 | |
) | |
# st.write(f"Generator: {MODELS}'") | |
if st.button("Run"): | |
with st.spinner(text="Getting results..."): | |
memory = psutil.virtual_memory() | |
# st.subheader("Result") | |
time_start = time.time() | |
# text_generator = MODELS[model_type]["text_generator"] | |
result = process(title=session_state.title, | |
keywords=session_state.keywords, | |
text=session_state.text, max_length=int(max_length), | |
temperature=temperature, do_sample=do_sample, penalty_alpha=penalty_alpha, | |
top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty) | |
time_end = time.time() | |
time_diff = time_end - time_start | |
# result = result[0]["generated_text"] | |
title = f"### {session_state.title}" | |
tldr = f"*{result['description'].strip()}*" | |
caption = f"*Photo Caption: {result['caption'].strip()}*" if result['caption'].strip() != "" else "" | |
st.markdown(title) | |
st.markdown(tldr) | |
st.markdown(result["generated_text"].replace("\n", " \n")) | |
st.markdown(caption.replace("\n", " \n")) | |
st.markdown("**Translation**") | |
translation = translate(result["generated_text"], "en", "id") | |
st.write(translation.replace("\n", " \n")) | |
# st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*") | |
info = f""" | |
*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB* | |
*Text generated in {time_diff:.5} seconds* | |
""" | |
st.write(info) | |
# Reset state | |
session_state.prompt = None | |
session_state.prompt_box = None | |