Alex Strick van Linschoten
update examples
2f2050d
raw
history blame
7.33 kB
import os
import tempfile
import fitz
import gradio as gr
import PIL
import skimage
from fastai.learner import load_learner
from fastai.vision.all import *
from fpdf import FPDF
from huggingface_hub import hf_hub_download
from icevision.all import *
from icevision.models.checkpoint import *
from PIL import Image as PILImage
checkpoint_path = "./allsynthetic-imgsize768.pth"
checkpoint_and_model = model_from_checkpoint(checkpoint_path)
model = checkpoint_and_model["model"]
model_type = checkpoint_and_model["model_type"]
class_map = checkpoint_and_model["class_map"]
img_size = checkpoint_and_model["img_size"]
valid_tfms = tfms.A.Adapter(
[*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()]
)
learn = load_learner(
hf_hub_download("strickvl/redaction-classifier-fastai", "model.pkl")
)
labels = learn.dls.vocab
def get_content_area(pred_dict) -> int:
if "content" not in pred_dict["detection"]["labels"]:
return 0
content_bboxes = [
pred_dict["detection"]["bboxes"][idx]
for idx, label in enumerate(pred_dict["detection"]["labels"])
if label == "content"
]
cb = content_bboxes[0]
return (cb.xmax - cb.xmin) * (cb.ymax - cb.ymin)
def get_redaction_area(pred_dict) -> int:
if "redaction" not in pred_dict["detection"]["labels"]:
return 0
redaction_bboxes = [
pred_dict["detection"]["bboxes"][idx]
for idx, label in enumerate(pred_dict["detection"]["labels"])
if label == "redaction"
]
return sum(
(bbox.xmax - bbox.xmin) * (bbox.ymax - bbox.ymin)
for bbox in redaction_bboxes
)
def predict(pdf, confidence, generate_file):
filename_without_extension = pdf.name[:-4]
document = fitz.open(pdf.name)
results = []
images = []
total_image_areas = 0
total_content_areas = 0
total_redaction_area = 0
tmp_dir = tempfile.gettempdir()
for page_num, page in enumerate(document, start=1):
image_pixmap = page.get_pixmap()
image = image_pixmap.tobytes()
_, _, probs = learn.predict(image)
results.append(
{labels[i]: float(probs[i]) for i in range(len(labels))}
)
if probs[0] > (confidence / 100):
redaction_count = len(images)
if not os.path.exists(
os.path.join(tmp_dir, filename_without_extension)
):
os.makedirs(os.path.join(tmp_dir, filename_without_extension))
image_pixmap.save(
os.path.join(
tmp_dir, filename_without_extension, f"page-{page_num}.png"
)
)
images.append(
[
f"Redacted page #{redaction_count + 1} on page {page_num}",
os.path.join(
tmp_dir,
filename_without_extension,
f"page-{page_num}.png",
),
]
)
redacted_pages = [
str(page + 1)
for page in range(len(results))
if results[page]["redacted"] > (confidence / 100)
]
report = os.path.join(
tmp_dir, filename_without_extension, "redacted_pages.pdf"
)
if generate_file:
pdf = FPDF()
pdf.set_auto_page_break(0)
imagelist = sorted(
[
i
for i in os.listdir(
os.path.join(tmp_dir, filename_without_extension)
)
if i.endswith("png")
]
)
for image in imagelist:
with PILImage.open(
os.path.join(tmp_dir, filename_without_extension, image)
) as img:
size = img.size
width, height = size
if width > height:
pdf.add_page(orientation="L")
else:
pdf.add_page(orientation="P")
pred_dict = model_type.end2end_detect(
img,
valid_tfms,
model,
class_map=class_map,
detection_threshold=confidence / 100,
display_label=True,
display_bbox=True,
return_img=True,
font_size=16,
label_color="#FF59D6",
)
# print(pred_dict)
total_image_areas += pred_dict["width"] * pred_dict["height"]
total_content_areas += get_content_area(pred_dict)
total_redaction_area += get_redaction_area(pred_dict)
pred_dict["img"].save(
os.path.join(
tmp_dir, filename_without_extension, f"pred-{image}"
),
)
# TODO: resize image such that it fits the pdf
pdf.image(
os.path.join(
tmp_dir, filename_without_extension, f"pred-{image}"
)
)
pdf.output(report, "F")
text_output = f"A total of {len(redacted_pages)} pages were redacted. \n\nThe redacted page numbers were: {', '.join(redacted_pages)}. \n\n"
if not generate_file:
return text_output, images, None
total_redaction_proportion = round(
(total_redaction_area / total_image_areas) * 100, 1
)
content_redaction_proportion = round(
(total_redaction_area / total_content_areas) * 100, 1
)
redaction_analysis = f"- {total_redaction_proportion}% of the total area of the redacted pages was redacted. \n- {content_redaction_proportion}% of the actual content of those redacted pages was redacted."
return text_output + redaction_analysis, images, report
title = "Redaction Detector for PDFs"
description = "An MVP app for detection, extraction and analysis of PDF documents that contain redactions. Two models are used for this demo, both trained on publicly released redacted (and unredacted) FOIA documents: \n\n - Classification model trained using [fastai](https://github.com/fastai/fastai) \n- Object detection model trained using [IceVision](https://airctic.com/0.12.0/)"
with open("article.md") as f:
article = f.read()
examples = [
["test1.pdf", 80, True],
["test2.pdf", 80, False],
["test3.pdf", 80, True],
["test4.pdf", 80, False],
]
interpretation = "default"
enable_queue = True
theme = "grass"
allow_flagging = "never"
demo = gr.Interface(
fn=predict,
inputs=[
gr.inputs.File(label="PDF file", file_count="single"),
gr.inputs.Slider(
minimum=0,
maximum=100,
step=None,
default=80,
label="Confidence",
optional=False,
),
gr.inputs.Checkbox(label="Extract redacted images", default=True),
],
outputs=[
gr.outputs.Textbox(label="Document Analysis"),
gr.outputs.Carousel(["text", "image"], label="Redacted pages"),
gr.outputs.File(label="Download redacted pages"),
],
title=title,
description=description,
article=article,
theme=theme,
allow_flagging=allow_flagging,
examples=examples,
interpretation=interpretation,
)
demo.launch(
cache_examples=True,
enable_queue=enable_queue,
)