Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import pipeline, set_seed | |
from transformers import AutoTokenizer | |
from mtranslate import translate | |
import random | |
import meta | |
import examples | |
from normalizer import normalize | |
from utils import ( | |
remote_css, | |
local_css | |
) | |
class TextGeneration: | |
def __init__(self): | |
self.debug = False | |
self.dummy_output = "ناف جایی قرار گرفته که در واقع بندناف در داخل رحم در آنجا به شکم جنین وصل بودهاست. " \ | |
"بندناف که جفت را به جنین متصل کرده بعد از تولد از نوزاد جدا میشود. برای جدا کردن بند ناف از دو پنس استفاده میکنند و بین آن دو را میبرند. پنس دیگری نزدیک شکم نوزاد قرار داده میشود که بعد از دو روز برداشته خواهد شد. بندناف باقیمانده طی ۱۵ روز خشک شده و میافتد و به جای آن اسکاری طبیعی به جای میماند. البته بر خلاف تصور عامه مردم شکل ناف در اثر بریدن بند ناف به وجود نمیآید و پیش از این در شکم مادر حالت ناف شکل گرفتهاست. شکل ناف در میان مردم مختلف متفاوت است و اندازه آن بین ۱.۵ تا ۲ سانتیمتر است. تمام پستانداران جفتزیست ناف دارند. ناف در انسانها به سادگی قابل مشاهدهاست." | |
self.tokenizer = None | |
self.generator = None | |
self.task = "text-generation" | |
self.model_name_or_path = "flax-community/gpt2-medium-persian" | |
set_seed(42) | |
def load(self): | |
if not self.debug: | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) | |
self.generator = pipeline(self.task, model=self.model_name_or_path, tokenizer=self.model_name_or_path) | |
def generate(self, prompt, generation_kwargs): | |
if not self.debug: | |
generation_kwargs["num_return_sequences"] = 1 | |
max_length = len(self.tokenizer(prompt)["input_ids"]) + generation_kwargs["max_length"] | |
generation_kwargs["max_length"] = max_length | |
generation_kwargs["return_full_text"] = False | |
return self.generator( | |
prompt, | |
**generation_kwargs, | |
)[0]["generated_text"] | |
return self.dummy_output | |
def load_text_generator(): | |
generator = TextGeneration() | |
generator.load() | |
return generator | |
def main(): | |
st.set_page_config( | |
page_title="GPT2 - Persian", | |
page_icon="🤘", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
remote_css("https://cdn.jsdelivr.net/gh/rastikerdar/vazir-font/dist/font-face.css") | |
local_css("assets/rtl.css") | |
generator = load_text_generator() | |
st.sidebar.markdown(meta.SIDEBAR_INFO) | |
max_length = st.sidebar.slider( | |
label='Max Length', | |
help="The maximum length of the sequence to be generated.", | |
min_value=1, | |
max_value=128, | |
value=50, | |
step=1 | |
) | |
top_k = st.sidebar.slider( | |
label='Top-k', | |
help="The number of highest probability vocabulary tokens to keep for top-k-filtering", | |
min_value=40, | |
max_value=80, | |
value=50, | |
step=1 | |
) | |
top_p = st.sidebar.slider( | |
label='Top-p', | |
help="Only the most probable tokens with probabilities that add up to `top_p` or higher are kept for " | |
"generation.", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.95, | |
step=0.01 | |
) | |
temperature = st.sidebar.slider( | |
label='Temperature', | |
help="The value used to module the next token probabilities", | |
min_value=0.1, | |
max_value=10.0, | |
value=1.0, | |
step=0.05 | |
) | |
do_sample = st.sidebar.selectbox( | |
label='Sampling ?', | |
options=(True, False), | |
help="Whether or not to use sampling; use greedy decoding otherwise.", | |
) | |
translated = st.sidebar.selectbox( | |
label='Translation ?', | |
options=(True, False), | |
help="Will translate the result in English", | |
) | |
generation_kwargs = { | |
"max_length": max_length, | |
"top_k": top_k, | |
"top_p": top_p, | |
"temperature": temperature, | |
"do_sample": do_sample, | |
} | |
st.markdown(meta.HEADER_INFO) | |
prompts = list(examples.EXAMPLES.keys()) + ["Custom"] | |
prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1) | |
if prompt == "Custom": | |
prompt_box = meta.PROMPT_BOX | |
else: | |
prompt_box = random.choice(examples.EXAMPLES[prompt]) | |
text = st.text_area("Enter text", prompt_box) | |
generation_kwargs_ph = st.empty() | |
if st.button("Generate !"): | |
with st.spinner(text="Generating ..."): | |
generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()])) | |
text = normalize(text) | |
if text: | |
generated_text = generator.generate(text, generation_kwargs) | |
st.markdown( | |
f'<p class="rtl rtl-box">' | |
f'<span class="result-text">{text} <span>' | |
f'<span class="result-text generated-text">{generated_text}</span>' | |
f'</p>', | |
unsafe_allow_html=True | |
) | |
if translated: | |
translated_text = translate(text, "en", "fa") | |
translated_generated_text = translate(generated_text, "en", "fa") | |
st.markdown( | |
f'<p class="ltr ltr-box">' | |
f'<span class="result-text">{translated_text} <span>' | |
f'<span class="result-text generated-text">{translated_generated_text}</span>' | |
f'</p>', | |
unsafe_allow_html=True | |
) | |
if __name__ == '__main__': | |
main() | |