alps / unitable /unitable_predictor.py
yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
raw
history blame
12.9 kB
from typing import Tuple, List, Sequence, Optional, Union
from pathlib import Path
import re
import torch
import tokenizers as tk
from PIL import Image
from matplotlib import pyplot as plt
from matplotlib import patches
from torchvision import transforms
from torch import nn, Tensor
from functools import partial
import numpy.typing as npt
from numpy import uint8
ImageType = npt.NDArray[uint8]
import warnings
import time
import argparse
from .src.model import EncoderDecoder, ImgLinearBackbone, Encoder, Decoder
from .src.utils import subsequent_mask, pred_token_within_range, greedy_sampling, bbox_str_to_token_list, html_str_to_token_list
from .src.trainer.utils import VALID_HTML_TOKEN, VALID_BBOX_TOKEN, INVALID_CELL_TOKEN
warnings.filterwarnings('ignore')
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
class UnitablePredictor():
def __init__(self):
pass
def load_vocab_and_model(
self,
backbone,
encoder,
decoder,
vocab_path: Union[str, Path],
max_seq_len: int,
model_weights: Union[str, Path],
) -> Tuple[tk.Tokenizer, EncoderDecoder]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab = tk.Tokenizer.from_file(vocab_path)
d_model = 768
dropout = 0.2
model = EncoderDecoder(
backbone= backbone,
encoder= encoder,
decoder= decoder,
vocab_size= vocab.get_vocab_size(),
d_model= d_model,
padding_idx= vocab.token_to_id("<pad>"),
max_seq_len=max_seq_len,
dropout=dropout,
norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
# it loads weights onto the CPU first and then moves the model to the desired device
model.load_state_dict(torch.load(model_weights, map_location="cpu"))
model = model.to(device)
return vocab, model
def autoregressive_decode(
self,
model: EncoderDecoder,
image: Tensor,
prefix: Sequence[int],
max_decode_len: int,
eos_id: int,
token_whitelist: Optional[Sequence[int]] = None,
token_blacklist: Optional[Sequence[int]] = None,
) -> Tensor:
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
"""
The encoder takes the input data (in this case, an image) and transforms it into a high-dimensional feature representation.
This feature representation, or memory tensor, captures the essential information from the input data needed to generate the output sequence.
"""
memory = model.encode(image)
"""
Creates a context tensor from the prefix and repeats it to match the batch size of the image, moving it to the appropriate device.
"""
context = torch.tensor(prefix, dtype=torch.int32).repeat(image.shape[0], 1).to(device)
for _ in range(max_decode_len):
eos_flag = [eos_id in k for k in context]
if all(eos_flag):
break
with torch.no_grad():
causal_mask = subsequent_mask(context.shape[1]).to(device)
logits = model.decode(
memory, context, tgt_mask=causal_mask, tgt_padding_mask=None
)
logits = model.generator(logits)[:, -1, :]
logits = pred_token_within_range(
logits.detach(),
white_list=token_whitelist,
black_list=token_blacklist,
)
next_probs, next_tokens = greedy_sampling(logits)
context = torch.cat([context, next_tokens], dim=1)
return context
@staticmethod
def image_to_tensor(image: Image, size: Tuple[int, int]) -> Tensor:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
T = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.86597056,0.88463002,0.87491087], std = [0.20686628,0.18201602,0.18485524])
])
image_tensor = T(image)
image_tensor = image_tensor.to(device).unsqueeze(0)
return image_tensor
def rescale_bbox(
self,
bbox: Sequence[Sequence[float]],
src: Tuple[int, int],
tgt: Tuple[int, int]
) -> Sequence[Sequence[float]]:
assert len(src) == len(tgt) == 2
ratio = [tgt[0] / src[0], tgt[1] / src[1]] * 2
print(ratio)
bbox = [[int(round(i * j)) for i, j in zip(entry, ratio)] for entry in bbox]
return bbox
def predict(self, images:List[Image.Image],debugfolder_filename_page_name:str):
MODEL_FILE_NAME = ["unitable_large_structure.pt", "unitable_large_bbox.pt", "unitable_large_content.pt"]
MODEL_DIR = Path("./unitable/experiments/unitable_weights")
# UniTable large model
d_model = 768
patch_size = 16
nhead = 12
dropout = 0.2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backbone= ImgLinearBackbone(d_model=d_model, patch_size=patch_size)
encoder= Encoder(
d_model=d_model,
nhead=nhead,
dropout=dropout,
activation="gelu",
norm_first=True,
nlayer=12,
ff_ratio=4,
)
decoder= Decoder(
d_model=d_model,
nhead=nhead,
dropout=dropout,
activation="gelu",
norm_first=True,
nlayer=4,
ff_ratio=4,
)
"""
Step 1 Load Table Structure Model
"""
start1 = time.time()
# Table structure extraction
vocabS, modelS = self.load_vocab_and_model(
backbone=backbone,
encoder=encoder,
decoder=decoder,
vocab_path="./unitable/vocab/vocab_html.json",
max_seq_len=784,
model_weights=MODEL_DIR / MODEL_FILE_NAME[0]
)
end1 = time.time()
print("time to load table structure model ",end1-start1,"seconds")
"""
Step 2 prepare images to tensor
"""
image_tensors = []
for i in range(len(images)):
image_size = images[i].size
# Image transformation
image_tensor = self.image_to_tensor(images[i], (448, 448))
image_tensors.append(image_tensor)
# This will be list of arrays(pred_html), which is again list of array
pred_htmls = []
for i in range(len(image_tensors)):
#print(image_tensor)
print("Processing table "+str(i))
start2 = time.time()
# Inference
pred_html = self.autoregressive_decode(
model= modelS,
image= image_tensors[i],
prefix=[vocabS.token_to_id("[html]")],
max_decode_len=512,
eos_id=vocabS.token_to_id("<eos>"),
token_whitelist=[vocabS.token_to_id(i) for i in VALID_HTML_TOKEN],
token_blacklist = None
)
end2 = time.time()
print("time for inference table structure ",end2-start2,"seconds")
pred_html = pred_html.detach().cpu().numpy()[0]
pred_html = vocabS.decode(pred_html, skip_special_tokens=False)
pred_html = html_str_to_token_list(pred_html)
pred_htmls.append(pred_html)
print(pred_html)
"""
Step 3 Load Table Cell detection
"""
start3 = time.time()
# Table cell bbox detection
vocabB, modelB = self.load_vocab_and_model(
backbone=backbone,
encoder=encoder,
decoder=decoder,
vocab_path="./unitable/vocab/vocab_bbox.json",
max_seq_len=1024,
model_weights=MODEL_DIR / MODEL_FILE_NAME[1],
)
end3 = time.time()
print("time to load cell bbox detection model ",end3-start3,"seconds")
"""
Step 4 do the pred_bboxes detection
"""
pred_bboxs =[]
for i in range(len(image_tensors)):
start4 = time.time()
# Inference
pred_bbox = self.autoregressive_decode(
model=modelB,
image=image_tensors[i],
prefix=[vocabB.token_to_id("[bbox]")],
max_decode_len=1024,
eos_id=vocabB.token_to_id("<eos>"),
token_whitelist=[vocabB.token_to_id(i) for i in VALID_BBOX_TOKEN[: 449]],
token_blacklist = None
)
end4 = time.time()
print("Processing table "+str(i))
print("time to do inference for table cell bbox detection model ",end4-start4,"seconds")
# Convert token id to token text
pred_bbox = pred_bbox.detach().cpu().numpy()[0]
pred_bbox = vocabB.decode(pred_bbox, skip_special_tokens=False)
pred_bbox = bbox_str_to_token_list(pred_bbox)
pred_bbox = self.rescale_bbox(pred_bbox, src=(448, 448), tgt=images[i].size)
print(pred_bbox)
print("Size of the image ")
#(1498, 971)
print(images[i].size)
print("Number of bounding boxes ")
print(len(pred_bbox))
countcells = 0
for elem in pred_htmls[i] :
if elem == '<td>[]</td>' or elem == '>[]</td>':
countcells+=1
#275
print("number of countcells")
print(countcells)
if countcells > 256:
#TODO Extra processing for big tables
#Find the last incomplete row and its ymax coordinate
# Last bbox's ymax gives us coordinate of where the cutted off row starts
#IMPORTANT : pred_bbox is xmin, ymin, xmax, ymax
cut_off = pred_bbox[-1][1]
#This will be used to distinguish how many cells are already detected in that row.
last_cells_redudant = 0
for cell in reversed(pred_bbox):
if cut_off-5 < cell[1] <cut_off+5:
last_cells_redudant+=1
else:
break
width = images[i].size[0]
height = images[i].size[1]
#IMPORTANT : crop takes in (xmin, ymax, xmax, ymin) coordintes !!!
bbox = (0, cut_off, width, height)
# Crop the image to the specified bounding box
cropped_image = images[i].crop(bbox)
#cropped_image.save("./res/table_debug/cropped_image_for_extra_bbox_det_table_num_"+str(i)+".png")
image_tensor = self.image_to_tensor(cropped_image, (448, 448))
pred_bbox_extra = self.autoregressive_decode(
model=modelB,
image=image_tensor,
prefix=[vocabB.token_to_id("[bbox]")],
max_decode_len=1024,
eos_id=vocabB.token_to_id("<eos>"),
token_whitelist=[vocabB.token_to_id(i) for i in VALID_BBOX_TOKEN[: 449]],
token_blacklist = None
)
# Convert token id to token text
pred_bbox_extra = pred_bbox_extra.detach().cpu().numpy()[0]
pred_bbox_extra = vocabB.decode(pred_bbox_extra, skip_special_tokens=False)
pred_bbox_extra = bbox_str_to_token_list(pred_bbox_extra)
pred_bbox_extra = pred_bbox_extra[last_cells_redudant:]
pred_bbox_extra = self.rescale_bbox(pred_bbox_extra, src=(448, 448), tgt=cropped_image.size)
pred_bbox_extra = [[i[0], i[1]+cut_off, i[2], i[3]+cut_off] for i in pred_bbox_extra]
pred_bbox = pred_bbox + pred_bbox_extra
print("extra boxes:")
print(pred_bbox_extra)
print("length of extra boxes")
print(len(pred_bbox_extra))
pred_bboxs.append(pred_bbox)
fig, ax = plt.subplots(figsize=(12, 10))
for j in pred_bbox:
#i is xmin, ymin, xmax, ymax based on the function usage
rect = patches.Rectangle(j[:2], j[2] - j[0], j[3] - j[1], linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)
ax.set_axis_off()
ax.imshow(images[i])
fig.savefig(debugfolder_filename_page_name+str(i)+".png", bbox_inches='tight', dpi=300)
return pred_htmls,pred_bboxs