File size: 3,991 Bytes
cc5f7b7
61d07b2
cc5f7b7
a47bf95
a671104
a47bf95
cc5f7b7
a47bf95
cc5f7b7
a47bf95
bb7b589
a47bf95
 
 
 
a671104
 
a47bf95
 
 
 
 
 
 
a671104
a47bf95
 
 
 
 
 
 
 
 
 
 
a671104
a47bf95
 
 
 
 
cc5f7b7
 
a68694f
a671104
cc5f7b7
 
a47bf95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a671104
cc5f7b7
 
 
 
a671104
 
cc5f7b7
a47bf95
 
 
 
 
 
cc5f7b7
 
 
 
a47bf95
 
 
 
 
 
 
 
cc5f7b7
 
 
 
 
a47bf95
cc5f7b7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
os.system('pip install -U transformers==4.44.2')
import sys
import shutil
import torch
import argparse
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download

sys.path.insert(0, os.path.join(os.getcwd(), ".."))
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor


def load_model_and_processor(cfg_path):
    args = argparse.Namespace(cfg_path=cfg_path, options=None)
    cfg = Config(args)
    task = tasks.setup_task(cfg)
    model = task.build_model(cfg)
    vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
    return model, vis_processor
    
def recognize_image(input_img, model_type):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if model_type == "base":
        model = model_base.to(device)
    elif model_type == "small":
        model = model_small.to(device)
    else:
        model = model_tiny.to(device)
        
    if len(input_img.shape) == 3:
            input_img = input_img[:, :, ::-1].copy()

    img = Image.fromarray(input_img)
    image = vis_processor(img).unsqueeze(0).to(device)
    output = model.generate({"image": image})
    latex_code = output["pred_str"][0]
    return latex_code
    
def gradio_reset():
    return gr.update(value=None), gr.update(value=None)

    
if __name__ == "__main__":
    root_path = os.path.abspath(os.getcwd())
    
    # == download weights ==
    tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny')
    small_model_dir = snapshot_download('wanderkid/unimernet_small')
    base_model_dir = snapshot_download('wanderkid/unimernet_base')
    
    os.makedirs(os.path.join(root_path, "models"), exist_ok=True)
    shutil.move(tiny_model_dir, os.path.join(root_path, "models", "unimernet_tiny"))
    shutil.move(small_model_dir, os.path.join(root_path, "models", "unimernet_small"))
    shutil.move(base_model_dir, os.path.join(root_path, "models", "unimernet_base"))
    # == download weights ==
    
    # == load model ==
    model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
    model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
    model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
    print("== load all models ==")
    # == load model ==
    
    with open("header.html", "r") as file:
        header = file.read()
    with gr.Blocks() as demo:
        gr.HTML(header)
        
        with gr.Row():
            with gr.Column():
                model_type = gr.Radio(
                        choices=["tiny", "small", "base"],
                        value="tiny",
                        label="Model Type",
                        interactive=True,
                    )
                input_img = gr.Image(label=" ", interactive=True)
                with gr.Row():
                    clear = gr.Button("Clear")
                    predict = gr.Button(value="Recognize", interactive=True, variant="primary")
                    
                with gr.Accordion("Examples:"):
                    example_root = os.path.join(os.path.dirname(__file__), "examples")
                    gr.Examples(
                        examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
                                    _.endswith("png")],
                        inputs=input_img,
                    )
            with gr.Column():
                gr.Button(value="Predict Latex:", interactive=False)
                pred_latex = gr.Textbox(label='Latex', interactive=False)
    
        clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
        predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex])
    
    demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)