Spaces:
Runtime error
Runtime error
import os | |
import tempfile | |
import fitz | |
import gradio as gr | |
import PIL | |
import skimage | |
from fastai.learner import load_learner | |
from fastai.vision.all import * | |
from fpdf import FPDF | |
from huggingface_hub import hf_hub_download | |
from icevision.all import * | |
from icevision.models.checkpoint import * | |
from PIL import Image as PILImage | |
checkpoint_path = "./2022-01-15-vfnet-post-self-train.pth" | |
checkpoint_and_model = model_from_checkpoint(checkpoint_path) | |
model = checkpoint_and_model["model"] | |
model_type = checkpoint_and_model["model_type"] | |
class_map = checkpoint_and_model["class_map"] | |
img_size = checkpoint_and_model["img_size"] | |
valid_tfms = tfms.A.Adapter( | |
[*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()] | |
) | |
learn = load_learner( | |
hf_hub_download("strickvl/redaction-classifier-fastai", "model.pkl") | |
) | |
labels = learn.dls.vocab | |
def predict(pdf, confidence, generate_file): | |
filename_without_extension = pdf.name[:-4] | |
document = fitz.open(pdf.name) | |
results = [] | |
images = [] | |
tmp_dir = tempfile.gettempdir() | |
for page_num, page in enumerate(document, start=1): | |
image_pixmap = page.get_pixmap() | |
image = image_pixmap.tobytes() | |
_, _, probs = learn.predict(image) | |
results.append( | |
{labels[i]: float(probs[i]) for i in range(len(labels))} | |
) | |
if probs[0] > (confidence / 100): | |
redaction_count = len(images) | |
if not os.path.exists( | |
os.path.join(tmp_dir, filename_without_extension) | |
): | |
os.makedirs(os.path.join(tmp_dir, filename_without_extension)) | |
image_pixmap.save( | |
os.path.join( | |
tmp_dir, filename_without_extension, f"page-{page_num}.png" | |
) | |
) | |
images.append( | |
[ | |
f"Redacted page #{redaction_count + 1} on page {page_num}", | |
os.path.join( | |
tmp_dir, | |
filename_without_extension, | |
f"page-{page_num}.png", | |
), | |
] | |
) | |
redacted_pages = [ | |
str(page + 1) | |
for page in range(len(results)) | |
if results[page]["redacted"] > (confidence / 100) | |
] | |
report = os.path.join( | |
tmp_dir, filename_without_extension, "redacted_pages.pdf" | |
) | |
if generate_file: | |
pdf = FPDF() | |
pdf.set_auto_page_break(0) | |
imagelist = sorted( | |
[ | |
i | |
for i in os.listdir( | |
os.path.join(tmp_dir, filename_without_extension) | |
) | |
if i.endswith("png") | |
] | |
) | |
for image in imagelist: | |
with PILImage.open( | |
os.path.join(tmp_dir, filename_without_extension, image) | |
) as img: | |
size = img.size | |
if size[0] > size[1]: | |
pdf.add_page("L") | |
else: | |
pdf.add_page("P") | |
pred_dict = model_type.end2end_detect( | |
img, | |
valid_tfms, | |
model, | |
class_map=class_map, | |
detection_threshold=0.7, | |
display_label=True, | |
display_bbox=True, | |
return_img=True, | |
font_size=16, | |
label_color="#FF59D6", | |
) | |
pred_dict["img"].save( | |
os.path.join( | |
tmp_dir, filename_without_extension, f"pred-{image}" | |
) | |
) | |
pdf.image( | |
os.path.join( | |
tmp_dir, filename_without_extension, f"pred-{image}" | |
) | |
) | |
pdf.output(report, "F") | |
text_output = f"A total of {len(redacted_pages)} pages were redacted. \n\n The redacted page numbers were: {', '.join(redacted_pages)}." | |
if generate_file: | |
return text_output, images, report | |
else: | |
return text_output, images, None | |
title = "Redaction Detector" | |
description = "A classifier trained on publicly released redacted (and unredacted) FOIA documents, using [fastai](https://github.com/fastai/fastai)." | |
with open("article.md") as f: | |
article = f.read() | |
examples = [["test1.pdf", 80, False], ["test2.pdf", 80, False]] | |
interpretation = "default" | |
enable_queue = True | |
theme = "grass" | |
allow_flagging = "never" | |
demo = gr.Interface( | |
fn=predict, | |
inputs=[ | |
"file", | |
gr.inputs.Slider( | |
minimum=0, | |
maximum=100, | |
step=None, | |
default=80, | |
label="Confidence", | |
optional=False, | |
), | |
"checkbox", | |
], | |
outputs=[ | |
gr.outputs.Textbox(label="Document Analysis"), | |
gr.outputs.Carousel(["text", "image"], label="Redacted pages"), | |
gr.outputs.File(label="Download redacted pages"), | |
], | |
title=title, | |
description=description, | |
article=article, | |
theme=theme, | |
allow_flagging=allow_flagging, | |
examples=examples, | |
interpretation=interpretation, | |
) | |
demo.launch( | |
cache_examples=True, | |
enable_queue=enable_queue, | |
) | |