from typing import Tuple, List, Sequence, Optional, Union from pathlib import Path import re import torch import tokenizers as tk from PIL import Image from matplotlib import pyplot as plt from matplotlib import patches from torchvision import transforms from torch import nn, Tensor from functools import partial import numpy.typing as npt from numpy import uint8 ImageType = npt.NDArray[uint8] import warnings import time import argparse from .src.model import EncoderDecoder, ImgLinearBackbone, Encoder, Decoder from .src.utils import subsequent_mask, pred_token_within_range, greedy_sampling, bbox_str_to_token_list, html_str_to_token_list from .src.trainer.utils import VALID_HTML_TOKEN, VALID_BBOX_TOKEN, INVALID_CELL_TOKEN warnings.filterwarnings('ignore') device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") class UnitablePredictor(): def __init__(self): pass def load_vocab_and_model( self, backbone, encoder, decoder, vocab_path: Union[str, Path], max_seq_len: int, model_weights: Union[str, Path], ) -> Tuple[tk.Tokenizer, EncoderDecoder]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vocab = tk.Tokenizer.from_file(vocab_path) d_model = 768 dropout = 0.2 model = EncoderDecoder( backbone= backbone, encoder= encoder, decoder= decoder, vocab_size= vocab.get_vocab_size(), d_model= d_model, padding_idx= vocab.token_to_id(""), max_seq_len=max_seq_len, dropout=dropout, norm_layer=partial(nn.LayerNorm, eps=1e-6) ) # it loads weights onto the CPU first and then moves the model to the desired device model.load_state_dict(torch.load(model_weights, map_location="cpu")) model = return vocab, model def autoregressive_decode( self, model: EncoderDecoder, image: Tensor, prefix: Sequence[int], max_decode_len: int, eos_id: int, token_whitelist: Optional[Sequence[int]] = None, token_blacklist: Optional[Sequence[int]] = None, ) -> Tensor: model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with torch.no_grad(): """ The encoder takes the input data (in this case, an image) and transforms it into a high-dimensional feature representation. This feature representation, or memory tensor, captures the essential information from the input data needed to generate the output sequence. """ memory = model.encode(image) """ Creates a context tensor from the prefix and repeats it to match the batch size of the image, moving it to the appropriate device. """ context = torch.tensor(prefix, dtype=torch.int32).repeat(image.shape[0], 1).to(device) for _ in range(max_decode_len): eos_flag = [eos_id in k for k in context] if all(eos_flag): break with torch.no_grad(): causal_mask = subsequent_mask(context.shape[1]).to(device) logits = model.decode( memory, context, tgt_mask=causal_mask, tgt_padding_mask=None ) logits = model.generator(logits)[:, -1, :] logits = pred_token_within_range( logits.detach(), white_list=token_whitelist, black_list=token_blacklist, ) next_probs, next_tokens = greedy_sampling(logits) context =[context, next_tokens], dim=1) return context @staticmethod def image_to_tensor(image: Image, size: Tuple[int, int]) -> Tensor: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") T = transforms.Compose([ transforms.Resize(size), transforms.ToTensor(), transforms.Normalize(mean=[0.86597056,0.88463002,0.87491087], std = [0.20686628,0.18201602,0.18485524]) ]) image_tensor = T(image) image_tensor = return image_tensor def rescale_bbox( self, bbox: Sequence[Sequence[float]], src: Tuple[int, int], tgt: Tuple[int, int] ) -> Sequence[Sequence[float]]: assert len(src) == len(tgt) == 2 ratio = [tgt[0] / src[0], tgt[1] / src[1]] * 2 print(ratio) bbox = [[int(round(i * j)) for i, j in zip(entry, ratio)] for entry in bbox] return bbox def predict(self, images:List[Image.Image],debugfolder_filename_page_name:str): MODEL_FILE_NAME = ["", "", ""] MODEL_DIR = Path("./unitable/experiments/unitable_weights") # UniTable large model d_model = 768 patch_size = 16 nhead = 12 dropout = 0.2 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") backbone= ImgLinearBackbone(d_model=d_model, patch_size=patch_size) encoder= Encoder( d_model=d_model, nhead=nhead, dropout=dropout, activation="gelu", norm_first=True, nlayer=12, ff_ratio=4, ) decoder= Decoder( d_model=d_model, nhead=nhead, dropout=dropout, activation="gelu", norm_first=True, nlayer=4, ff_ratio=4, ) """ Step 1 Load Table Structure Model """ start1 = time.time() # Table structure extraction vocabS, modelS = self.load_vocab_and_model( backbone=backbone, encoder=encoder, decoder=decoder, vocab_path="./unitable/vocab/vocab_html.json", max_seq_len=784, model_weights=MODEL_DIR / MODEL_FILE_NAME[0] ) end1 = time.time() print("time to load table structure model ",end1-start1,"seconds") """ Step 2 prepare images to tensor """ image_tensors = [] for i in range(len(images)): image_size = images[i].size # Image transformation image_tensor = self.image_to_tensor(images[i], (448, 448)) image_tensors.append(image_tensor) # This will be list of arrays(pred_html), which is again list of array pred_htmls = [] for i in range(len(image_tensors)): #print(image_tensor) print("Processing table "+str(i)) start2 = time.time() # Inference pred_html = self.autoregressive_decode( model= modelS, image= image_tensors[i], prefix=[vocabS.token_to_id("[html]")], max_decode_len=512, eos_id=vocabS.token_to_id(""), token_whitelist=[vocabS.token_to_id(i) for i in VALID_HTML_TOKEN], token_blacklist = None ) end2 = time.time() print("time for inference table structure ",end2-start2,"seconds") pred_html = pred_html.detach().cpu().numpy()[0] pred_html = vocabS.decode(pred_html, skip_special_tokens=False) pred_html = html_str_to_token_list(pred_html) pred_htmls.append(pred_html) print(pred_html) """ Step 3 Load Table Cell detection """ start3 = time.time() # Table cell bbox detection vocabB, modelB = self.load_vocab_and_model( backbone=backbone, encoder=encoder, decoder=decoder, vocab_path="./unitable/vocab/vocab_bbox.json", max_seq_len=1024, model_weights=MODEL_DIR / MODEL_FILE_NAME[1], ) end3 = time.time() print("time to load cell bbox detection model ",end3-start3,"seconds") """ Step 4 do the pred_bboxes detection """ pred_bboxs =[] for i in range(len(image_tensors)): start4 = time.time() # Inference pred_bbox = self.autoregressive_decode( model=modelB, image=image_tensors[i], prefix=[vocabB.token_to_id("[bbox]")], max_decode_len=1024, eos_id=vocabB.token_to_id(""), token_whitelist=[vocabB.token_to_id(i) for i in VALID_BBOX_TOKEN[: 449]], token_blacklist = None ) end4 = time.time() print("Processing table "+str(i)) print("time to do inference for table cell bbox detection model ",end4-start4,"seconds") # Convert token id to token text pred_bbox = pred_bbox.detach().cpu().numpy()[0] pred_bbox = vocabB.decode(pred_bbox, skip_special_tokens=False) pred_bbox = bbox_str_to_token_list(pred_bbox) pred_bbox = self.rescale_bbox(pred_bbox, src=(448, 448), tgt=images[i].size) print(pred_bbox) print("Size of the image ") #(1498, 971) print(images[i].size) print("Number of bounding boxes ") print(len(pred_bbox)) countcells = 0 for elem in pred_htmls[i] : if elem == '[]' or elem == '>[]': countcells+=1 #275 print("number of countcells") print(countcells) if countcells > 256: #TODO Extra processing for big tables #Find the last incomplete row and its ymax coordinate # Last bbox's ymax gives us coordinate of where the cutted off row starts #IMPORTANT : pred_bbox is xmin, ymin, xmax, ymax cut_off = pred_bbox[-1][1] #This will be used to distinguish how many cells are already detected in that row. last_cells_redudant = 0 for cell in reversed(pred_bbox): if cut_off-5 < cell[1] "), token_whitelist=[vocabB.token_to_id(i) for i in VALID_BBOX_TOKEN[: 449]], token_blacklist = None ) # Convert token id to token text pred_bbox_extra = pred_bbox_extra.detach().cpu().numpy()[0] pred_bbox_extra = vocabB.decode(pred_bbox_extra, skip_special_tokens=False) pred_bbox_extra = bbox_str_to_token_list(pred_bbox_extra) pred_bbox_extra = pred_bbox_extra[last_cells_redudant:] pred_bbox_extra = self.rescale_bbox(pred_bbox_extra, src=(448, 448), tgt=cropped_image.size) pred_bbox_extra = [[i[0], i[1]+cut_off, i[2], i[3]+cut_off] for i in pred_bbox_extra] pred_bbox = pred_bbox + pred_bbox_extra print("extra boxes:") print(pred_bbox_extra) print("length of extra boxes") print(len(pred_bbox_extra)) pred_bboxs.append(pred_bbox) fig, ax = plt.subplots(figsize=(12, 10)) for j in pred_bbox: #i is xmin, ymin, xmax, ymax based on the function usage rect = patches.Rectangle(j[:2], j[2] - j[0], j[3] - j[1], linewidth=1, edgecolor='r', facecolor='none') ax.add_patch(rect) ax.set_axis_off() ax.imshow(images[i]) fig.savefig(debugfolder_filename_page_name+str(i)+".png", bbox_inches='tight', dpi=300) return pred_htmls,pred_bboxs