Spaces:
Build error
Build error
from typing import Tuple, List, Sequence, Optional, Union | |
from torchvision import transforms | |
from torch import nn, Tensor | |
from PIL import Image | |
from pathlib import Path | |
from bs4 import BeautifulSoup as bs | |
import numpy as np | |
import numpy.typing as npt | |
from numpy import uint8 | |
ImageType = npt.NDArray[uint8] | |
from transformers import AutoModelForObjectDetection | |
import torch | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
from matplotlib.patches import Patch | |
from unitable import UnitableFullPredictor | |
#based on this notebook:https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Table%20Transformer/Inference_with_Table_Transformer_(TATR)_for_parsing_tables.ipynb | |
class MaxResize(object): | |
def __init__(self, max_size=800): | |
self.max_size = max_size | |
def __call__(self, image): | |
width, height = image.size | |
current_max_size = max(width, height) | |
scale = self.max_size / current_max_size | |
resized_image = image.resize((int(round(scale*width)), int(round(scale*height)))) | |
return resized_image | |
def iob(boxA, boxB): | |
""" | |
Calculate the Intersection over Bounding Box (IoB) of two bounding boxes. | |
Parameters: | |
- boxA: list or tuple with [xmin, ymin, xmax, ymax] of the first box | |
- boxB: list or tuple with [xmin, ymin, xmax, ymax] of the second box | |
Returns: | |
- iob: float, the IoB ratio | |
""" | |
# Determine the coordinates of the intersection rectangle | |
xA = max(boxA[0], boxB[0]) | |
yA = max(boxA[1], boxB[1]) | |
xB = min(boxA[2], boxB[2]) | |
yB = min(boxA[3], boxB[3]) | |
# Compute the area of intersection rectangle | |
interWidth = max(0, xB - xA) | |
interHeight = max(0, yB - yA) | |
interArea = interWidth * interHeight | |
# Compute the area of boxB (the second box) | |
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) | |
# Compute the Intersection over Bounding Box (IoB) ratio | |
iob = interArea / float(boxBArea) | |
return iob | |
class DetectionAndOcrTable2(): | |
#This components can take in entire pdf page as input , scan for tables and return the table in html format | |
#Uses the full unitable model - different to DetectionAndOcrTable1 | |
def __init__(self): | |
self.unitableFullPredictor = UnitableFullPredictor() | |
def save_detection(detected_lines_images:List[ImageType], prefix = './res/test1/res_'): | |
i = 0 | |
for img in detected_lines_images: | |
pilimg = Image.fromarray(img) | |
pilimg.save(prefix+str(i)+'.png') | |
i=i+1 | |
# for output bounding box post-processing | |
def box_cxcywh_to_xyxy(x): | |
x_c, y_c, w, h = x.unbind(-1) | |
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] | |
return torch.stack(b, dim=1) | |
def rescale_bboxes(out_bbox, size): | |
img_w, img_h = size | |
b = DetectionAndOcrTable2.box_cxcywh_to_xyxy(out_bbox) | |
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) | |
return b | |
def outputs_to_objects(outputs, img_size, id2label): | |
m = outputs.logits.softmax(-1).max(-1) | |
pred_labels = list(m.indices.detach().cpu().numpy())[0] | |
pred_scores = list(m.values.detach().cpu().numpy())[0] | |
pred_bboxes = outputs['pred_boxes'].detach().cpu()[0] | |
pred_bboxes = [elem.tolist() for elem in DetectionAndOcrTable2.rescale_bboxes(pred_bboxes, img_size)] | |
objects = [] | |
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): | |
class_label = id2label[int(label)] | |
if not class_label == 'no object': | |
objects.append({'label': class_label, 'score': float(score), | |
'bbox': [float(elem) for elem in bbox]}) | |
return objects | |
def visualize_detected_tables(img, det_tables, out_path=None): | |
plt.imshow(img, interpolation="lanczos") | |
fig = plt.gcf() | |
fig.set_size_inches(20, 20) | |
ax = plt.gca() | |
for det_table in det_tables: | |
bbox = det_table['bbox'] | |
if det_table['label'] == 'table': | |
facecolor = (1, 0, 0.45) | |
edgecolor = (1, 0, 0.45) | |
alpha = 0.3 | |
linewidth = 2 | |
hatch='//////' | |
elif det_table['label'] == 'table rotated': | |
facecolor = (0.95, 0.6, 0.1) | |
edgecolor = (0.95, 0.6, 0.1) | |
alpha = 0.3 | |
linewidth = 2 | |
hatch='//////' | |
else: | |
continue | |
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, | |
edgecolor='none',facecolor=facecolor, alpha=0.1) | |
ax.add_patch(rect) | |
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, | |
edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha) | |
ax.add_patch(rect) | |
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0, | |
edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2) | |
ax.add_patch(rect) | |
plt.xticks([], []) | |
plt.yticks([], []) | |
legend_elements = [Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45), | |
label='Table', hatch='//////', alpha=0.3), | |
Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1), | |
label='Table (rotated)', hatch='//////', alpha=0.3)] | |
plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0, | |
fontsize=10, ncol=2) | |
plt.gcf().set_size_inches(10, 10) | |
plt.axis('off') | |
if out_path is not None: | |
plt.savefig(out_path, bbox_inches='tight', dpi=150) | |
return fig | |
#For that, the TATR authors employ some padding to make sure the borders of the table are included. | |
def objects_to_crops(img, tokens, objects, class_thresholds, padding=10): | |
""" | |
Process the bounding boxes produced by the table detection model into | |
cropped table images and cropped tokens. | |
""" | |
table_crops = [] | |
for obj in objects: | |
# abit unecessary here cause i crop them anywyas | |
if obj['score'] < class_thresholds[obj['label']]: | |
print('skipping object with score', obj['score']) | |
continue | |
cropped_table = {} | |
bbox = obj['bbox'] | |
bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding] | |
cropped_img = img.crop(bbox) | |
# Add padding to the cropped image | |
padded_width = cropped_img.width + 40 | |
padded_height = cropped_img.height +40 | |
new_img_np = np.full((padded_height, padded_width, 3), fill_value=255, dtype=np.uint8) | |
y_offset = (padded_height - cropped_img.height) // 2 | |
x_offset = (padded_width - cropped_img.width) // 2 | |
new_img_np[y_offset:y_offset + cropped_img.height, x_offset:x_offset+cropped_img.width] = np.array(cropped_img) | |
padded_img = Image.fromarray(new_img_np,'RGB') | |
table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5] | |
for token in table_tokens: | |
token['bbox'] = [token['bbox'][0]-bbox[0] + padding, | |
token['bbox'][1]-bbox[1] + padding, | |
token['bbox'][2]-bbox[0] + padding, | |
token['bbox'][3]-bbox[1] + padding] | |
# If table is predicted to be rotated, rotate cropped image and tokens/words: | |
if obj['label'] == 'table rotated': | |
padded_img = padded_img.rotate(270, expand=True) | |
for token in table_tokens: | |
bbox = token['bbox'] | |
bbox = [padded_img.size[0]-bbox[3]-1, | |
bbox[0], | |
padded_img.size[0]-bbox[1]-1, | |
bbox[2]] | |
token['bbox'] = bbox | |
cropped_table['image'] = padded_img | |
cropped_table['tokens'] = table_tokens | |
table_crops.append(cropped_table) | |
return table_crops | |
def predict(self,image:Image.Image,debugfolder_filename_page_name): | |
""" | |
0. Locate the table using Table detection | |
1. Unitable | |
""" | |
# Step 0 : Locate the table using Table detection TODO | |
#First we load a Table Transformer pre-trained for table detection. We use the "no_timm" version here to load the checkpoint with a Transformers-native backbone. | |
model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
#Preparing the image for the model | |
detection_transform = transforms.Compose([ | |
MaxResize(800), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
pixel_values = detection_transform(image).unsqueeze(0) | |
pixel_values = pixel_values.to(device) | |
# Next, we forward the pixel values through the model. | |
# The model outputs logits of shape (batch_size, num_queries, num_labels + 1). The +1 is for the "no object" class. | |
with torch.no_grad(): | |
outputs = model(pixel_values) | |
# update id2label to include "no object" | |
id2label = model.config.id2label | |
id2label[len(model.config.id2label)] = "no object" | |
#[{'label': 'table', 'score': 0.9999570846557617, 'bbox': [110.24547576904297, 73.31171417236328, 1024.609130859375, 308.7159423828125]}] | |
objects = DetectionAndOcrTable2.outputs_to_objects(outputs, image.size, id2label) | |
#Only do these for objects with score greater than 0.8 | |
objects = [obj for obj in objects if obj['score'] > 0.95] | |
print(objects) | |
if objects: | |
fig = DetectionAndOcrTable2.visualize_detected_tables(image, objects,out_path = "./res/table_debug/table_former_detection.jpg") | |
#Next, we crop the table out of the image. For that, the TATR authors employ some padding to make sure the borders of the table are included. | |
tokens = [] | |
detection_class_thresholds = { | |
"table": 0.95, | |
"table rotated": 0.95, | |
"no object": 10 | |
} | |
crop_padding = 10 | |
tables_crops = DetectionAndOcrTable2.objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=crop_padding) | |
#[{'image': <PIL.Image.Image image mode=RGB size=1392x903 at 0x7F71B02BCB50>, 'tokens': []}] | |
#print(tables_crops) | |
#TODO: Handle the case where there are multiple tables | |
cropped_tables =[] | |
for i in range (len(tables_crops)): | |
cropped_table = tables_crops[i]['image'].convert("RGB") | |
cropped_table.save(debugfolder_filename_page_name +"cropped_table_"+str(i)+".png") | |
cropped_tables.append(cropped_table) | |
print("number of cropped tables found: "+str(len(cropped_tables))) | |
# Step 1: Unitable | |
#This take PIL Images as input | |
table_codes = self.unitableFullPredictor.predict(cropped_tables,debugfolder_filename_page_name) | |
else: | |
return | |