BounharAbdelaziz's picture
Update app.py
e9636ac verified
import gradio as gr
import torch
from transformers import pipeline
import os
import spaces
MODELS = {
"Moroccan": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-ARY",
"Arabic": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-MSA", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-ini-8192-mx-8192
"Egyptian": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-ARZ",
"Tunisian": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-ARY",
"Algerian": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-ALGERIAN",
}
EXAMPLES = {
"Moroccan": [
"الدار البيضاء [MASK]",
"المغرب بلاد [MASK]",
"كناكل [MASK] فالمغرب",
"العاصمة د [MASK] هي الرباط",
"المغرب [MASK] زوين",
"انا سميتي مريم، و كنسكن ف[MASK] العاصمة دفلسطين"
],
"Arabic": [
"العاصمة الرسمية لمصر هي [MASK].",
"أطول نهر في العالم هو نهر [MASK].",
"الشاعر العربي المشهور [MASK] كتب قصيدة 'أراك عصي الدمع'.",
"عندما أستيقظ في الصباح، أشرب فنجان من [MASK].",
"في التأني [MASK] وفي العجلة الندامة.",
"معركة [MASK] كانت من أهم المعارك في تاريخ الإسلام.",
"يعتبر [MASK] من أهم العلماء في مجال الفيزياء.",
"تقع جبال [MASK] في شمال إفريقيا.",
"يعتبر [MASK] من أركان الإسلام الخمسة."
],
"Egyptian": [
"القاهرة مدينة [MASK]",
"مصر بلاد [MASK]",
"بنحب [MASK] فمصر"
],
"Tunisian": [
"تونس بلاد [MASK]",
"المنستير مدينة [MASK]",
"عيشتي في [MASK]"
],
"Algerian": [
"الجزائر بلاد [MASK]",
"قسنطينة مدينة [MASK]",
"نحبو [MASK] ف الجزائر"
],
}
TOKEN = os.environ["HF_TOKEN"]
device = "cuda:0" if torch.cuda.is_available() else "cpu"
def load_model(dialect):
model_path = MODELS.get(dialect, MODELS["Arabic"])
return pipeline(task="fill-mask", model=model_path, token=TOKEN, device=device)
pipe = None
@spaces.GPU
def predict(text, dialect):
global pipe
if pipe is None or dialect != predict.current_dialect: # Reload model if dialect changes
pipe = load_model(dialect)
predict.current_dialect = dialect
outputs = pipe(text)
scores = [x["score"] for x in outputs]
tokens = [x["token_str"] for x in outputs]
return {label: float(prob) for label, prob in zip(tokens, scores)}
# Initialize current dialect
predict.current_dialect = None
# Create Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
# Dropdown for dialect selection
dialect_dropdown = gr.Dropdown(
choices=["Arabic", "Tunisian", "Moroccan", "Algerian", "Egyptian"],
label="Select Dialect",
value="Arabic"
)
# Input text box
input_text = gr.Textbox(
label="Input",
placeholder="Enter text here...",
rtl=True
)
# Button row
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit", variant="primary")
# Examples section (stateful)
examples_state = gr.State(EXAMPLES["Arabic"]) # Store examples for the current dialect
example_dropdown = gr.Dropdown(
choices=EXAMPLES["Arabic"],
label="Select Example",
interactive=True
)
load_example_btn = gr.Button("Load Example")
with gr.Column():
# Output probabilities
output_labels = gr.Label(
label="Prediction Results",
show_label=False
)
# Function to update examples when dialect changes
def update_example_choices(dialect):
return gr.update(choices=EXAMPLES.get(dialect, EXAMPLES["Arabic"]), value=EXAMPLES[dialect][0])
# Function to load selected example into the text box
def load_example(selected_example):
return selected_example
# Update example dropdown choices dynamically
dialect_dropdown.change(
update_example_choices,
inputs=dialect_dropdown,
outputs=example_dropdown
)
# Load selected example into the input text box
load_example_btn.click(
load_example,
inputs=example_dropdown,
outputs=input_text
)
# Button actions
submit_btn.click(
predict,
inputs=[input_text, dialect_dropdown],
outputs=output_labels
)
clear_btn.click(
lambda: "",
outputs=input_text
)
# Launch the app
demo.launch()