Spaces:
Runtime error
Runtime error
import os | |
os.system('pip install -U transformers==4.44.2') | |
import sys | |
import shutil | |
import torch | |
import argparse | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from huggingface_hub import snapshot_download | |
# == download weights == | |
tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny') | |
small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small') | |
base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base') | |
os.system("ls -l models/unimernet_tiny") | |
# os.system(f"sed -i 's/MODEL_DIR/{tiny_model_dir}/g' cfg_tiny.yaml") | |
# os.system(f"sed -i 's/MODEL_DIR/{small_model_dir}/g' cfg_small.yaml") | |
# os.system(f"sed -i 's/MODEL_DIR/{base_model_dir}/g' cfg_base.yaml") | |
# root_path = os.path.abspath(os.getcwd()) | |
# os.makedirs(os.path.join(root_path, "models"), exist_ok=True) | |
# shutil.move(tiny_model_dir, os.path.join(root_path, "models", "unimernet_tiny")) | |
# shutil.move(small_model_dir, os.path.join(root_path, "models", "unimernet_small")) | |
# shutil.move(base_model_dir, os.path.join(root_path, "models", "unimernet_base")) | |
# == download weights == | |
sys.path.insert(0, os.path.join(os.getcwd(), "..")) | |
from unimernet.common.config import Config | |
import unimernet.tasks as tasks | |
from unimernet.processors import load_processor | |
def load_model_and_processor(cfg_path): | |
args = argparse.Namespace(cfg_path=cfg_path, options=None) | |
cfg = Config(args) | |
task = tasks.setup_task(cfg) | |
model = task.build_model(cfg) | |
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) | |
return model, vis_processor | |
def recognize_image(input_img, model_type): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if model_type == "base": | |
model = model_base.to(device) | |
elif model_type == "small": | |
model = model_small.to(device) | |
else: | |
model = model_tiny.to(device) | |
if len(input_img.shape) == 3: | |
input_img = input_img[:, :, ::-1].copy() | |
img = Image.fromarray(input_img) | |
image = vis_processor(img).unsqueeze(0).to(device) | |
output = model.generate({"image": image}) | |
latex_code = output["pred_str"][0] | |
return latex_code | |
def gradio_reset(): | |
return gr.update(value=None), gr.update(value=None) | |
if __name__ == "__main__": | |
root_path = os.path.abspath(os.getcwd()) | |
# == load model == | |
model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml")) | |
model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml")) | |
model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml")) | |
print("== load all models ==") | |
# == load model == | |
with open("header.html", "r") as file: | |
header = file.read() | |
with gr.Blocks() as demo: | |
gr.HTML(header) | |
with gr.Row(): | |
with gr.Column(): | |
model_type = gr.Radio( | |
choices=["tiny", "small", "base"], | |
value="tiny", | |
label="Model Type", | |
interactive=True, | |
) | |
input_img = gr.Image(label=" ", interactive=True) | |
with gr.Row(): | |
clear = gr.Button("Clear") | |
predict = gr.Button(value="Recognize", interactive=True, variant="primary") | |
with gr.Accordion("Examples:"): | |
example_root = os.path.join(os.path.dirname(__file__), "examples") | |
gr.Examples( | |
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if | |
_.endswith("png")], | |
inputs=input_img, | |
) | |
with gr.Column(): | |
gr.Button(value="Predict Latex:", interactive=False) | |
pred_latex = gr.Textbox(label='Latex', interactive=False) | |
clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex]) | |
predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex]) | |
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) |