johnlockejrr commited on
Commit
3c9eefa
1 Parent(s): 5d0610f

Add application file

Browse files
README.md CHANGED
@@ -1,13 +1,15 @@
1
  ---
2
- title: PyLaia-ar-hand V1
3
- emoji: 👀
4
  colorFrom: green
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.37.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: PyLaia Arabic Handwritten v1
3
+ emoji: 🐢
4
  colorFrom: green
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ models:
12
+ - johnlockejrr/pylaia-ar-hand_v1
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from uuid import uuid4
2
+ import gradio as gr
3
+ from laia.scripts.htr.decode_ctc import run as decode
4
+ from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs
5
+ import sys
6
+ from tempfile import NamedTemporaryFile, mkdtemp
7
+ from pathlib import Path
8
+ from contextlib import redirect_stdout
9
+ import re
10
+ from huggingface_hub import snapshot_download
11
+ from bidi.algorithm import get_display
12
+
13
+ images = Path(mkdtemp())
14
+
15
+ IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})"
16
+ CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line
17
+ TEXT_PATTERN = r"\s*(?P<text>.*)\s*"
18
+ LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}")
19
+ models_name = ["johnlockejrr/pylaia-ar-hand_v1"]
20
+ MODELS = {}
21
+ DEFAULT_HEIGHT = 128
22
+
23
+
24
+ def get_width(image, height=DEFAULT_HEIGHT):
25
+ aspect_ratio = image.width / image.height
26
+ return height * aspect_ratio
27
+
28
+
29
+ def load_model(model_name):
30
+ if model_name not in MODELS:
31
+ MODELS[model_name] = Path(snapshot_download(model_name))
32
+ return MODELS[model_name]
33
+
34
+
35
+ def predict(model_name, input_img):
36
+ model_dir = load_model(model_name)
37
+
38
+ temperature = 2.0
39
+ batch_size = 1
40
+
41
+ weights_path = model_dir / "weights.ckpt"
42
+ syms_path = model_dir / "syms.txt"
43
+ language_model_params = {"language_model_weight": 1.0}
44
+ use_language_model = (model_dir / "tokens.txt").exists()
45
+ if use_language_model:
46
+ language_model_params.update(
47
+ {
48
+ "language_model_path": str(model_dir / "language_model.arpa.gz"),
49
+ "lexicon_path": str(model_dir / "lexicon.txt"),
50
+ "tokens_path": str(model_dir / "tokens.txt"),
51
+ }
52
+ )
53
+
54
+ common_args = CommonArgs(
55
+ checkpoint=str(weights_path.relative_to(model_dir)),
56
+ train_path=str(model_dir),
57
+ experiment_dirname="",
58
+ )
59
+ data_args = DataArgs(batch_size=batch_size, color_mode="L")
60
+ trainer_args = TrainerArgs(
61
+ # Disable progress bar else it messes with frontend display
62
+ progress_bar_refresh_rate=0
63
+ )
64
+ decode_args = DecodeArgs(
65
+ include_img_ids=True,
66
+ join_string="",
67
+ convert_spaces=True,
68
+ print_line_confidence_scores=True,
69
+ print_word_confidence_scores=False,
70
+ temperature=temperature,
71
+ use_language_model=use_language_model,
72
+ **language_model_params,
73
+ )
74
+
75
+ with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list:
76
+ image_id = uuid4()
77
+ # Resize image to 128 if bigger/smaller
78
+ input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT))
79
+ input_img.save(str(images / f"{image_id}.jpg"))
80
+ # Export image list
81
+ Path(img_list.name).write_text("\n".join([str(image_id)]))
82
+
83
+ # Capture stdout as that's where PyLaia outputs predictions
84
+ with redirect_stdout(open(pred_stdout.name, mode="w")):
85
+ decode(
86
+ syms=str(syms_path),
87
+ img_list=img_list.name,
88
+ img_dirs=[str(images)],
89
+ common=common_args,
90
+ data=data_args,
91
+ trainer=trainer_args,
92
+ decode=decode_args,
93
+ num_workers=1,
94
+ )
95
+ # Flush stdout to avoid output buffering
96
+ sys.stdout.flush()
97
+ predictions = Path(pred_stdout.name).read_text().strip().splitlines()
98
+ assert len(predictions) == 1
99
+ _, score, text = LINE_PREDICTION.match(predictions[0]).groups()
100
+ return input_img, {"text": get_display(u'%s' % text, base_dir='R'), "score": score}
101
+
102
+
103
+ gradio_app = gr.Interface(
104
+ predict,
105
+ inputs=[
106
+ gr.Dropdown(models_name, value=models_name[0], label="Models"),
107
+ gr.Image(
108
+ label="Upload an image of a line",
109
+ sources=["upload", "clipboard"],
110
+ type="pil",
111
+ height=DEFAULT_HEIGHT,
112
+ width=2000,
113
+ image_mode="L",
114
+ ),
115
+ ],
116
+ outputs=[
117
+ gr.Image(label="Processed Image"),
118
+ gr.JSON(label="Decoded text"),
119
+ ],
120
+ examples=[
121
+ ["johnlockejrr/pylaia-ar-hand_v1", str(filename)]
122
+ for filename in Path("examples").iterdir()
123
+ ],
124
+ title="Decode the transcription of an image using a PyLaia model",
125
+ cache_examples=True,
126
+ )
127
+
128
+ if __name__ == "__main__":
129
+ gradio_app.launch()
examples/AHTD3A0036_Para1_1.jpg ADDED
examples/AHTD3A0051_Para1_1.jpg ADDED
examples/AHTD3A0098_Para1_1.jpg ADDED
examples/AHTD3A0158_Para4_2.jpg ADDED
examples/AHTD3A0222_Para4_5.jpg ADDED
examples/AHTD3A0348_Para4_8.jpg ADDED
examples/AHTD3A0486_Para1_3.jpg ADDED
examples/default.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pylaia==1.1.0
2
+ python-bidi