Upload dataset.py
Browse files- dataset.py +127 -0
dataset.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%writefile dataset.py
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
from PIL import Image, ImageDraw
|
5 |
+
import numpy as np
|
6 |
+
import pytesseract
|
7 |
+
import torch
|
8 |
+
|
9 |
+
## I guess, I got my own script for it, from https://github.com/shabie/docformer/blob/master/src/docformer/dataset.py
|
10 |
+
|
11 |
+
def rescale_bbox(bbox, img_width : int,
|
12 |
+
img_height : int, size : int = 1000):
|
13 |
+
x0, x1, y0, y1, width, height = bbox
|
14 |
+
x0 = int(size * (x0 / img_width))
|
15 |
+
x1 = int(size * (x1 / img_width))
|
16 |
+
y0 = int(size * (y0 / img_height))
|
17 |
+
y1 = int(size * (y1 / img_height))
|
18 |
+
width = int(size * (width / img_width))
|
19 |
+
height = int(size * (height / img_height))
|
20 |
+
return [x0, x1, y0, y1, width, height]
|
21 |
+
|
22 |
+
def coordinate_features(df_row):
|
23 |
+
xmin, ymin, width, height = df_row["left"], df_row["top"], df_row["width"], df_row["height"]
|
24 |
+
return [xmin, xmin + width, ymin, ymin + height, width, height] ## [xmin, xmax, ymin, ymax, width, height]
|
25 |
+
|
26 |
+
def get_ocr_results(image_path : str):
|
27 |
+
|
28 |
+
"""
|
29 |
+
Returns words and its bounding boxes from the image file path
|
30 |
+
image_path: string containing the path of the image
|
31 |
+
"""
|
32 |
+
|
33 |
+
## Getting the Image
|
34 |
+
image = Image.open(image_path)
|
35 |
+
width, height = image.size
|
36 |
+
|
37 |
+
## OCR Processing
|
38 |
+
ocr_df = pytesseract.image_to_data(image, output_type="data.frame")
|
39 |
+
ocr_df = ocr_df.dropna().reset_index(drop=True)
|
40 |
+
float_cols = ocr_df.select_dtypes("float").columns
|
41 |
+
ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
|
42 |
+
ocr_df = ocr_df.replace(r"^\s*$", np.nan, regex=True)
|
43 |
+
ocr_df = ocr_df.dropna().reset_index(drop=True)
|
44 |
+
ocr_df = ocr_df.sort_values(by=['left', 'top']) ## Sorting the values on the basis of left, top bounding box coordinates
|
45 |
+
|
46 |
+
## Finally getting the words and the bounding box
|
47 |
+
words = list(ocr_df.text.apply(lambda x: str(x).strip()))
|
48 |
+
actual_bboxes = ocr_df.apply(coordinate_features, axis=1).values.tolist()
|
49 |
+
|
50 |
+
# add as extra columns
|
51 |
+
assert len(words) == len(actual_bboxes)
|
52 |
+
return {"words": words, "bbox": actual_bboxes}
|
53 |
+
|
54 |
+
## Stealed from here: https://github.com/uakarsh/latr/blob/main/src/latr/dataset.py
|
55 |
+
|
56 |
+
def get_tokens_with_boxes(unnormalized_word_boxes, list_of_words,
|
57 |
+
tokenizer, pad_token_id : int = 0,
|
58 |
+
pad_token_box = [0, 0, 0, 0, 0, 0],
|
59 |
+
max_seq_len = 512,
|
60 |
+
sep_token_box = [0, 0, 1000, 1000, 0, 0]
|
61 |
+
):
|
62 |
+
|
63 |
+
'''
|
64 |
+
This function returns two items:
|
65 |
+
1. unnormalized_token_boxes -> a list of len = max_seq_len, containing the boxes corresponding to the tokenized words,
|
66 |
+
one box might repeat as per the tokenization procedure
|
67 |
+
2. tokenized_words -> tokenized words corresponding to the tokenizer and the list_of_words
|
68 |
+
'''
|
69 |
+
|
70 |
+
assert len(unnormalized_word_boxes) == len(list_of_words), "Bounding box length != total words length"
|
71 |
+
|
72 |
+
length_of_box = len(unnormalized_word_boxes)
|
73 |
+
unnormalized_token_boxes = []
|
74 |
+
tokenized_words = []
|
75 |
+
|
76 |
+
## CLS, SEP tokens have to be appended
|
77 |
+
unnormalized_token_boxes.extend([pad_token_box])
|
78 |
+
tokenized_words.extend([tokenizer.cls_token_id]) ## CLS Token Box is same as pad_token_box, if not, you can change here
|
79 |
+
|
80 |
+
## Normal for loop
|
81 |
+
idx = 0
|
82 |
+
for box, word in zip(unnormalized_word_boxes, list_of_words):
|
83 |
+
if idx != 0:
|
84 |
+
new_word = " " + word
|
85 |
+
else:
|
86 |
+
new_word = word
|
87 |
+
current_tokens = tokenizer(new_word, add_special_tokens = False).input_ids
|
88 |
+
unnormalized_token_boxes.extend([box]*len(current_tokens))
|
89 |
+
tokenized_words.extend(current_tokens)
|
90 |
+
idx += 1
|
91 |
+
|
92 |
+
## For post processing the token box
|
93 |
+
if len(unnormalized_token_boxes)<max_seq_len:
|
94 |
+
unnormalized_token_boxes.extend([sep_token_box])
|
95 |
+
unnormalized_token_boxes.extend([pad_token_box] * (max_seq_len-len(unnormalized_token_boxes)))
|
96 |
+
|
97 |
+
else:
|
98 |
+
unnormalized_token_boxes[max_seq_len - 1] = sep_token_box
|
99 |
+
|
100 |
+
## For post processing the tokenized words
|
101 |
+
if len(tokenized_words) < max_seq_len:
|
102 |
+
tokenized_words.extend([tokenizer.sep_token_id])
|
103 |
+
tokenized_words.extend([pad_token_id]* (max_seq_len-len(tokenized_words)))
|
104 |
+
|
105 |
+
else:
|
106 |
+
tokenized_words[max_seq_len - 1] = tokenizer.sep_token_id
|
107 |
+
|
108 |
+
return unnormalized_token_boxes[:max_seq_len], tokenized_words[:max_seq_len]
|
109 |
+
|
110 |
+
|
111 |
+
def create_features(
|
112 |
+
img_path : str,
|
113 |
+
tokenizer,
|
114 |
+
max_seq_length : int = 512,
|
115 |
+
size : int = 1000,
|
116 |
+
use_ocr : bool = True,
|
117 |
+
bounding_box = None,
|
118 |
+
words = None ):
|
119 |
+
|
120 |
+
image = Image.open(img_path).convert("RGB")
|
121 |
+
ocr_results = get_ocr_results(img_path)
|
122 |
+
ocr_results['rescale_bbox'] = list(map(lambda x: rescale_bbox(x, *image.size, size = size), ocr_results['bbox']))
|
123 |
+
boxes, words = get_tokens_with_boxes(ocr_results['rescale_bbox'], ocr_results['words'], tokenizer)
|
124 |
+
torch_boxes = torch.as_tensor(boxes)
|
125 |
+
torch_words = torch.as_tensor(words)
|
126 |
+
|
127 |
+
return torch_boxes, torch_words, ocr_results['bbox']
|