import streamlit as st import torch from transformers import pipeline, set_seed from transformers import AutoTokenizer from transformers import GPT2LMHeadModel from mtranslate import translate import random import meta from normalizer import normalize from utils import ( remote_css, local_css, load_json ) EXAMPLES = load_json("examples.json") CK = "متن" QK = "پرسش" AK = "پاسخ" class TextGeneration: def __init__(self): self.debug = False self.dummy_output = "مخلوطی از ایتالیایی و انگلیسی" self.tokenizer = None self.model = None self.model_name_or_path = "m3hrdadfi/gpt2-persian-qa" self.length_margin = 100 set_seed(42) def load(self): if not self.debug: self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) self.model = GPT2LMHeadModel.from_pretrained(self.model_name_or_path) def generate(self, prompt, generation_kwargs): if not self.debug: input_ids = self.tokenizer([prompt], return_tensors="pt")["input_ids"] max_length = len(input_ids[0]) + self.length_margin generation_kwargs["max_length"] = max_length generated = self.model.generate( input_ids, **generation_kwargs, )[0] answer = self.tokenizer.decode(generated, skip_special_tokens=True) found = answer.find(f"{AK}: ") if not found: return "" answer = [a.strip() for a in answer[found:].split(f"{AK}: ") if a.strip()] answer = answer[0] if len(answer) > 0 else "" return answer return self.dummy_output @st.cache(allow_output_mutation=True) def load_text_generator(): generator = TextGeneration() generator.load() return generator def main(): st.set_page_config( page_title="GPT2 QA - Persian", page_icon="⁉️", layout="wide", initial_sidebar_state="expanded" ) remote_css("https://cdn.jsdelivr.net/gh/rastikerdar/vazir-font/dist/font-face.css") local_css("assets/rtl.css") generator = load_text_generator() st.sidebar.markdown(meta.SIDEBAR_INFO) num_beams = st.sidebar.slider( label='Number of Beam', help="Number of beams for beam search", min_value=4, max_value=15, value=5, step=1 ) repetition_penalty = st.sidebar.slider( label='Repetition Penalty', help="The parameter for repetition penalty", min_value=1.0, max_value=10.0, value=1.0, step=0.1 ) length_penalty = st.sidebar.slider( label='Length Penalty', help="Exponential penalty to the length", min_value=1.0, max_value=10.0, value=1.0, step=0.1 ) early_stopping = st.sidebar.selectbox( label='Early Stopping ?', options=(True, False), help="Whether to stop the beam search when at least num_beams sentences are finished per batch or not", ) translated = st.sidebar.selectbox( label='Translation ?', options=(True, False), help="Will translate the result in English", ) generation_kwargs = { "num_beams": num_beams, "early_stopping": early_stopping, "repetition_penalty": repetition_penalty, "length_penalty": length_penalty, } st.markdown(meta.HEADER_INFO) prompts = [e["title"] for e in EXAMPLES] + ["Custom"] prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1) if prompt == "Custom": prompt_box = { "context": meta.C_PROMPT_BOX, "question": meta.Q_PROMPT_BOX, "answer": meta.A_PROMPT_BOX, } else: prompt_box = next(e for e in EXAMPLES if e["title"] == prompt) context = st.text_area("Enter context", prompt_box["context"], height=250) question = st.text_area("Enter question", prompt_box["question"], height=100) answer = "پاسخ درست: " + prompt_box["answer"] st.markdown( f'

' f'{answer}' f'

', unsafe_allow_html=True ) if translated: translated_answer = translate(answer, "en", "fa") st.markdown( f'

' f'{translated_answer}' f'

', unsafe_allow_html=True ) generation_kwargs_ph = st.empty() if st.button("Find the answer 🔎 "): with st.spinner(text="Searching ..."): generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()])) context = normalize(context) question = normalize(question) if context and question: text = f"{context} {QK}: {question} {AK}:" generated_answer = generator.generate(text, generation_kwargs) generated_answer = f"{AK}: {generated_answer}".strip() context = f"{CK}: {context}".strip() question = f"{QK}: {question}".strip() st.markdown( f'

' f'{context}

' f'{question}

' f'{generated_answer} ' f'

', unsafe_allow_html=True ) if translated: translated_context = translate(context, "en", "fa") translated_question = translate(question, "en", "fa") translated_generated_answer = translate(generated_answer, "en", "fa") st.markdown( f'

' f'{translated_context}

' f'{translated_question}

' f'{translated_generated_answer}' f'

', unsafe_allow_html=True ) if __name__ == '__main__': main()