import gradio as gr import torch from transformers import pipeline import os import spaces MODELS = { "Moroccan": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-ARY", "Arabic": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-MSA", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-ini-8192-mx-8192 "Egyptian": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-ARZ", "Tunisian": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-ARY", "Algerian": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-ALGERIAN", } EXAMPLES = { "Moroccan": [ "الدار البيضاء [MASK]", "المغرب بلاد [MASK]", "كناكل [MASK] فالمغرب", "العاصمة د [MASK] هي الرباط", "المغرب [MASK] زوين", "انا سميتي مريم، و كنسكن ف[MASK] العاصمة دفلسطين" ], "Arabic": [ "العاصمة الرسمية لمصر هي [MASK].", "أطول نهر في العالم هو نهر [MASK].", "الشاعر العربي المشهور [MASK] كتب قصيدة 'أراك عصي الدمع'.", "عندما أستيقظ في الصباح، أشرب فنجان من [MASK].", "في التأني [MASK] وفي العجلة الندامة.", "معركة [MASK] كانت من أهم المعارك في تاريخ الإسلام.", "يعتبر [MASK] من أهم العلماء في مجال الفيزياء.", "تقع جبال [MASK] في شمال إفريقيا.", "يعتبر [MASK] من أركان الإسلام الخمسة." ], "Egyptian": [ "القاهرة مدينة [MASK]", "مصر بلاد [MASK]", "بنحب [MASK] فمصر" ], "Tunisian": [ "تونس بلاد [MASK]", "المنستير مدينة [MASK]", "عيشتي في [MASK]" ], "Algerian": [ "الجزائر بلاد [MASK]", "قسنطينة مدينة [MASK]", "نحبو [MASK] ف الجزائر" ], } TOKEN = os.environ["HF_TOKEN"] device = "cuda:0" if torch.cuda.is_available() else "cpu" def load_model(dialect): model_path = MODELS.get(dialect, MODELS["Arabic"]) return pipeline(task="fill-mask", model=model_path, token=TOKEN, device=device) pipe = None @spaces.GPU def predict(text, dialect): global pipe if pipe is None or dialect != predict.current_dialect: # Reload model if dialect changes pipe = load_model(dialect) predict.current_dialect = dialect outputs = pipe(text) scores = [x["score"] for x in outputs] tokens = [x["token_str"] for x in outputs] return {label: float(prob) for label, prob in zip(tokens, scores)} # Initialize current dialect predict.current_dialect = None # Create Gradio interface with gr.Blocks() as demo: with gr.Row(): with gr.Column(): # Dropdown for dialect selection dialect_dropdown = gr.Dropdown( choices=["Arabic", "Tunisian", "Moroccan", "Algerian", "Egyptian"], label="Select Dialect", value="Arabic" ) # Input text box input_text = gr.Textbox( label="Input", placeholder="Enter text here...", rtl=True ) # Button row with gr.Row(): clear_btn = gr.Button("Clear") submit_btn = gr.Button("Submit", variant="primary") # Examples section (stateful) examples_state = gr.State(EXAMPLES["Arabic"]) # Store examples for the current dialect example_dropdown = gr.Dropdown( choices=EXAMPLES["Arabic"], label="Select Example", interactive=True ) load_example_btn = gr.Button("Load Example") with gr.Column(): # Output probabilities output_labels = gr.Label( label="Prediction Results", show_label=False ) # Function to update examples when dialect changes def update_example_choices(dialect): return gr.update(choices=EXAMPLES.get(dialect, EXAMPLES["Arabic"]), value=EXAMPLES[dialect][0]) # Function to load selected example into the text box def load_example(selected_example): return selected_example # Update example dropdown choices dynamically dialect_dropdown.change( update_example_choices, inputs=dialect_dropdown, outputs=example_dropdown ) # Load selected example into the input text box load_example_btn.click( load_example, inputs=example_dropdown, outputs=input_text ) # Button actions submit_btn.click( predict, inputs=[input_text, dialect_dropdown], outputs=output_labels ) clear_btn.click( lambda: "", outputs=input_text ) # Launch the app demo.launch()