Spaces:
Running
Running
import random | |
import torch | |
from torch.utils.data import Dataset | |
import torchvision.transforms.functional as F | |
import numpy as np | |
from torch.utils.data.dataloader import default_collate | |
import cv2 | |
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset | |
from modules.utils import object_dict, arrow_dict, resize_boxes, resize_keypoints | |
import torchvision.transforms.functional as F | |
import torch | |
class RandomCrop: | |
def __init__(self, new_size=(1333, 800), crop_fraction=0.5, min_objects=4): | |
""" | |
Initialize the RandomCrop transformation. | |
Parameters: | |
- new_size (tuple): The target size for the image after cropping. | |
- crop_fraction (float): The fraction of the original width to use when cropping. | |
- min_objects (int): Minimum number of objects required to be within the crop. | |
""" | |
self.crop_fraction = crop_fraction | |
self.min_objects = min_objects | |
self.new_size = new_size | |
def __call__(self, image, target): | |
""" | |
Apply the RandomCrop transformation to the image and its target. | |
Parameters: | |
- image (PIL Image): The image to be cropped. | |
- target (dict): The target dictionary containing 'boxes' and optional 'keypoints'. | |
Returns: | |
- PIL Image, dict: The cropped image and its updated target dictionary. | |
""" | |
new_w1, new_h1 = self.new_size | |
w, h = image.size | |
new_w = int(w * self.crop_fraction) | |
new_h = int(new_w * new_h1 / new_w1) | |
i = 0 | |
for i in range(4): # Try 4 times to adjust new_w and new_h if new_h >= h | |
if new_h >= h: | |
i += 0.05 | |
new_w = int(w * (self.crop_fraction - i)) | |
new_h = int(new_w * new_h1 / new_w1) | |
if new_h < h: | |
continue | |
if new_h >= h: # If still not valid, return original image and target | |
return image, target | |
boxes = target["boxes"] | |
if 'keypoints' in target: | |
keypoints = target["keypoints"] | |
else: | |
keypoints = [] | |
for _ in range(len(boxes)): | |
keypoints.append(torch.zeros((2, 3))) | |
# Attempt to find a suitable crop region | |
success = False | |
for _ in range(100): # Max 100 attempts to find a valid crop | |
top = random.randint(0, h - new_h) | |
left = random.randint(0, w - new_w) | |
crop_region = [left, top, left + new_w, top + new_h] | |
# Check how many objects are fully contained in this region | |
contained_boxes = [] | |
contained_keypoints = [] | |
for box, kp in zip(boxes, keypoints): | |
if box[0] >= crop_region[0] and box[1] >= crop_region[1] and box[2] <= crop_region[2] and box[3] <= crop_region[3]: | |
# Adjust box and keypoints coordinates | |
new_box = box - torch.tensor([crop_region[0], crop_region[1], crop_region[0], crop_region[1]]) | |
new_kp = kp - torch.tensor([crop_region[0], crop_region[1], 0]) | |
contained_boxes.append(new_box) | |
contained_keypoints.append(new_kp) | |
if len(contained_boxes) >= self.min_objects: | |
success = True | |
break | |
if success: | |
# Perform the actual crop | |
image = F.crop(image, top, left, new_h, new_w) | |
target["boxes"] = torch.stack(contained_boxes) if contained_boxes else torch.zeros((0, 4)) | |
if 'keypoints' in target: | |
target["keypoints"] = torch.stack(contained_keypoints) if contained_keypoints else torch.zeros((0, 2, 4)) | |
return image, target | |
class RandomFlip: | |
def __init__(self, h_flip_prob=0.5, v_flip_prob=0.5): | |
""" | |
Initialize the RandomFlip transformation with probabilities for flipping. | |
Parameters: | |
- h_flip_prob (float): Probability of applying a horizontal flip to the image. | |
- v_flip_prob (float): Probability of applying a vertical flip to the image. | |
""" | |
self.h_flip_prob = h_flip_prob | |
self.v_flip_prob = v_flip_prob | |
def __call__(self, image, target): | |
""" | |
Apply random horizontal and/or vertical flip to the image and updates target data accordingly. | |
Parameters: | |
- image (PIL Image): The image to be flipped. | |
- target (dict): The target dictionary containing 'boxes' and 'keypoints'. | |
Returns: | |
- PIL Image, dict: The flipped image and its updated target dictionary. | |
""" | |
if random.random() < self.h_flip_prob: | |
image = F.hflip(image) | |
w, _ = image.size # Get the new width of the image after flip for bounding box adjustment | |
# Adjust bounding boxes for horizontal flip | |
for i, box in enumerate(target['boxes']): | |
xmin, ymin, xmax, ymax = box | |
target['boxes'][i] = torch.tensor([w - xmax, ymin, w - xmin, ymax], dtype=torch.float32) | |
# Adjust keypoints for horizontal flip | |
if 'keypoints' in target: | |
new_keypoints = [] | |
for keypoints_for_object in target['keypoints']: | |
flipped_keypoints_for_object = [] | |
for kp in keypoints_for_object: | |
x, y = kp[:2] | |
new_x = w - x | |
flipped_keypoints_for_object.append(torch.tensor([new_x, y] + list(kp[2:]))) | |
new_keypoints.append(torch.stack(flipped_keypoints_for_object)) | |
target['keypoints'] = torch.stack(new_keypoints) | |
if random.random() < self.v_flip_prob: | |
image = F.vflip(image) | |
_, h = image.size # Get the new height of the image after flip for bounding box adjustment | |
# Adjust bounding boxes for vertical flip | |
for i, box in enumerate(target['boxes']): | |
xmin, ymin, xmax, ymax = box | |
target['boxes'][i] = torch.tensor([xmin, h - ymax, xmax, h - ymin], dtype=torch.float32) | |
# Adjust keypoints for vertical flip | |
if 'keypoints' in target: | |
new_keypoints = [] | |
for keypoints_for_object in target['keypoints']: | |
flipped_keypoints_for_object = [] | |
for kp in keypoints_for_object: | |
x, y = kp[:2] | |
new_y = h - y | |
flipped_keypoints_for_object.append(torch.tensor([x, new_y] + list(kp[2:]))) | |
new_keypoints.append(torch.stack(flipped_keypoints_for_object)) | |
target['keypoints'] = torch.stack(new_keypoints) | |
return image, target | |
class RandomRotate: | |
def __init__(self, max_rotate_deg=20, rotate_proba=0.3): | |
""" | |
Initialize the RandomRotate transformation with a maximum rotation angle and probability of rotating. | |
Parameters: | |
- max_rotate_deg (int): Maximum degree to rotate the image. | |
- rotate_proba (float): Probability of applying rotation to the image. | |
""" | |
self.max_rotate_deg = max_rotate_deg | |
self.rotate_proba = rotate_proba | |
def __call__(self, image, target): | |
""" | |
Randomly rotate the image and updates the target data accordingly. | |
Parameters: | |
- image (PIL Image): The image to be rotated. | |
- target (dict): The target dictionary containing 'boxes', 'labels', and 'keypoints'. | |
Returns: | |
- PIL Image, dict: The rotated image and its updated target dictionary. | |
""" | |
if random.random() < self.rotate_proba: | |
angle = random.uniform(-self.max_rotate_deg, self.max_rotate_deg) | |
image = F.rotate(image, angle, expand=False, fill=255) | |
# Rotate bounding boxes | |
w, h = image.size | |
cx, cy = w / 2, h / 2 | |
boxes = target["boxes"] | |
new_boxes = [] | |
for box in boxes: | |
new_box = self.rotate_box(box, angle, cx, cy) | |
new_boxes.append(new_box) | |
target["boxes"] = torch.stack(new_boxes) | |
# Rotate keypoints | |
if 'keypoints' in target: | |
new_keypoints = [] | |
for keypoints in target["keypoints"]: | |
new_kp = self.rotate_keypoints(keypoints, angle, cx, cy) | |
new_keypoints.append(new_kp) | |
target["keypoints"] = torch.stack(new_keypoints) | |
return image, target | |
def rotate_box(self, box, angle, cx, cy): | |
""" | |
Rotate a bounding box by a given angle around the center of the image. | |
Parameters: | |
- box (tensor): The bounding box to be rotated. | |
- angle (float): The angle to rotate the box. | |
- cx (float): The x-coordinate of the image center. | |
- cy (float): The y-coordinate of the image center. | |
Returns: | |
- tensor: The rotated bounding box. | |
""" | |
x1, y1, x2, y2 = box | |
corners = torch.tensor([ | |
[x1, y1], | |
[x2, y1], | |
[x2, y2], | |
[x1, y2] | |
]) | |
corners = torch.cat((corners, torch.ones(corners.shape[0], 1)), dim=1) | |
M = cv2.getRotationMatrix2D((cx, cy), angle, 1) | |
corners = torch.matmul(torch.tensor(M, dtype=torch.float32), corners.T).T | |
x_ = corners[:, 0] | |
y_ = corners[:, 1] | |
x_min, x_max = torch.min(x_), torch.max(x_) | |
y_min, y_max = torch.min(y_), torch.max(y_) | |
return torch.tensor([x_min, y_min, x_max, y_max], dtype=torch.float32) | |
def rotate_keypoints(self, keypoints, angle, cx, cy): | |
""" | |
Rotate keypoints by a given angle around the center of the image. | |
Parameters: | |
- keypoints (tensor): The keypoints to be rotated. | |
- angle (float): The angle to rotate the keypoints. | |
- cx (float): The x-coordinate of the image center. | |
- cy (float): The y-coordinate of the image center. | |
Returns: | |
- tensor: The rotated keypoints. | |
""" | |
new_keypoints = [] | |
for kp in keypoints: | |
x, y, v = kp | |
point = torch.tensor([x, y, 1]) | |
M = cv2.getRotationMatrix2D((cx, cy), angle, 1) | |
new_point = torch.matmul(torch.tensor(M, dtype=torch.float32), point) | |
new_keypoints.append(torch.tensor([new_point[0], new_point[1], v], dtype=torch.float32)) | |
return torch.stack(new_keypoints) | |
def rotate_90_box(box, angle, w, h): | |
""" | |
Rotate a bounding box by 90 degrees. | |
Parameters: | |
- box (tensor): The bounding box to be rotated. | |
- angle (int): The angle to rotate the box (90 or -90 degrees). | |
- w (int): The width of the image. | |
- h (int): The height of the image. | |
Returns: | |
- tensor: The rotated bounding box. | |
""" | |
x1, y1, x2, y2 = box | |
if angle == 90: | |
return torch.tensor([y1, h - x2, y2, h - x1]) | |
elif angle == 270 or angle == -90: | |
return torch.tensor([w - y2, x1, w - y1, x2]) | |
else: | |
print("angle not supported") | |
def rotate_90_keypoints(kp, angle, w, h): | |
""" | |
Rotate keypoints by 90 degrees. | |
Parameters: | |
- kp (tensor): The keypoints to be rotated. | |
- angle (int): The angle to rotate the keypoints (90 or -90 degrees). | |
- w (int): The width of the image. | |
- h (int): The height of the image. | |
Returns: | |
- tensor: The rotated keypoints. | |
""" | |
# Extract coordinates and visibility from each keypoint tensor | |
x1, y1, v1 = kp[0][0], kp[0][1], kp[0][2] | |
x2, y2, v2 = kp[1][0], kp[1][1], kp[1][2] | |
# Swap x and y coordinates for each keypoint | |
if angle == 90: | |
new = [[y1, h - x1, v1], [y2, h - x2, v2]] | |
elif angle == 270 or angle == -90: | |
new = [[w - y1, x1, v1], [w - y2, x2, v2]] | |
return torch.tensor(new, dtype=torch.float32) | |
def rotate_vertical(image, target): | |
""" | |
Rotate the image and target if the image is vertical. | |
Parameters: | |
- image (PIL Image): The image to be rotated. | |
- target (dict): The target dictionary containing 'boxes' and 'keypoints'. | |
Returns: | |
- PIL Image, dict: The rotated image and its updated target dictionary. | |
""" | |
new_boxes = [] | |
angle = random.choice([-90, 90]) | |
image = F.rotate(image, angle, expand=True, fill=200) | |
for box in target["boxes"]: | |
new_box = rotate_90_box(box, angle, image.size[0], image.size[1]) | |
new_boxes.append(new_box) | |
target["boxes"] = torch.stack(new_boxes) | |
if 'keypoints' in target: | |
new_kp = [] | |
for kp in target['keypoints']: | |
new_key = rotate_90_keypoints(kp, angle, image.size[0], image.size[1]) | |
new_kp.append(new_key) | |
target['keypoints'] = torch.stack(new_kp) | |
return image, target | |
def resize_and_pad(image, target, new_size=(1333, 800)): | |
""" | |
Resize and pad the image and target to the specified new size while maintaining the aspect ratio. | |
Parameters: | |
- image (PIL Image): The image to be resized and padded. | |
- target (dict): The target dictionary containing 'boxes' and optional 'keypoints'. | |
- new_size (tuple): The target size for the image after resizing and padding. | |
Returns: | |
- PIL Image, dict: The resized and padded image and its updated target dictionary. | |
""" | |
original_size = image.size | |
# Calculate scale to fit the new size while maintaining aspect ratio | |
scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1]) | |
new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale)) | |
target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), (new_scaled_size)) | |
if 'area' in target: | |
target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0]) | |
if 'keypoints' in target: | |
for i in range(len(target['keypoints'])): | |
target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), (new_scaled_size)) | |
# Resize image to new scaled size | |
image = F.resize(image, (new_scaled_size[1], new_scaled_size[0])) | |
# Pad the resized image to make it exactly the desired size | |
padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]] | |
image = F.pad(image, padding, fill=200, padding_mode='edge') | |
return image, target | |
class BPMN_Dataset(Dataset): | |
def __init__(self, annotations, transform=None, crop_transform=None, crop_prob=0.3, rotate_90_proba=0.2, | |
flip_transform=None, rotate_transform=None, new_size=(1333, 1333), keep_ratio=0.1, resize=True, model_type='object'): | |
""" | |
Initialize the BPMN_Dataset with annotations and optional transformations. | |
Parameters: | |
- annotations (list): List of annotations for the dataset. | |
- transform (callable, optional): Transformation function to apply to each image. | |
- crop_transform (callable, optional): Custom cropping transformation. | |
- crop_prob (float): Probability of applying the crop transformation. | |
- rotate_90_proba (float): Probability of rotating the image by 90 degrees. | |
- flip_transform (callable, optional): Custom flipping transformation. | |
- rotate_transform (callable, optional): Custom rotation transformation. | |
- new_size (tuple): Target size for the images. | |
- keep_ratio (float): Probability of keeping the aspect ratio during resizing. | |
- resize (bool): Flag indicating whether to resize images after transformations. | |
- model_type (str): Type of model ('object' or 'arrow') to determine the target dictionary. | |
""" | |
self.annotations = annotations | |
print(f"Loaded {len(self.annotations)} annotations.") | |
self.transform = transform | |
self.crop_transform = crop_transform | |
self.crop_prob = crop_prob | |
self.flip_transform = flip_transform | |
self.rotate_transform = rotate_transform | |
self.resize = resize | |
self.new_size = new_size | |
self.keep_ratio = keep_ratio | |
self.model_type = model_type | |
if model_type == 'object': | |
self.dict = object_dict | |
elif model_type == 'arrow': | |
self.dict = arrow_dict | |
self.rotate_90_proba = rotate_90_proba | |
def __len__(self): | |
""" | |
Return the number of annotations in the dataset. | |
Returns: | |
- int: The number of annotations. | |
""" | |
return len(self.annotations) | |
def __getitem__(self, idx): | |
""" | |
Get an item (image and target) from the dataset at the specified index. | |
Parameters: | |
- idx (int): The index of the item to retrieve. | |
Returns: | |
- PIL Image, dict: The transformed image and its updated target dictionary. | |
""" | |
annotation = self.annotations[idx] | |
image = annotation.img.convert("RGB") | |
boxes = torch.tensor(np.array(annotation.boxes_ltrb), dtype=torch.float32) | |
labels_names = [ann for ann in annotation.categories] | |
# Only keep the labels, boxes, and keypoints that are in the class_dict | |
kept_indices = [i for i, ann in enumerate(annotation.categories) if ann in self.dict.values()] | |
boxes = boxes[kept_indices] | |
labels_names = [ann for i, ann in enumerate(labels_names) if i in kept_indices] | |
# Replace any subprocess by task | |
labels_names = ['task' if ann == 'subProcess' else ann for ann in labels_names] | |
labels_id = torch.tensor([(list(self.dict.values()).index(ann)) for ann in labels_names], dtype=torch.int64) | |
# Initialize keypoints tensor | |
max_keypoints = 2 | |
keypoints = torch.zeros((len(labels_id), max_keypoints, 3), dtype=torch.float32) | |
ii = 0 | |
for i, ann in enumerate(annotation.annotations): | |
# Only keep the keypoints that are in the kept indices | |
if i not in kept_indices: | |
continue | |
if ann.category in ["sequenceFlow", "messageFlow", "dataAssociation"]: | |
# Fill the keypoints tensor for this annotation, mark as visible (1) | |
kp = np.array(ann.keypoints, dtype=np.float32).reshape(-1, 3) | |
kp = kp[:, :2] | |
visible = np.ones((kp.shape[0], 1), dtype=np.float32) | |
kp = np.hstack([kp, visible]) | |
keypoints[ii, :kp.shape[0], :] = torch.tensor(kp, dtype=torch.float32) | |
ii += 1 | |
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) | |
if self.model_type == 'object': | |
target = { | |
"boxes": boxes, | |
"labels": labels_id, | |
# "area": area, | |
} | |
elif self.model_type == 'arrow': | |
target = { | |
"boxes": boxes, | |
"labels": labels_id, | |
# "area": area, | |
"keypoints": keypoints, | |
} | |
# Randomly apply flip transform | |
if self.flip_transform: | |
image, target = self.flip_transform(image, target) | |
# Randomly apply rotate transform | |
if self.rotate_transform: | |
image, target = self.rotate_transform(image, target) | |
# Randomly apply the custom cropping transform | |
if self.crop_transform and random.random() < self.crop_prob: | |
image, target = self.crop_transform(image, target) | |
# Rotate vertical image | |
if random.random() < self.rotate_90_proba: | |
image, target = rotate_vertical(image, target) | |
if self.resize: | |
if random.random() < self.keep_ratio: | |
# Center and pad the image while keeping the aspect ratio | |
image, target = resize_and_pad(image, target, self.new_size) | |
else: | |
target['boxes'] = resize_boxes(target['boxes'], (image.size[0], image.size[1]), self.new_size) | |
if 'area' in target: | |
target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0]) | |
if 'keypoints' in target: | |
for i in range(len(target['keypoints'])): | |
target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0], image.size[1]), self.new_size) | |
image = F.resize(image, (self.new_size[1], self.new_size[0])) | |
return self.transform(image), target | |
def collate_fn(batch): | |
""" | |
Custom collation function for DataLoader that handles batches of images and targets. | |
This function ensures that images are properly batched together using PyTorch's default collation, | |
while keeping the targets (such as bounding boxes and labels) in a list of dictionaries, | |
as each image might have a different number of objects detected. | |
Parameters: | |
- batch (list): A list of tuples, where each tuple contains an image and its corresponding target dictionary. | |
Returns: | |
- Tuple containing: | |
- Tensor: Batched images. | |
- List of dicts: Targets corresponding to each image in the batch. | |
""" | |
images, targets = zip(*batch) # Unzip the batch into separate lists for images and targets. | |
# Batch images using the default collate function which handles tensors, numpy arrays, numbers, etc. | |
images = default_collate(images) | |
return images, targets | |
def create_loader(new_size, transformation, annotations1, annotations2=None, | |
batch_size=4, crop_prob=0.0, crop_fraction=0.7, min_objects=3, | |
h_flip_prob=0.0, v_flip_prob=0.0, max_rotate_deg=5, rotate_90_proba=0.0, rotate_proba=0.0, | |
seed=42, resize=True, keep_ratio=1, model_type='object'): | |
""" | |
Create a DataLoader for BPMN datasets with optional transformations and concatenation of two datasets. | |
Parameters: | |
- new_size (tuple): The target size for the images. | |
- transformation (callable): Transformation function to apply to each image (e.g., normalization). | |
- annotations1 (list): Primary list of annotations. | |
- annotations2 (list, optional): Secondary list of annotations to concatenate with the first. | |
- batch_size (int): Number of images per batch. | |
- crop_prob (float): Probability of applying the crop transformation. | |
- crop_fraction (float): Fraction of the original width to use when cropping. | |
- min_objects (int): Minimum number of objects required to be within the crop. | |
- h_flip_prob (float): Probability of applying horizontal flip. | |
- v_flip_prob (float): Probability of applying vertical flip. | |
- max_rotate_deg (int): Maximum degree to rotate the image. | |
- rotate_90_proba (float): Probability of rotating the image by 90 degrees. | |
- rotate_proba (float): Probability of applying rotation to the image. | |
- seed (int): Seed for random number generators for reproducibility. | |
- resize (bool): Flag indicating whether to resize images after transformations. | |
- keep_ratio (float): Probability of keeping the aspect ratio during resizing. | |
- model_type (str): Type of model ('object' or 'arrow') to determine the target dictionary. | |
Returns: | |
- DataLoader: Configured data loader for the dataset. | |
""" | |
# Initialize custom transformations for cropping and flipping | |
custom_crop_transform = RandomCrop(new_size, crop_fraction, min_objects) | |
custom_flip_transform = RandomFlip(h_flip_prob, v_flip_prob) | |
custom_rotate_transform = RandomRotate(max_rotate_deg, rotate_proba) | |
# Create the primary dataset | |
dataset = BPMN_Dataset( | |
annotations=annotations1, | |
transform=transformation, | |
crop_transform=custom_crop_transform, | |
crop_prob=crop_prob, | |
rotate_90_proba=rotate_90_proba, | |
flip_transform=custom_flip_transform, | |
rotate_transform=custom_rotate_transform, | |
new_size=new_size, | |
keep_ratio=keep_ratio, | |
model_type=model_type, | |
resize=resize | |
) | |
# Optionally concatenate a second dataset | |
if annotations2: | |
dataset2 = BPMN_Dataset( | |
annotations=annotations2, | |
transform=transformation, | |
crop_transform=custom_crop_transform, | |
crop_prob=crop_prob, | |
rotate_90_proba=rotate_90_proba, | |
flip_transform=custom_flip_transform, | |
new_size=new_size, | |
keep_ratio=keep_ratio, | |
model_type=model_type, | |
resize=resize | |
) | |
dataset = ConcatDataset([dataset, dataset2]) # Concatenate the two datasets | |
# Set the seed for reproducibility in random operations within transformations and data loading | |
random.seed(seed) | |
torch.manual_seed(seed) | |
# Create the DataLoader with the dataset | |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) | |
return data_loader | |