wufan commited on
Commit
a47bf95
1 Parent(s): a68694f

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +65 -58
  2. cfg_base.yaml +46 -0
  3. cfg_small.yaml +46 -0
  4. cfg_tiny.yaml +46 -0
  5. requirements.txt +2 -1
app.py CHANGED
@@ -1,76 +1,69 @@
1
  import os
2
  import sys
3
- import argparse
4
- import numpy as np
5
-
6
- import cv2
7
  import torch
 
8
  import gradio as gr
 
9
  from PIL import Image
 
10
 
11
- # sys.path.insert(0, os.path.join(os.getcwd(), ".."))
12
- # from unimernet.common.config import Config
13
- # import unimernet.tasks as tasks
14
- # from unimernet.processors import load_processor
15
-
16
-
17
- # class ImageProcessor:
18
- # def __init__(self, cfg_path):
19
- # self.cfg_path = cfg_path
20
- # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- # self.model, self.vis_processor = self.load_model_and_processor()
22
-
23
- # def load_model_and_processor(self):
24
- # args = argparse.Namespace(cfg_path=self.cfg_path, options=None)
25
- # cfg = Config(args)
26
- # task = tasks.setup_task(cfg)
27
- # model = task.build_model(cfg).to(self.device)
28
- # vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
29
 
30
- # return model, vis_processor
31
 
32
- # def process_single_image(self, image_path):
33
- # try:
34
- # raw_image = Image.open(image_path)
35
- # except IOError:
36
- # print(f"Error: Unable to open image at {image_path}")
37
- # return
38
- # # Convert PIL Image to OpenCV format
39
- # open_cv_image = np.array(raw_image)
40
- # # Convert RGB to BGR
41
- # if len(open_cv_image.shape) == 3:
42
- # # Convert RGB to BGR
43
- # open_cv_image = open_cv_image[:, :, ::-1].copy()
44
- # # Display the image using cv2
45
-
46
- # image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
47
- # output = self.model.generate({"image": image})
48
- # pred = output["pred_str"][0]
49
- # print(f'Prediction:\n{pred}')
50
-
51
- # cv2.imshow('Original Image', open_cv_image)
52
- # cv2.waitKey(0)
53
- # cv2.destroyAllWindows()
54
-
55
- # return pred
56
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- def recognize_image(input_img):
59
- # latex_code = processor.process_single_image(input_img.name)
60
- return "100"
 
 
61
 
62
  def gradio_reset():
63
  return gr.update(value=None), gr.update(value=None)
64
 
65
 
66
  if __name__ == "__main__":
67
- # == init model ==
68
- # root_path = os.path.abspath(os.getcwd())
69
- # config_path = os.path.join(root_path, "cfg_tiny.yaml")
70
-
71
- # processor_tiny = ImageProcessor(config_path)
72
- # print("== all models init. ==")
73
- # == init model ==
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  with open("header.html", "r") as file:
76
  header = file.read()
@@ -79,15 +72,29 @@ if __name__ == "__main__":
79
 
80
  with gr.Row():
81
  with gr.Column():
 
 
 
 
 
 
82
  input_img = gr.Image(label=" ", interactive=True)
83
  with gr.Row():
84
  clear = gr.Button("Clear")
85
  predict = gr.Button(value="Recognize", interactive=True, variant="primary")
 
 
 
 
 
 
 
 
86
  with gr.Column():
87
  gr.Button(value="Predict Latex:", interactive=False)
88
  pred_latex = gr.Textbox(label='Latex', interactive=False)
89
 
90
  clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
91
- predict.click(recognize_image, inputs=[input_img], outputs=[pred_latex])
92
 
93
  demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
 
1
  import os
2
  import sys
3
+ import shutil
 
 
 
4
  import torch
5
+ import argparse
6
  import gradio as gr
7
+ import numpy as np
8
  from PIL import Image
9
+ from huggingface_hub import snapshot_download
10
 
11
+ sys.path.insert(0, os.path.join(os.getcwd(), ".."))
12
+ from unimernet.common.config import Config
13
+ import unimernet.tasks as tasks
14
+ from unimernet.processors import load_processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
16
 
17
+ def load_model_and_processor(cfg_path):
18
+ args = argparse.Namespace(cfg_path=cfg_path, options=None)
19
+ cfg = Config(args)
20
+ task = tasks.setup_task(cfg)
21
+ model = task.build_model(cfg)
22
+ vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
23
+ return model, vis_processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def recognize_image(input_img, model_type):
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ if model_type == "base":
28
+ model = model_base.to(device)
29
+ elif model_type == "small":
30
+ model = model_small.to(device)
31
+ else:
32
+ model = model_tiny.to(device)
33
+
34
+ if len(input_img.shape) == 3:
35
+ input_img = input_img[:, :, ::-1].copy()
36
 
37
+ img = Image.fromarray(input_img)
38
+ image = vis_processor(img).unsqueeze(0).to(device)
39
+ output = model.generate({"image": image})
40
+ latex_code = output["pred_str"][0]
41
+ return latex_code
42
 
43
  def gradio_reset():
44
  return gr.update(value=None), gr.update(value=None)
45
 
46
 
47
  if __name__ == "__main__":
