File size: 6,172 Bytes
9c909e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
from torch.utils.data import Dataset
from PIL import Image
import json
from transformers import TrOCRProcessor
import pandas as pd
from sklearn.model_selection import train_test_split
import glob
import torchvision.transforms as transforms
import numpy as np
def prepare_data_frame(root_dir):
with open(root_dir) as f:
d = json.load(f)
filename = [d[i]["word_id"]+ ".png" for i in range(len(d))]
text = [d[i]["text"] for i in range(len(d))]
data = {'filename': filename, 'text': text}
df = pd.DataFrame(data=data)
return df
class AphaPenDataset(Dataset):
def __init__(self, root_dir, df, processor, transform=None, max_target_length=128):
self.root_dir = root_dir
self.df= df
# self.filename, self.text = self.prepare_data()
self.processor = processor
self.max_target_length = max_target_length
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# get file name + text
file_name = self.df.filename[idx]
text = self.df.text[idx]
# prepare image (i.e. resize + normalize)
image = Image.open(self.root_dir + file_name).convert("RGB")
if self.transform is not None:
image = self.transform(image)
img=transforms.ToPILImage()(image)
img.save("/mnt/data1/Datasets/AlphaPen/transformed_images/" + file_name)
pixel_values = self.processor(image, return_tensors="pt").pixel_values
# add labels (input_ids) by encoding the text
labels = self.processor.tokenizer(text,
padding="max_length",
max_length=self.max_target_length).input_ids
# important: make sure that PAD tokens are ignored by the loss function
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
return encoding
def prepare_data(self):
with open(self.path_json) as f:
d = json.load(f)
filename = [d[i]["image_id"]+ ".png" for i in range(len(d))]
text = [d[i]["text"] for i in range(len(d))]
return filename, text
class AlphaPenPhi3Dataset(Dataset):
def __init__(self, root_dir, dataframe, tokenizer, max_length, image_size):
self.dataframe = dataframe
self.tokenizer = tokenizer
self.tokenizer.padding_side = 'left'
self.max_length = max_length
self.root_dir = root_dir
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor()
])
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
row = self.dataframe.iloc[idx]
text = f"<|user|>\n<|image_1|>What is shown in this image?<|end|><|assistant|>\n {row['text']} <|end|>"
image_path = self.root_dir + row['filename']
# Tokenize text
encodings = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length)
try:
# Load and transform image
image = Image.open(image_path).convert("RGB")
image = self.image_transform_function(image)
except (FileNotFoundError, IOError):
# Skip the sample if the image is not found
return None
labels = self.tokenizer(row['text'],
padding="max_length",
max_length=self.max_length).input_ids
# important: make sure that PAD tokens are ignored by the loss function
labels = [label if label != self.tokenizer.pad_token_id else -100 for label in labels]
encodings['pixel_values'] = image
encodings['labels'] = labels
return {key: torch.tensor(val) for key, val in encodings.items()}
def image_transform_function(self, image):
image = self.transform(image)
return image
if __name__ == "__main__":
json_path = "/mnt/data1/Datasets/OCR/Alphapen/label_check/"
json_path_b2 = "/mnt/data1/Datasets/OCR/Alphapen/DataBatch2/label_check/"
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
root_dir_b2 = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
json_files = glob.glob(json_path + "*.json")
json_files_b2 = glob.glob(json_path_b2 + "*.json")
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
df_list_b1 = [prepare_data_frame(file) for file in json_files]
df_list_b2 = [prepare_data_frame(file) for file in json_files_b2]
# df_list = df_list_b1 + df_list_b2
df_b1 = pd.concat(df_list_b1)
df_b2 = pd.concat(df_list_b2)
df_b1.to_csv("/mnt/data1/Datasets/AlphaPen/" + "testing_data_b1.csv")
df_b2.to_csv("/mnt/data1/Datasets/AlphaPen/" + "testing_data_b2.csv")
# train_df, test_df = train_test_split(df, test_size=0.15)
# # we reset the indices to start from zero
# train_df.reset_index(drop=True, inplace=True)
# test_df.reset_index(drop=True, inplace=True)
# processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
# train_dataset = AphaPenDataset(root_dir=root_dir, df=train_df, processor=processor)
# eval_dataset = AphaPenDataset(root_dir=root_dir, df=test_df, processor=processor)
# print("Number of training examples:", len(train_dataset))
# print("Number of validation examples:", len(eval_dataset))
# encoding = train_dataset[0]
# for k,v in encoding.items():
# print(k, v.shape)
# image = Image.open(train_dataset.root_dir + df.filename[0]).convert("RGB")
# print('Label: '+df.text[0])
# print(image)
# labels = encoding['labels']
# print(labels)
# labels[labels == -100] = processor.tokenizer.pad_token_id
# label_str = processor.decode(labels, skip_special_tokens=True)
# print('Decoded Label:', label_str) |