Spaces:
Runtime error
Runtime error
File size: 3,991 Bytes
cc5f7b7 61d07b2 cc5f7b7 a47bf95 a671104 a47bf95 cc5f7b7 a47bf95 cc5f7b7 a47bf95 bb7b589 a47bf95 a671104 a47bf95 a671104 a47bf95 a671104 a47bf95 cc5f7b7 a68694f a671104 cc5f7b7 a47bf95 a671104 cc5f7b7 a671104 cc5f7b7 a47bf95 cc5f7b7 a47bf95 cc5f7b7 a47bf95 cc5f7b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import os
os.system('pip install -U transformers==4.44.2')
import sys
import shutil
import torch
import argparse
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
sys.path.insert(0, os.path.join(os.getcwd(), ".."))
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor
def load_model_and_processor(cfg_path):
args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args)
task = tasks.setup_task(cfg)
model = task.build_model(cfg)
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
return model, vis_processor
def recognize_image(input_img, model_type):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_type == "base":
model = model_base.to(device)
elif model_type == "small":
model = model_small.to(device)
else:
model = model_tiny.to(device)
if len(input_img.shape) == 3:
input_img = input_img[:, :, ::-1].copy()
img = Image.fromarray(input_img)
image = vis_processor(img).unsqueeze(0).to(device)
output = model.generate({"image": image})
latex_code = output["pred_str"][0]
return latex_code
def gradio_reset():
return gr.update(value=None), gr.update(value=None)
if __name__ == "__main__":
root_path = os.path.abspath(os.getcwd())
# == download weights ==
tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny')
small_model_dir = snapshot_download('wanderkid/unimernet_small')
base_model_dir = snapshot_download('wanderkid/unimernet_base')
os.makedirs(os.path.join(root_path, "models"), exist_ok=True)
shutil.move(tiny_model_dir, os.path.join(root_path, "models", "unimernet_tiny"))
shutil.move(small_model_dir, os.path.join(root_path, "models", "unimernet_small"))
shutil.move(base_model_dir, os.path.join(root_path, "models", "unimernet_base"))
# == download weights ==
# == load model ==
model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
print("== load all models ==")
# == load model ==
with open("header.html", "r") as file:
header = file.read()
with gr.Blocks() as demo:
gr.HTML(header)
with gr.Row():
with gr.Column():
model_type = gr.Radio(
choices=["tiny", "small", "base"],
value="tiny",
label="Model Type",
interactive=True,
)
input_img = gr.Image(label=" ", interactive=True)
with gr.Row():
clear = gr.Button("Clear")
predict = gr.Button(value="Recognize", interactive=True, variant="primary")
with gr.Accordion("Examples:"):
example_root = os.path.join(os.path.dirname(__file__), "examples")
gr.Examples(
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
_.endswith("png")],
inputs=input_img,
)
with gr.Column():
gr.Button(value="Predict Latex:", interactive=False)
pred_latex = gr.Textbox(label='Latex', interactive=False)
clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex])
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) |