File size: 4,357 Bytes
cc5f7b7
61d07b2
cc5f7b7
a47bf95
a671104
a47bf95
cc5f7b7
a47bf95
cc5f7b7
a47bf95
bb7b589
9eb5356
cad4ba9
 
 
c0d1798
cad4ba9
 
 
ec0123f
 
 
 
 
9eb5356
 
a47bf95
 
 
 
a671104
 
a47bf95
 
 
 
 
 
 
a671104
a47bf95
 
 
 
 
 
 
 
 
 
 
a671104
a47bf95
 
 
 
 
cc5f7b7
 
a68694f
a671104
cc5f7b7
 
a47bf95
 
 
 
 
 
 
a671104
cc5f7b7
 
 
 
a671104
 
cc5f7b7
a47bf95
 
 
 
 
 
cc5f7b7
 
 
 
a47bf95
 
 
 
 
 
 
 
cc5f7b7
 
 
 
 
a47bf95
cc5f7b7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
os.system('pip install -U transformers==4.44.2')
import sys
import shutil
import torch
import argparse
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download

# == download weights ==
tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny')
small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small')
base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base')
os.system("ls -l models/unimernet_tiny")
# os.system(f"sed -i 's/MODEL_DIR/{tiny_model_dir}/g' cfg_tiny.yaml")
# os.system(f"sed -i 's/MODEL_DIR/{small_model_dir}/g' cfg_small.yaml")
# os.system(f"sed -i 's/MODEL_DIR/{base_model_dir}/g' cfg_base.yaml")
# root_path = os.path.abspath(os.getcwd())
# os.makedirs(os.path.join(root_path, "models"), exist_ok=True)
# shutil.move(tiny_model_dir, os.path.join(root_path, "models", "unimernet_tiny"))
# shutil.move(small_model_dir, os.path.join(root_path, "models", "unimernet_small"))
# shutil.move(base_model_dir, os.path.join(root_path, "models", "unimernet_base"))
# == download weights ==

sys.path.insert(0, os.path.join(os.getcwd(), ".."))
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor


def load_model_and_processor(cfg_path):
    args = argparse.Namespace(cfg_path=cfg_path, options=None)
    cfg = Config(args)
    task = tasks.setup_task(cfg)
    model = task.build_model(cfg)
    vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
    return model, vis_processor
    
def recognize_image(input_img, model_type):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if model_type == "base":
        model = model_base.to(device)
    elif model_type == "small":
        model = model_small.to(device)
    else:
        model = model_tiny.to(device)
        
    if len(input_img.shape) == 3:
            input_img = input_img[:, :, ::-1].copy()

    img = Image.fromarray(input_img)
    image = vis_processor(img).unsqueeze(0).to(device)
    output = model.generate({"image": image})
    latex_code = output["pred_str"][0]
    return latex_code
    
def gradio_reset():
    return gr.update(value=None), gr.update(value=None)

    
if __name__ == "__main__":
    root_path = os.path.abspath(os.getcwd())
    # == load model ==
    model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
    model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
    model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
    print("== load all models ==")
    # == load model ==
    
    with open("header.html", "r") as file:
        header = file.read()
    with gr.Blocks() as demo:
        gr.HTML(header)
        
        with gr.Row():
            with gr.Column():
                model_type = gr.Radio(
                        choices=["tiny", "small", "base"],
                        value="tiny",
                        label="Model Type",
                        interactive=True,
                    )
                input_img = gr.Image(label=" ", interactive=True)
                with gr.Row():
                    clear = gr.Button("Clear")
                    predict = gr.Button(value="Recognize", interactive=True, variant="primary")
                    
                with gr.Accordion("Examples:"):
                    example_root = os.path.join(os.path.dirname(__file__), "examples")
                    gr.Examples(
                        examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
                                    _.endswith("png")],
                        inputs=input_img,
                    )
            with gr.Column():
                gr.Button(value="Predict Latex:", interactive=False)
                pred_latex = gr.Textbox(label='Latex', interactive=False)
    
        clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
        predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex])
    
    demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)