LayoutLM / app.py
konfuzio-com's picture
Update app.py
dcb13ea
import os
# workaround: install old version of pytorch since detectron2 hasn't released packages for pytorch 1.9 (issue: https://github.com/facebookresearch/detectron2/issues/3158)
# os.system('pip install torch==1.8.0+cu101 torchvision==0.9.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html')
os.system('pip install -q torch==1.10.0+cu111 torchvision==0.11+cu111 -f https://download.pytorch.org/whl/torch_stable.html')
# install detectron2 that matches pytorch 1.8
# See https://detectron2.readthedocs.io/tutorials/install.html for instructions
#os.system('pip install -q detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.8/index.html')
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
import gradio as gr
import re
import string
from operator import itemgetter
import collections
import pypdf
from pypdf import PdfReader
from pypdf.errors import PdfReadError
import pdf2image
from pdf2image import convert_from_path
import langdetect
from langdetect import detect_langs
import pandas as pd
import numpy as np
import random
import tempfile
import itertools
from matplotlib import font_manager
from PIL import Image, ImageDraw, ImageFont
import cv2
## files
import sys
sys.path.insert(0, 'files/')
import functions
from functions import *
# update pip
os.system('python -m pip install --upgrade pip')
## model / feature extractor / tokenizer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model
from transformers import LayoutLMv2ForTokenClassification
model_id = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
model = LayoutLMv2ForTokenClassification.from_pretrained(model_id);
model.to(device);
# feature extractor
from transformers import LayoutLMv2FeatureExtractor
feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
# tokenizer
from transformers import AutoTokenizer
tokenizer_id = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
# get labels
id2label = model.config.id2label
label2id = model.config.label2id
num_labels = len(id2label)
# APP outputs
def app_outputs(uploaded_pdf):
filename, msg, images = pdf_to_images(uploaded_pdf)
num_images = len(images)
if not msg.startswith("Error with the PDF"):
# Extraction of image data (text and bounding boxes)
dataset, lines, row_indexes, par_boxes, line_boxes = extraction_data_from_image(images)
# prepare our data in the format of the model
encoded_dataset = dataset.map(prepare_inference_features, batched=True, batch_size=64, remove_columns=dataset.column_names)
custom_encoded_dataset = CustomDataset(encoded_dataset, tokenizer)
# Get predictions (token level)
outputs, images_ids_list, chunk_ids, input_ids, bboxes = predictions_token_level(images, custom_encoded_dataset)
# Get predictions (line level)
probs_bbox, bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = predictions_line_level(dataset, outputs, images_ids_list, chunk_ids, input_ids, bboxes)
# Get labeled images with lines bounding boxes
images = get_labeled_images(dataset, images_ids_list, bboxes_list_dict, probs_dict_dict)
img_files = list()
# get image of PDF without bounding boxes
for i in range(num_images):
if filename != "files/blank.png": img_file = f"img_{i}_" + filename.replace(".pdf", ".png")
else: img_file = filename.replace(".pdf", ".png")
images[i].save(img_file)
img_files.append(img_file)
if num_images < max_imgboxes:
img_files += [image_blank]*(max_imgboxes - num_images)
images += [Image.open(image_blank)]*(max_imgboxes - num_images)
for count in range(max_imgboxes - num_images):
df[num_images + count] = pd.DataFrame()
else:
img_files = img_files[:max_imgboxes]
images = images[:max_imgboxes]
df = dict(itertools.islice(df.items(), max_imgboxes))
# save
csv_files = list()
for i in range(max_imgboxes):
csv_file = f"csv_{i}_" + filename.replace(".pdf", ".csv")
csv_files.append(gr.File.update(value=csv_file, visible=True))
df[i].to_csv(csv_file, encoding="utf-8", index=False)
else:
img_files, images, csv_files = [""]*max_imgboxes, [""]*max_imgboxes, [""]*max_imgboxes
img_files[0], img_files[1] = image_blank, image_blank
images[0], images[1] = Image.open(image_blank), Image.open(image_blank)
csv_file = "csv_wo_content.csv"
csv_files[0], csv_files[1] = gr.File.update(value=csv_file, visible=True), gr.File.update(value=csv_file, visible=True)
df, df_empty = dict(), pd.DataFrame()
df[0], df[1] = df_empty.to_csv(csv_file, encoding="utf-8", index=False), df_empty.to_csv(csv_file, encoding="utf-8", index=False)
return msg, img_files[0], img_files[1], images[0], images[1], csv_files[0], csv_files[1], df[0], df[1]
# gradio APP
with gr.Blocks(title="", css=".gradio-container") as demo:
with gr.Row():
pdf_file = gr.File(label="PDF")
with gr.Row():
submit_btn = gr.Button(f"Display first {max_imgboxes} labeled PDF pages")
reset_btn = gr.Button(value="Clear")
with gr.Row():
output_msg = gr.Textbox(label="Output message")
with gr.Row():
fileboxes = []
for num_page in range(max_imgboxes):
file_path = gr.File(visible=True, label=f"Image file of the PDF page n°{num_page}")
fileboxes.append(file_path)
with gr.Row():
imgboxes = []
for num_page in range(max_imgboxes):
img = gr.Image(type="pil", label=f"Image of the PDF page n°{num_page}")
imgboxes.append(img)
with gr.Row():
csvboxes = []
for num_page in range(max_imgboxes):
csv = gr.File(visible=True, label=f"CSV file at line level (page {num_page})")
csvboxes.append(csv)
with gr.Row():
dfboxes = []
for num_page in range(max_imgboxes):
df = gr.Dataframe(
headers=["bounding boxes", "texts", "labels"],
datatype=["str", "str", "str"],
col_count=(3, "fixed"),
visible=True,
label=f"Data of page {num_page}",
type="pandas",
wrap=True
)
dfboxes.append(df)
outputboxes = [output_msg] + fileboxes + imgboxes + csvboxes + dfboxes
submit_btn.click(app_outputs, inputs=[pdf_file], outputs=outputboxes)
reset_btn.click(
lambda: [pdf_file.update(value=None), output_msg.update(value=None)] + [filebox.update(value=None) for filebox in fileboxes] + [imgbox.update(value=None) for imgbox in imgboxes] + [csvbox.update(value=None) for csvbox in csvboxes] + [dfbox.update(value=None) for dfbox in dfboxes],
inputs=[],
outputs=[pdf_file, output_msg] + fileboxes + imgboxes + csvboxes + dfboxes,
)
gr.Examples(
[["files/example.pdf"]],
[pdf_file],
outputboxes,
fn=app_outputs,
cache_examples=True,
)
demo.launch()