File size: 5,618 Bytes
2d171f0
 
 
 
 
 
afcf6db
e9636ac
 
 
 
 
afcf6db
2d171f0
d29967d
 
7d4d815
 
 
 
 
 
d29967d
 
7d4d815
 
 
 
 
 
 
 
 
d29967d
e4810d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d29967d
 
 
afcf6db
2d171f0
afcf6db
a1485ce
 
2d171f0
afcf6db
2d171f0
5c36335
afcf6db
 
d29967d
afcf6db
 
2d171f0
d29967d
 
 
2d171f0
d29967d
afcf6db
 
d29967d
2d171f0
 
 
d29967d
afcf6db
 
d29967d
 
afcf6db
 
d29967d
2d171f0
 
d29967d
2d171f0
 
 
d29967d
2d171f0
 
 
d29967d
 
 
 
 
 
 
 
 
 
 
a1485ce
2d171f0
d29967d
 
 
 
 
a1485ce
d29967d
 
 
2d171f0
d29967d
 
 
afcf6db
d29967d
afcf6db
d29967d
afcf6db
d29967d
afcf6db
 
d29967d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d171f0
d29967d
2d171f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import torch
from transformers import pipeline
import os
import spaces

MODELS = {
    "Moroccan": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-ARY",
    "Arabic": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-MSA", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-ini-8192-mx-8192
    "Egyptian": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-ARZ",
    "Tunisian": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-ARY",
    "Algerian": "BounharAbdelaziz/ModernBERT-Arabic-base-stage-1-pre-decay-predef-bkpt-ini-1024-mx8192", # "BounharAbdelaziz/ModernBERT-Arabic-base-stage-3-decay-mx8192-ALGERIAN",
}

EXAMPLES = {
    "Moroccan": [
        "الدار البيضاء [MASK]", 
        "المغرب بلاد [MASK]",
        "كناكل [MASK] فالمغرب",
        "العاصمة د [MASK] هي الرباط",
        "المغرب [MASK] زوين",
        "انا سميتي مريم، و كنسكن ف[MASK] العاصمة دفلسطين"
    ],
    "Arabic": [
        "العاصمة الرسمية لمصر هي [MASK].",
        "أطول نهر في العالم هو نهر [MASK].",
        "الشاعر العربي المشهور [MASK] كتب قصيدة 'أراك عصي الدمع'.",
        "عندما أستيقظ في الصباح، أشرب فنجان من [MASK].",
        "في التأني [MASK] وفي العجلة الندامة.",
        "معركة [MASK] كانت من أهم المعارك في تاريخ الإسلام.",
        "يعتبر [MASK] من أهم العلماء في مجال الفيزياء.",
        "تقع جبال [MASK] في شمال إفريقيا.",
        "يعتبر [MASK] من أركان الإسلام الخمسة."
    ],
    "Egyptian": [
        "القاهرة مدينة [MASK]", 
        "مصر بلاد [MASK]", 
        "بنحب [MASK] فمصر"
    ],
    "Tunisian": [
        "تونس بلاد [MASK]", 
        "المنستير مدينة [MASK]", 
        "عيشتي في [MASK]"
    ],
    "Algerian": [
        "الجزائر بلاد [MASK]", 
        "قسنطينة مدينة [MASK]", 
        "نحبو [MASK] ف الجزائر"
    ],
}

TOKEN = os.environ["HF_TOKEN"]
device = "cuda:0" if torch.cuda.is_available() else "cpu"

def load_model(dialect):
    model_path = MODELS.get(dialect, MODELS["Arabic"])
    return pipeline(task="fill-mask", model=model_path, token=TOKEN, device=device)

pipe = None

@spaces.GPU 
def predict(text, dialect):
    global pipe
    if pipe is None or dialect != predict.current_dialect:  # Reload model if dialect changes
        pipe = load_model(dialect)
        predict.current_dialect = dialect
    outputs = pipe(text)
    scores = [x["score"] for x in outputs]
    tokens = [x["token_str"] for x in outputs]
    return {label: float(prob) for label, prob in zip(tokens, scores)}

# Initialize current dialect
predict.current_dialect = None

# Create Gradio interface
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            # Dropdown for dialect selection
            dialect_dropdown = gr.Dropdown(
                choices=["Arabic", "Tunisian", "Moroccan", "Algerian", "Egyptian"],
                label="Select Dialect",
                value="Arabic"
            )
            
            # Input text box
            input_text = gr.Textbox(
                label="Input",
                placeholder="Enter text here...",
                rtl=True
            )
            
            # Button row
            with gr.Row():
                clear_btn = gr.Button("Clear")
                submit_btn = gr.Button("Submit", variant="primary")
            
            # Examples section (stateful)
            examples_state = gr.State(EXAMPLES["Arabic"])  # Store examples for the current dialect
            
            example_dropdown = gr.Dropdown(
                choices=EXAMPLES["Arabic"],
                label="Select Example",
                interactive=True
            )
            
            load_example_btn = gr.Button("Load Example")

        with gr.Column():
            # Output probabilities
            output_labels = gr.Label(
                label="Prediction Results",
                show_label=False
            )

    # Function to update examples when dialect changes
    def update_example_choices(dialect):
        return gr.update(choices=EXAMPLES.get(dialect, EXAMPLES["Arabic"]), value=EXAMPLES[dialect][0])

    # Function to load selected example into the text box
    def load_example(selected_example):
        return selected_example

    # Update example dropdown choices dynamically
    dialect_dropdown.change(
        update_example_choices,
        inputs=dialect_dropdown,
        outputs=example_dropdown
    )

    # Load selected example into the input text box
    load_example_btn.click(
        load_example,
        inputs=example_dropdown,
        outputs=input_text
    )

    # Button actions
    submit_btn.click(
        predict,
        inputs=[input_text, dialect_dropdown],
        outputs=output_labels
    )
    
    clear_btn.click(
        lambda: "",
        outputs=input_text
    )

# Launch the app
demo.launch()