File size: 5,049 Bytes
2fb81a4
 
 
 
 
 
 
 
 
 
 
 
 
728e2c0
2fb81a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
055e9bd
2fb81a4
 
 
 
 
 
 
055e9bd
2fb81a4
 
 
0e0600f
 
2fb81a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
055e9bd
 
2fb81a4
 
 
 
 
 
922139f
0e0600f
2fb81a4
922139f
 
0e0600f
922139f
2fb81a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
055e9bd
 
 
 
 
2fb81a4
055e9bd
 
 
 
2fb81a4
 
 
055e9bd
2fb81a4
 
055e9bd
 
2fb81a4
 
055e9bd
 
 
3bacf2f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import gradio as gr
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
from Ashaar.utils import get_output_df, get_highlighted_patterns_html
from Ashaar.bait_analysis import BaitAnalysis
from langs import *
import sys
import json
import argparse

arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--lang', type = str, default = 'ar')
args = arg_parser.parse_args()
lang = args.lang

if lang == 'ar':
    TITLE = TITLE_ar
    DESCRIPTION = DESCRIPTION_ar
    textbox_trg_text = textbox_trg_text_ar
    textbox_inp_text = textbox_inp_text_ar
    btn_trg_text = btn_trg_text_ar
    btn_inp_text = btn_inp_text_ar
    css = """ #textbox{ direction: RTL;}"""

else:
    TITLE = TITLE_en
    DESCRIPTION = DESCRIPTION_en
    textbox_trg_text = textbox_trg_text_en
    textbox_inp_text = textbox_inp_text_en
    btn_trg_text = btn_trg_text_en
    btn_inp_text = btn_inp_text_en
    css = ""

gpt_tokenizer = AutoTokenizer.from_pretrained('arbml/ashaar_tokenizer')
model = AutoModelForCausalLM.from_pretrained('arbml/Ashaar_model')

theme_to_token = json.load(open("extra/theme_tokens.json", "r"))
token_to_theme = {t:m for m,t in theme_to_token.items()}
meter_to_token = json.load(open("extra/meter_tokens.json", "r"))
token_to_meter = {t:m for m,t in meter_to_token.items()}

analysis = BaitAnalysis()
meter, theme, qafiyah = "", "", ""

def analyze(poem):
    global meter,theme,qafiyah, generate_btn
    shatrs = poem.split("\n")
    baits = [' # '.join(shatrs[2*i:2*i+2]) for i in range(len(shatrs)//2)]
    output = analysis.analyze(baits,override_tashkeel=True)
    meter = output['meter']
    qafiyah = output['qafiyah'][0]
    theme = output['theme'][-1]
    df = get_output_df(output)
    return get_highlighted_patterns_html(df), gr.Button.update(interactive=True)

def generate(inputs, top_p = 3):
    baits = inputs.split('\n')
    if len(baits) % 2 !=0:
        baits = baits[:-1]
    poem = ' '.join(['<|bsep|> '+baits[i]+' <|vsep|> '+baits[i+1]+' </|bsep|>' for i in range(0, len(baits), 2)])
    prompt = f"""
    {meter_to_token[meter]} {qafiyah} {theme_to_token[theme]}
    <|psep|>
    {poem}
    """.strip()
    print(prompt)
    encoded_input = gpt_tokenizer(prompt, return_tensors='pt')
    output = model.generate(**encoded_input, max_length = 512, top_p = 3, do_sample=True)

    result = ""
    prev_token = ""
    line_cnts = 0
    for i, beam in enumerate(output[:, len(encoded_input.input_ids[0]):]):
        if line_cnts >= 10:
            break
        for token in beam:
            if line_cnts >= 10:
                break
            decoded = gpt_tokenizer.decode(token)
            if 'meter' in decoded or 'theme' in decoded:
                break
            if decoded in ["<|vsep|>", "</|bsep|>"]:
                result += "\n"
                line_cnts+=1
            elif decoded in ['<|bsep|>', '<|psep|>', '</|psep|>']:
                pass
            else:
                result += decoded
            prev_token = decoded
        else:
            break
    # return theme+" "+ f"ู…ู† ุจุญุฑ {meter} ู…ุน ู‚ุงููŠุฉ ุจุญุฑ ({qafiyah})" + "\n" +result
    return result, gr.Button.update(interactive=False)

examples = [
    [
"""ุงู„ู‚ู„ุจ ุฃุนู„ู… ูŠุง ุนุฐูˆู„ ุจุฏุงุฆู‡
ูˆุฃุญู‚ ู…ู†ูƒ ุจุฌูู†ู‡ ูˆุจู…ุงุฆู‡"""
    ],
    [
"""ุฑู…ุชู ุงู„ูุคุงุฏูŽ ู…ู„ูŠุญุฉ ุนุฐุฑุงุกู
 ุจุณู‡ุงู…ู ู„ุญุธู ู…ุง ู„ู‡ู†ูŽู‘ ุฏูˆุงุกู"""
    ],
    [
"""ุฃุฐูŽู„ูŽู‘ ุงู„ุญูุฑู’ุตู ูˆุงู„ุทูŽู‘ู…ูŽุนู ุงู„ุฑูู‘ู‚ุงุจูŽุง
ูˆู‚ูŽุฏ ูŠูŽุนููˆ ุงู„ูƒูŽุฑูŠู…ูุŒ ุฅุฐุง ุงุณุชูŽุฑูŽุงุจูŽุง"""
    ]
]

with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
    with gr.Row():
        with gr.Column():
            gr.HTML(TITLE)
            gr.HTML(DESCRIPTION)

    with gr.Row():
        with gr.Column():
            textbox_output = gr.Textbox(lines=10, label=textbox_trg_text, elem_id="textbox")
        with gr.Column():
            inputs = gr.Textbox(lines=10, label=textbox_inp_text, elem_id="textbox")


    with gr.Row():
        with gr.Column():
            if lang == 'ar':
                trg_btn = gr.Button(btn_trg_text, interactive=False)
            else:
                trg_btn = gr.Button(btn_trg_text)

        with gr.Column():
            if lang == 'ar':
                inp_btn = gr.Button(btn_inp_text)
            else:
                inp_btn = gr.Button(btn_inp_text, interactive = False)

    with gr.Row():
        html_output = gr.HTML()
    
    if lang == 'en':
        gr.Examples(examples, textbox_output)
        inp_btn.click(generate, inputs = textbox_output, outputs=[inputs, inp_btn])
        trg_btn.click(analyze, inputs = textbox_output, outputs=[html_output,inp_btn])
    else:
        gr.Examples(examples, inputs)
        trg_btn.click(generate, inputs = inputs, outputs=[textbox_output, trg_btn])
        inp_btn.click(analyze, inputs = inputs, outputs=[html_output,trg_btn] )
    
# demo.launch(server_name = '0.0.0.0', share=True)
demo.launch()