LayoutLMv3_RFP / main.py
Karol Idaszak
Update main.py
d15bee8 verified
raw
history blame
4.68 kB
import torch
from PIL import ImageDraw, ImageFont, Image
from transformers import AutoModelForTokenClassification, AutoProcessor
import fitz # PyMuPDF
import io
import os
MODEL_KEY = os.getenv("MODEL_KEY")
def extract_data_from_pdf(pdf_path, page_number=0):
"""
Extracts image, words, and bounding boxes from a specified page of a PDF.
Args:
- pdf_path (str): Path to the PDF file.
- page_number (int): Page number to extract data from (0-indexed).
Returns:
- image: An image of the specified page.
- words: A list of words found on the page.
- boxes: A list of bounding boxes corresponding to the words.
"""
# Open the PDF
doc = fitz.open(pdf_path)
page = doc.load_page(page_number)
# Extract image of the page
pix = page.get_pixmap()
image_bytes = pix.tobytes("png")
image = Image.open(io.BytesIO(image_bytes))
# Extract words and their bounding boxes
words = []
boxes = []
for word in page.get_text("words"):
words.append(word[4])
boxes.append(word[:4]) # (x0, y0, x1, y1)
doc.close()
return image, words, boxes
def merge_pairs_v2(pairs):
if not pairs:
return []
merged = [pairs[0]]
for current in pairs[1:]:
last = merged[-1]
if last[0] == current[0]:
# Merge 'y' values (as strings) if 'x' values are the same
merged[-1] = [last[0], last[1] + " " + current[1]]
else:
merged.append(current)
return merged
def create_pretty_table(data):
table = "<div>"
for row in data:
color = (
"blue"
if row[0] == "Heder"
else "green"
if row[0] == "Section"
else "black"
)
table += "<p style='color:{};'>---{}---</p>{}".format(
color, row[0], row[1]
)
table += "</div>"
return table
# When using this function in Gradio, set the output type to 'html'
def interference(example, page_number=0):
image, words, boxes = extract_data_from_pdf(example, page_number)
boxes = [list(map(int, box)) for box in boxes]
# Process the image and words
model = AutoModelForTokenClassification.from_pretrained(
"karida/LayoutLMv3_RFP",
use_auth_token=MODEL_KEY
)
processor = AutoProcessor.from_pretrained(
"microsoft/layoutlmv3-base", apply_ocr=False
)
encoding = processor(image, words, boxes=boxes, return_tensors="pt")
# Prediction
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
predictions = logits.argmax(-1).squeeze().tolist()
model_words = encoding.word_ids()
# Process predictions
token_boxes = encoding.bbox.squeeze().tolist()
width, height = image.size
true_predictions = [model.config.id2label[pred] for pred in predictions]
true_boxes = token_boxes
# Draw annotations on the image
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
def iob_to_label(label):
label = label[2:]
return "other" if not label else label.lower()
label2color = {
"question": "blue",
"answer": "green",
"header": "orange",
"other": "violet",
}
# print(len(true_predictions), len(true_boxes), len(model_words))
table = []
ids = set()
for prediction, box, model_word in zip(
true_predictions, true_boxes, model_words
):
predicted_label = iob_to_label(prediction)
draw.rectangle(box, outline=label2color[predicted_label], width=2)
# draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
if model_word and model_word not in ids and predicted_label != "other":
ids.add(model_word)
table.append([predicted_label[0], words[model_word]])
values = merge_pairs_v2(table)
values = [
["Heder", x[1]] if x[0] == "q" else ["Section", x[1]] for x in values
]
table = create_pretty_table(values)
return image, table
import gradio as gr
description_text = """
<p>
Heading - <span style='color: blue;'>shown in blue</span><br>
Section - <span style='color: green;'>shown in green</span><br>
other - (ignored)<span style='color: violet;'>shown in violet</span>
</p>
"""
flagging_options = ["great example", "bad example"]
iface = gr.Interface(
fn=interference,
inputs=["file", "number"],
outputs=["image", "html"],
# examples=[["output.pdf", 1]],
description=description_text,
flagging_options=flagging_options,
)
# iface.save(".")
if __name__ == "__main__":
iface.launch()