sumen-base / README.md
hoang-quoc-trung's picture
Update README.md
9b8b16a verified
metadata
license: apache-2.0
pipeline_tag: image-to-text
datasets:
  - hoang-quoc-trung/fusion-image-to-latex-datasets
tags:
  - img2latex
  - latex ocr
  - Printed Mathematical Expression Recognition
  - Handwritten Mathematical Expression Recognition

Translating Math Formula Images To LaTeX Sequences

Scaling Up Image-to-LaTeX Performance: Sumen An End-to-End Transformer Model With Large Dataset

image/png

Performance

image/png

image/png

Uses

Source code: https://github.com/hoang-quoc-trung/sumen

Inference

import torch
import requests
from PIL import Image
from transformers import AutoProcessor, VisionEncoderDecoderModel

# Load model & processor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionEncoderDecoderModel.from_pretrained('hoang-quoc-trung/sumen-base').to(device)
processor = AutoProcessor.from_pretrained('hoang-quoc-trung/sumen-base')
task_prompt = processor.tokenizer.bos_token
decoder_input_ids = processor.tokenizer(
    task_prompt,
    add_special_tokens=False,
    return_tensors="pt"
).input_ids
# Load image
img_url = 'https://raw.githubusercontent.com/hoang-quoc-trung/sumen/main/assets/example_1.png'
image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
pixel_values = processor.image_processor(
    image,
    return_tensors="pt",
    data_format="channels_first",
).pixel_values
# Generate LaTeX expression
with torch.no_grad():
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_length,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=4,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
sequence = processor.tokenizer.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(
        processor.tokenizer.eos_token, ""
    ).replace(
        processor.tokenizer.pad_token, ""
    ).replace(processor.tokenizer.bos_token,"")
print(sequence)