48
+ root_path = os.path.abspath(os.getcwd())
49
+
50
+ # == download weights ==
51
+ tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny')
52
+ small_model_dir = snapshot_download('wanderkid/unimernet_small')
53
+ base_model_dir = snapshot_download('wanderkid/unimernet_base')
54
+
55
+ os.makedirs(os.path.join(root_path, "models"), exist_ok=True)
56
+ shutil.move(tiny_model_dir, os.path.join(root_path, "models", "unimernet_tiny"))
57
+ shutil.move(small_model_dir, os.path.join(root_path, "models", "unimernet_small"))
58
+ shutil.move(base_model_dir, os.path.join(root_path, "models", "unimernet_base"))
59
+ # == download weights ==
60
+
61
+ # == load model ==
62
+ model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
63
+ model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
64
+ model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
65
+ print("== load all models ==")
66
+ # == load model ==
67
 
68
  with open("header.html", "r") as file:
69
  header = file.read()
 
72
 
73
  with gr.Row():
74
  with gr.Column():
75
+ model_type = gr.Radio(
76
+ choices=["tiny", "small", "base"],
77
+ value="tiny",
78
+ label="Model Type",
79
+ interactive=True,
80
+ )
81
  input_img = gr.Image(label=" ", interactive=True)
82
  with gr.Row():
83
  clear = gr.Button("Clear")
84
  predict = gr.Button(value="Recognize", interactive=True, variant="primary")
85
+
86
+ with gr.Accordion("Examples:"):
87
+ example_root = os.path.join(os.path.dirname(__file__), "examples")
88
+ gr.Examples(
89
+ examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
90
+ _.endswith("png")],
91
+ inputs=input_img,
92
+ )
93
  with gr.Column():
94
  gr.Button(value="Predict Latex:", interactive=False)
95
  pred_latex = gr.Textbox(label='Latex', interactive=False)
96
 
97
  clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
98
+ predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex])
99
 
100
  demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
cfg_base.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: unimernet
3
+ model_type: unimernet
4
+ model_config:
5
+ model_name: ./models/unimernet_base
6
+ max_seq_len: 1536
7
+
8
+ load_pretrained: True
9
+ pretrained: './models/unimernet_base/unimernet_base.pth'
10
+ tokenizer_config:
11
+ path: ./models/unimernet_base
12
+
13
+ datasets:
14
+ formula_rec_eval:
15
+ vis_processor:
16
+ eval:
17
+ name: "formula_image_eval"
18
+ image_size:
19
+ - 192
20
+ - 672
21
+
22
+ run:
23
+ runner: runner_iter
24
+ task: unimernet_train
25
+
26
+ batch_size_train: 64
27
+ batch_size_eval: 64
28
+ num_workers: 1
29
+
30
+ iters_per_inner_epoch: 2000
31
+ max_iters: 60000
32
+
33
+ seed: 42
34
+ output_dir: "../output/demo"
35
+
36
+ evaluate: True
37
+ test_splits: [ "eval" ]
38
+
39
+ device: "cuda"
40
+ world_size: 1
41
+ dist_url: "env://"
42
+ distributed: True
43
+ distributed_type: ddp # or fsdp when train llm
44
+
45
+ generate_cfg:
46
+ temperature: 0.0
cfg_small.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: unimernet
3
+ model_type: unimernet
4
+ model_config:
5
+ model_name: ./models/unimernet_small
6
+ max_seq_len: 1536
7
+
8
+ load_pretrained: True
9
+ pretrained: './models/unimernet_small/unimernet_small.pth'
10
+ tokenizer_config:
11
+ path: ./models/unimernet_small
12
+
13
+ datasets:
14
+ formula_rec_eval:
15
+ vis_processor:
16
+ eval:
17
+ name: "formula_image_eval"
18
+ image_size:
19
+ - 192
20
+ - 672
21
+
22
+ run:
23
+ runner: runner_iter
24
+ task: unimernet_train
25
+
26
+ batch_size_train: 64
27
+ batch_size_eval: 64
28
+ num_workers: 1
29
+
30
+ iters_per_inner_epoch: 2000
31
+ max_iters: 60000
32
+
33
+ seed: 42
34
+ output_dir: "../output/demo"
35
+
36
+ evaluate: True
37
+ test_splits: [ "eval" ]
38
+
39
+ device: "cuda"
40
+ world_size: 1
41
+ dist_url: "env://"
42
+ distributed: True
43
+ distributed_type: ddp # or fsdp when train llm
44
+
45
+ generate_cfg:
46
+ temperature: 0.0
cfg_tiny.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: unimernet
3
+ model_type: unimernet
4
+ model_config:
5
+ model_name: ./models/unimernet_tiny
6
+ max_seq_len: 1536
7
+
8
+ load_pretrained: True
9
+ pretrained: './models/unimernet_tiny/unimernet_tiny.pth'
10
+ tokenizer_config:
11
+ path: ./models/unimernet_tiny
12
+
13
+ datasets:
14
+ formula_rec_eval:
15
+ vis_processor:
16
+ eval:
17
+ name: "formula_image_eval"
18
+ image_size:
19
+ - 192
20
+ - 672
21
+
22
+ run:
23
+ runner: runner_iter
24
+ task: unimernet_train
25
+
26
+ batch_size_train: 64
27
+ batch_size_eval: 64
28
+ num_workers: 1
29
+
30
+ iters_per_inner_epoch: 2000
31
+ max_iters: 60000
32
+
33
+ seed: 42
34
+ output_dir: "../output/demo"
35
+
36
+ evaluate: True
37
+ test_splits: [ "eval" ]
38
+
39
+ device: "cuda"
40
+ world_size: 1
41
+ dist_url: "env://"
42
+ distributed: True
43
+ distributed_type: ddp # or fsdp when train llm
44
+
45
+ generate_cfg:
46
+ temperature: 0.0
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  unimernet==0.1.6
2
- gradio
 
 
1
  unimernet==0.1.6
2
+ gradio
3
+ transformers==4.44.2