Spaces:
Runtime error
Runtime error
import json | |
import requests | |
from mtranslate import translate | |
from prompts import PROMPT_LIST | |
import streamlit as st | |
import random | |
import transformers | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
import fasttext | |
import SessionState | |
LOGO = "huggingwayang.png" | |
MODELS = { | |
"GPT-2 Small": "flax-community/gpt2-small-indonesian", | |
"GPT-2 Medium": "flax-community/gpt2-medium-indonesian", | |
"GPT-2 Small finetuned on Indonesian academic journals": "Galuh/id-journal-gpt2" | |
} | |
headers = {} | |
def load_gpt(model_type): | |
model = GPT2LMHeadModel.from_pretrained(MODELS[model_type]) | |
return model | |
def load_gpt_tokenizer(model_type): | |
tokenizer = GPT2Tokenizer.from_pretrained(MODELS[model_type]) | |
return tokenizer | |
def get_image(text: str): | |
url = "https://wikisearch.uncool.ai/get_image/" | |
try: | |
payload = { | |
"text": text, | |
"image_width": 400 | |
} | |
data = json.dumps(payload) | |
response = requests.request("POST", url, headers=headers, data=data) | |
print(response.content) | |
image = json.loads(response.content.decode("utf-8"))["url"] | |
except: | |
image = "" | |
return image | |
st.set_page_config(page_title="Indonesian GPT-2 Demo") | |
st.title("Indonesian GPT-2") | |
ft_model = fasttext.load_model('lid.176.ftz') | |
# Sidebar | |
st.sidebar.image(LOGO) | |
st.sidebar.subheader("Configurable parameters") | |
max_len = st.sidebar.number_input( | |
"Maximum length", | |
value=100, | |
help="The maximum length of the sequence to be generated." | |
) | |
temp = st.sidebar.slider( | |
"Temperature", | |
value=1.0, | |
min_value=0.0, | |
max_value=100.0, | |
help="The value used to module the next token probabilities." | |
) | |
top_k = st.sidebar.number_input( | |
"Top k", | |
value=50, | |
help="The number of highest probability vocabulary tokens to keep for top-k-filtering." | |
) | |
top_p = st.sidebar.number_input( | |
"Top p", | |
value=1.0, | |
help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation." | |
) | |
st.markdown( | |
""" | |
This demo uses the [small](https://huggingface.co/flax-community/gpt2-small-indonesian) and | |
[medium](https://huggingface.co/flax-community/gpt2-medium-indonesian) Indonesian GPT2 model | |
trained on the Indonesian [Oscar](https://huggingface.co/datasets/oscar), [MC4](https://huggingface.co/datasets/mc4) | |
and [Wikipedia](https://huggingface.co/datasets/wikipedia) dataset. We created it as part of the | |
[Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/). | |
The demo supports "multi language" ;-), feel free to try a prompt on your language. We are also experimenting with | |
the sentence based image search using Wikipedia passages encoded with distillbert, and search the encoded sentence | |
in the encoded passages using Facebook's Faiss. | |
""" | |
) | |
model_name = st.selectbox('Model',(['GPT-2 Small', 'GPT-2 Medium', 'GPT-2 Small finetuned on Indonesian academic journals'])) | |
if model_name in ["GPT-2 Small", "GPT-2 Medium"]: | |
prompt_group_name = "GPT-2" | |
elif model_name in ["GPT-2 Small finetuned on Indonesian academic journals"]: | |
prompt_group_name = "Indonesian Journals" | |
ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"] | |
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1) | |
session_state = SessionState.get(prompt_box=None) | |
if prompt == "Custom": | |
prompt_box = "Enter your text here" | |
else: | |
prompt_box = random.choice(PROMPT_LIST[prompt_group_name][prompt]) | |
session_state.prompt_box = prompt_box | |
text = st.text_area("Enter text", session_state.prompt_box) | |
if st.button("Run"): | |
text = st.text_area("Enter text", session_state.prompt_box) | |
with st.spinner(text="Getting results..."): | |
lang_predictions, lang_probability = ft_model.predict(text.replace("\n", " "), k=3) | |
if "__label__id" in lang_predictions: | |
lang = "id" | |
else: | |
lang = lang_predictions[0].replace("__label__", "") | |
text = translate(text, "id", lang) | |
st.subheader("Result") | |
model = load_gpt(model_name) | |
tokenizer = load_gpt_tokenizer(model_name) | |
input_ids = tokenizer.encode(text, return_tensors='pt') | |
output = model.generate(input_ids=input_ids, | |
max_length=max_len, | |
temperature=temp, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=2.0) | |
text = tokenizer.decode(output[0], | |
skip_special_tokens=True) | |
st.write(text.replace("\n", " \n")) | |
st.text("Translation") | |
translation = translate(text, "en", "id") | |
if lang == "id": | |
st.write(translation.replace("\n", " \n")) | |
else: | |
st.write(translate(text, lang, "id").replace("\n", " \n")) | |
image_cat = "https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif" | |
image = get_image(translation.replace("\"", "'")) | |
if image is not "": | |
st.image(image, width=400) | |
else: | |
# display cat image if no image found | |
st.image(image_cat, width=400) | |