File size: 5,049 Bytes
2fb81a4 728e2c0 2fb81a4 055e9bd 2fb81a4 055e9bd 2fb81a4 0e0600f 2fb81a4 055e9bd 2fb81a4 922139f 0e0600f 2fb81a4 922139f 0e0600f 922139f 2fb81a4 055e9bd 2fb81a4 055e9bd 2fb81a4 055e9bd 2fb81a4 055e9bd 2fb81a4 055e9bd 3bacf2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import gradio as gr
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
from Ashaar.utils import get_output_df, get_highlighted_patterns_html
from Ashaar.bait_analysis import BaitAnalysis
from langs import *
import sys
import json
import argparse
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--lang', type = str, default = 'ar')
args = arg_parser.parse_args()
lang = args.lang
if lang == 'ar':
TITLE = TITLE_ar
DESCRIPTION = DESCRIPTION_ar
textbox_trg_text = textbox_trg_text_ar
textbox_inp_text = textbox_inp_text_ar
btn_trg_text = btn_trg_text_ar
btn_inp_text = btn_inp_text_ar
css = """ #textbox{ direction: RTL;}"""
else:
TITLE = TITLE_en
DESCRIPTION = DESCRIPTION_en
textbox_trg_text = textbox_trg_text_en
textbox_inp_text = textbox_inp_text_en
btn_trg_text = btn_trg_text_en
btn_inp_text = btn_inp_text_en
css = ""
gpt_tokenizer = AutoTokenizer.from_pretrained('arbml/ashaar_tokenizer')
model = AutoModelForCausalLM.from_pretrained('arbml/Ashaar_model')
theme_to_token = json.load(open("extra/theme_tokens.json", "r"))
token_to_theme = {t:m for m,t in theme_to_token.items()}
meter_to_token = json.load(open("extra/meter_tokens.json", "r"))
token_to_meter = {t:m for m,t in meter_to_token.items()}
analysis = BaitAnalysis()
meter, theme, qafiyah = "", "", ""
def analyze(poem):
global meter,theme,qafiyah, generate_btn
shatrs = poem.split("\n")
baits = [' # '.join(shatrs[2*i:2*i+2]) for i in range(len(shatrs)//2)]
output = analysis.analyze(baits,override_tashkeel=True)
meter = output['meter']
qafiyah = output['qafiyah'][0]
theme = output['theme'][-1]
df = get_output_df(output)
return get_highlighted_patterns_html(df), gr.Button.update(interactive=True)
def generate(inputs, top_p = 3):
baits = inputs.split('\n')
if len(baits) % 2 !=0:
baits = baits[:-1]
poem = ' '.join(['<|bsep|> '+baits[i]+' <|vsep|> '+baits[i+1]+' </|bsep|>' for i in range(0, len(baits), 2)])
prompt = f"""
{meter_to_token[meter]} {qafiyah} {theme_to_token[theme]}
<|psep|>
{poem}
""".strip()
print(prompt)
encoded_input = gpt_tokenizer(prompt, return_tensors='pt')
output = model.generate(**encoded_input, max_length = 512, top_p = 3, do_sample=True)
result = ""
prev_token = ""
line_cnts = 0
for i, beam in enumerate(output[:, len(encoded_input.input_ids[0]):]):
if line_cnts >= 10:
break
for token in beam:
if line_cnts >= 10:
break
decoded = gpt_tokenizer.decode(token)
if 'meter' in decoded or 'theme' in decoded:
break
if decoded in ["<|vsep|>", "</|bsep|>"]:
result += "\n"
line_cnts+=1
elif decoded in ['<|bsep|>', '<|psep|>', '</|psep|>']:
pass
else:
result += decoded
prev_token = decoded
else:
break
# return theme+" "+ f"ู
ู ุจุญุฑ {meter} ู
ุน ูุงููุฉ ุจุญุฑ ({qafiyah})" + "\n" +result
return result, gr.Button.update(interactive=False)
examples = [
[
"""ุงูููุจ ุฃุนูู
ูุง ุนุฐูู ุจุฏุงุฆู
ูุฃุญู ู
ูู ุจุฌููู ูุจู
ุงุฆู"""
],
[
"""ุฑู
ุชู ุงููุคุงุฏู ู
ููุญุฉ ุนุฐุฑุงุกู
ุจุณูุงู
ู ูุญุธู ู
ุง ููููู ุฏูุงุกู"""
],
[
"""ุฃุฐูููู ุงูุญูุฑูุตู ูุงูุทููู
ูุนู ุงูุฑูููุงุจูุง
ูููุฏ ููุนูู ุงูููุฑูู
ูุ ุฅุฐุง ุงุณุชูุฑูุงุจูุง"""
]
]
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
with gr.Row():
with gr.Column():
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
with gr.Row():
with gr.Column():
textbox_output = gr.Textbox(lines=10, label=textbox_trg_text, elem_id="textbox")
with gr.Column():
inputs = gr.Textbox(lines=10, label=textbox_inp_text, elem_id="textbox")
with gr.Row():
with gr.Column():
if lang == 'ar':
trg_btn = gr.Button(btn_trg_text, interactive=False)
else:
trg_btn = gr.Button(btn_trg_text)
with gr.Column():
if lang == 'ar':
inp_btn = gr.Button(btn_inp_text)
else:
inp_btn = gr.Button(btn_inp_text, interactive = False)
with gr.Row():
html_output = gr.HTML()
if lang == 'en':
gr.Examples(examples, textbox_output)
inp_btn.click(generate, inputs = textbox_output, outputs=[inputs, inp_btn])
trg_btn.click(analyze, inputs = textbox_output, outputs=[html_output,inp_btn])
else:
gr.Examples(examples, inputs)
trg_btn.click(generate, inputs = inputs, outputs=[textbox_output, trg_btn])
inp_btn.click(analyze, inputs = inputs, outputs=[html_output,trg_btn] )
# demo.launch(server_name = '0.0.0.0', share=True)
demo.launch() |