johnlockejrr commited on
Commit
140a773
1 Parent(s): 1e236cb

Add application file

Browse files
Files changed (4) hide show
  1. README.md +6 -4
  2. app.py +128 -0
  3. examples/default.jpg +0 -0
  4. requirements.txt +1 -0
README.md CHANGED
@@ -1,13 +1,15 @@
1
  ---
2
- title: PyLaia-mcdonald V2
3
- emoji: 👁
4
  colorFrom: green
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: PyLaia
3
+ emoji: 🐢
4
  colorFrom: green
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.13.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ models:
12
+ - Teklia/pylaia-rimes
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from uuid import uuid4
2
+ import gradio as gr
3
+ from laia.scripts.htr.decode_ctc import run as decode
4
+ from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs
5
+ import sys
6
+ from tempfile import NamedTemporaryFile, mkdtemp
7
+ from pathlib import Path
8
+ from contextlib import redirect_stdout
9
+ import re
10
+ from huggingface_hub import snapshot_download
11
+
12
+ images = Path(mkdtemp())
13
+
14
+ IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})"
15
+ CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line
16
+ TEXT_PATTERN = r"\s*(?P<text>.*)\s*"
17
+ LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}")
18
+ models_name = ["johnlockejrr/pylaia-mcdonald_v2"]
19
+ MODELS = {}
20
+ DEFAULT_HEIGHT = 128
21
+
22
+
23
+ def get_width(image, height=DEFAULT_HEIGHT):
24
+ aspect_ratio = image.width / image.height
25
+ return height * aspect_ratio
26
+
27
+
28
+ def load_model(model_name):
29
+ if model_name not in MODELS:
30
+ MODELS[model_name] = Path(snapshot_download(model_name))
31
+ return MODELS[model_name]
32
+
33
+
34
+ def predict(model_name, input_img):
35
+ model_dir = load_model(model_name)
36
+
37
+ temperature = 2.0
38
+ batch_size = 1
39
+
40
+ weights_path = model_dir / "weights.ckpt"
41
+ syms_path = model_dir / "syms.txt"
42
+ language_model_params = {"language_model_weight": 1.0}
43
+ use_language_model = (model_dir / "tokens.txt").exists()
44
+ if use_language_model:
45
+ language_model_params.update(
46
+ {
47
+ "language_model_path": str(model_dir / "language_model.arpa.gz"),
48
+ "lexicon_path": str(model_dir / "lexicon.txt"),
49
+ "tokens_path": str(model_dir / "tokens.txt"),
50
+ }
51
+ )
52
+
53
+ common_args = CommonArgs(
54
+ checkpoint=str(weights_path.relative_to(model_dir)),
55
+ train_path=str(model_dir),
56
+ experiment_dirname="",
57
+ )
58
+ data_args = DataArgs(batch_size=batch_size, color_mode="L")
59
+ trainer_args = TrainerArgs(
60
+ # Disable progress bar else it messes with frontend display
61
+ progress_bar_refresh_rate=0
62
+ )
63
+ decode_args = DecodeArgs(
64
+ include_img_ids=True,
65
+ join_string="",
66
+ convert_spaces=True,
67
+ print_line_confidence_scores=True,
68
+ print_word_confidence_scores=False,
69
+ temperature=temperature,
70
+ use_language_model=use_language_model,
71
+ **language_model_params,
72
+ )
73
+
74
+ with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list:
75
+ image_id = uuid4()
76
+ # Resize image to 128 if bigger/smaller
77
+ input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT))
78
+ input_img.save(str(images / f"{image_id}.jpg"))
79
+ # Export image list
80
+ Path(img_list.name).write_text("\n".join([str(image_id)]))
81
+
82
+ # Capture stdout as that's where PyLaia outputs predictions
83
+ with redirect_stdout(open(pred_stdout.name, mode="w")):
84
+ decode(
85
+ syms=str(syms_path),
86
+ img_list=img_list.name,
87
+ img_dirs=[str(images)],
88
+ common=common_args,
89
+ data=data_args,
90
+ trainer=trainer_args,
91
+ decode=decode_args,
92
+ num_workers=1,
93
+ )
94
+ # Flush stdout to avoid output buffering
95
+ sys.stdout.flush()
96
+ predictions = Path(pred_stdout.name).read_text().strip().splitlines()
97
+ assert len(predictions) == 1
98
+ _, score, text = LINE_PREDICTION.match(predictions[0]).groups()
99
+ return input_img, {"text": text, "score": score}
100
+
101
+
102
+ gradio_app = gr.Interface(
103
+ predict,
104
+ inputs=[
105
+ gr.Dropdown(models_name, value=models_name[0], label="Models"),
106
+ gr.Image(
107
+ label="Upload an image of a line",
108
+ sources=["upload", "clipboard"],
109
+ type="pil",
110
+ height=DEFAULT_HEIGHT,
111
+ width=2000,
112
+ image_mode="L",
113
+ ),
114
+ ],
115
+ outputs=[
116
+ gr.Image(label="Processed Image"),
117
+ gr.JSON(label="Decoded text"),
118
+ ],
119
+ examples=[
120
+ ["johnlockejrr/pylaia-mcdonald_v2", str(filename)]
121
+ for filename in Path("examples").iterdir()
122
+ ],
123
+ title="Decode the transcription of an image using a PyLaia model",
124
+ cache_examples=True,
125
+ )
126
+
127
+ if __name__ == "__main__":
128
+ gradio_app.launch()
examples/default.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pylaia==1.1.0