Spaces:
Runtime error
Runtime error
import torch | |
from PIL import ImageDraw, ImageFont, Image | |
from transformers import AutoModelForTokenClassification, AutoProcessor | |
import fitz # PyMuPDF | |
import io | |
import os | |
MODEL_KEY = os.getenv("MODEL_KEY") | |
def extract_data_from_pdf(pdf_path, page_number=0): | |
""" | |
Extracts image, words, and bounding boxes from a specified page of a PDF. | |
Args: | |
- pdf_path (str): Path to the PDF file. | |
- page_number (int): Page number to extract data from (0-indexed). | |
Returns: | |
- image: An image of the specified page. | |
- words: A list of words found on the page. | |
- boxes: A list of bounding boxes corresponding to the words. | |
""" | |
# Open the PDF | |
doc = fitz.open(pdf_path) | |
page = doc.load_page(page_number) | |
# Extract image of the page | |
pix = page.get_pixmap() | |
image_bytes = pix.tobytes("png") | |
image = Image.open(io.BytesIO(image_bytes)) | |
# Extract words and their bounding boxes | |
words = [] | |
boxes = [] | |
for word in page.get_text("words"): | |
words.append(word[4]) | |
boxes.append(word[:4]) # (x0, y0, x1, y1) | |
doc.close() | |
return image, words, boxes | |
def merge_pairs_v2(pairs): | |
if not pairs: | |
return [] | |
merged = [pairs[0]] | |
for current in pairs[1:]: | |
last = merged[-1] | |
if last[0] == current[0]: | |
# Merge 'y' values (as strings) if 'x' values are the same | |
merged[-1] = [last[0], last[1] + " " + current[1]] | |
else: | |
merged.append(current) | |
return merged | |
def create_pretty_table(data): | |
table = "<div>" | |
for row in data: | |
color = ( | |
"blue" | |
if row[0] == "Heder" | |
else "green" | |
if row[0] == "Section" | |
else "black" | |
) | |
table += "<p style='color:{};'>---{}---</p>{}".format( | |
color, row[0], row[1] | |
) | |
table += "</div>" | |
return table | |
# When using this function in Gradio, set the output type to 'html' | |
def interference(example, page_number=0): | |
image, words, boxes = extract_data_from_pdf(example, page_number) | |
boxes = [list(map(int, box)) for box in boxes] | |
# Process the image and words | |
model = AutoModelForTokenClassification.from_pretrained( | |
"karida/LayoutLMv3_RFP", | |
use_auth_token=MODEL_KEY | |
) | |
processor = AutoProcessor.from_pretrained( | |
"microsoft/layoutlmv3-base", apply_ocr=False | |
) | |
encoding = processor(image, words, boxes=boxes, return_tensors="pt") | |
# Prediction | |
with torch.no_grad(): | |
outputs = model(**encoding) | |
logits = outputs.logits | |
predictions = logits.argmax(-1).squeeze().tolist() | |
model_words = encoding.word_ids() | |
# Process predictions | |
token_boxes = encoding.bbox.squeeze().tolist() | |
width, height = image.size | |
true_predictions = [model.config.id2label[pred] for pred in predictions] | |
true_boxes = token_boxes | |
# Draw annotations on the image | |
draw = ImageDraw.Draw(image) | |
font = ImageFont.load_default() | |
def iob_to_label(label): | |
label = label[2:] | |
return "other" if not label else label.lower() | |
label2color = { | |
"question": "blue", | |
"answer": "green", | |
"header": "orange", | |
"other": "violet", | |
} | |
# print(len(true_predictions), len(true_boxes), len(model_words)) | |
table = [] | |
ids = set() | |
for prediction, box, model_word in zip( | |
true_predictions, true_boxes, model_words | |
): | |
predicted_label = iob_to_label(prediction) | |
draw.rectangle(box, outline=label2color[predicted_label], width=2) | |
# draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font) | |
if model_word and model_word not in ids and predicted_label != "other": | |
ids.add(model_word) | |
table.append([predicted_label[0], words[model_word]]) | |
values = merge_pairs_v2(table) | |
values = [ | |
["Heder", x[1]] if x[0] == "q" else ["Section", x[1]] for x in values | |
] | |
table = create_pretty_table(values) | |
return image, table | |
import gradio as gr | |
description_text = """ | |
<p> | |
Heading - <span style='color: blue;'>shown in blue</span><br> | |
Section - <span style='color: green;'>shown in green</span><br> | |
other - (ignored)<span style='color: violet;'>shown in violet</span> | |
</p> | |
""" | |
flagging_options = ["great example", "bad example"] | |
iface = gr.Interface( | |
fn=interference, | |
inputs=["file", "number"], | |
outputs=["image", "html"], | |
# examples=[["output.pdf", 1]], | |
description=description_text, | |
flagging_options=flagging_options, | |
) | |
# iface.save(".") | |
if __name__ == "__main__": | |
iface.launch() | |