"""@package docstring @file CustomDataset.py @brief This library contains usefull functions to visualise data and also Classes to store and organise data structures. @section libraries_CustomDataset Libraries - torch - numpy - dataset_utils - cv2 - torch.utils.data - random - re @section classes_CustomDataset Classes - CUDatasetBase - CUDatasetStg2Compl - CUDatasetStg3Compl - CUDatasetStg4Compl - CUDatasetStg5Compl - CUDatasetStg6Compl - SamplerStg6 - SamplerStg5 - SamplerStg4 @section functions_CustomDataset Functions - yuv2bgr(matrix) - bgr2yuv(matrix) - get_cu(f_path, f_size, cu_pos, cu_size, frame_number) - get_file_size(name) - resize(img, scale_percent) - show_CU(image, cu_size, cu_pos) - show_all_CUs(CUDataset, file_name, POC, cu_size) @section global_vars_CustomDataset Global Variables - None @section todo_CustomDataset TODO - None @section license License MIT License Copyright (c) 2022 Raul Kevin do Espirito Santo Viana Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. @section author_CustomDataset Author(s) - Created by Raul Kevin Viana - Last time modified is 2023-01-29 22:22:04.113134 """ # ============================================================== # Imports # ============================================================== import torch import numpy as np from torch.utils.data import Dataset import cv2 import re from dataset_utils import VideoCaptureYUV, VideoCaptureYUVV2 import dataset_utils import random # ============================================================== # Functions # ============================================================== def yuv2bgr(matrix): """! Converts yuv matrix to bgr matrix @param [in] matrix: Yuv matrix @param [out] bgr: Bgr conversion """ # Convert from yuv to bgr bgr = cv2.cvtColor(matrix, cv2.COLOR_YUV2BGR_I420) return bgr def bgr2yuv(matrix): """! Converts BGR matrix to YUV matrix @param [in] matrix: BGR matrix @param [out] YUV: YUV conversion """ # Convert from bgr to yuv YUV = cv2.cvtColor(matrix, cv2.COLOR_BGR2YUV_I420) return YUV def get_cu(f_path, f_size, cu_pos, cu_size, frame_number): """! Get CU from image @param [in] f_path: Path of file to get the CU from @param [in] f_size: YUV file dimensions (height, width) @param [in] cu_pos: Tuple with the position of the CU (y position (height), x position (width)) @param [in] cu_size: Tuple with the size of the CU (y position (height), x position (width)) @param [in] frame_number: Number of the frame containing the CU @param [out] CU: CU with all the information from all the components(Luminance and Chroma) @param [out] CU_Y: CU with the Luminance component @param [out] CU_U: CU with the Chroma (Blue) component @param [out] CU_V: CU with the Chroma (Red) component @param [out] frame_CU: specified frame that contains the CU """ # Get file data yuv_file = VideoCaptureYUV(f_path, f_size) # Get the specific frame ret, yuv_frame, luma_frame, chr_u, chr_v = yuv_file.read_raw(frame_number) # Read Frame from File # Return false in case a wrongful value is returned if not ret: return False # Get region that contain the CU CU_Y = luma_frame[cu_pos[0]: (cu_size[0] + cu_pos[0]), cu_pos[1]: (cu_size[1] + cu_pos[1])] # CU Luma component # Get the CU different components CU_U = chr_u[int(cu_pos[0]/2): int((cu_size[0] + cu_pos[0])/2), int(cu_pos[1]/2): int((cu_size[1] + cu_pos[1])/2)] # CU Chroma blue component CU_V = chr_v[int(cu_pos[0]/2): int((cu_size[0] + cu_pos[0])/2), int(cu_pos[1]/2): int((cu_size[1] + cu_pos[1])/2)] # CU Chroma red component return yuv_frame, CU_Y, CU_U, CU_V def get_cu_v2(img, f_size, cu_pos, cu_size): """! Get CU from image. In this version, an YUV420 image is passed instead of a path. @param [in] img: Image in YUV420 format @param [in] f_size: YUV image dimensions (height, width) @param [in] cu_pos: Tuple with the position of the CU (y position (height), x position (width)) @param [in] cu_size: Tuple with the size of the CU (y position (height), x position (width)) @param [out] CU: CU with all the information from all the components(Luminance and Chroma) @param [out] CU_Y: CU with the Luminance component @param [out] CU_U: CU with the Chroma (Blue) component @param [out] CU_V: CU with the Chroma (Red) component @param [out] frame_CU: specified frame that contains the CU """ # Get file data yuv_file = VideoCaptureYUVV2(img, f_size) # Get the specific frame ret, yuv_frame, luma_frame, chr_u, chr_v = yuv_file.read_raw() # Read Frame from File # Return false in case a wrongful value is returned if not ret: return False # Get region that contain the CU CU_Y = luma_frame[cu_pos[0]: (cu_size[0] + cu_pos[0]), cu_pos[1]: (cu_size[1] + cu_pos[1])] # CU Luma component # Get the CU different components CU_U = chr_u[int(cu_pos[0]/2): int((cu_size[0] + cu_pos[0])/2), int(cu_pos[1]/2): int((cu_size[1] + cu_pos[1])/2)] # CU Chroma blue component CU_V = chr_v[int(cu_pos[0]/2): int((cu_size[0] + cu_pos[0])/2), int(cu_pos[1]/2): int((cu_size[1] + cu_pos[1])/2)] # CU Chroma red component return yuv_frame, CU_Y, CU_U, CU_V def get_file_size(name): """! Retrieves information about the YUV file info (width and height) @param [in] name: Name of the file where the file is located @param [out] file_info: Dictionary with information about the yuv file """ # Initialize variable file_info = {} ## Look for size # Look for the type "Number x Number" size = re.findall("\d+x\d+", name) if len(size) == 1: # Get size size = list(map(int, re.findall('\d+', size[0]))) # Obtain the values in integer file_info["width"] = size[0] file_info["height"] = size[1] # Look for fps framerate = re.findall("_\d\d_|_\d\d\.", name) if len(framerate) == 1: file_info["frame_rate"] = int(framerate[0][1:3]) else: file_info["frame_rate"] = 30 # Default frame_rate # Look for the type cif size = re.findall("_cif", name) if len(size) == 1: # Size file_info["width"] = 352 file_info["height"] = 288 # Frame rate file_info["frame_rate"] = 30 # Look for the type sif size = re.findall("_sif", name) if len(size) == 1: # Size file_info["width"] = 352 file_info["height"] = 240 # Frame rate file_info["frame_rate"] = 30 # Look for the type 4cif size = re.findall("_4cif", name) if len(size) == 1: # Size file_info["width"] = 704 file_info["height"] = 576 # Frame rate file_info["frame_rate"] = 30 # Look for the type 1080p size = re.findall("1080p\d*", name) if len(size) == 1: # Size file_info["width"] = 1920 file_info["height"] = 1080 # Frame rate framerate = list(map(int, re.findall('\d+', size[0]))) # Get all numbers from string if len(framerate) == 2: file_info["frame_rate"] = framerate[1] # Get frame rate else: file_info["frame_rate"] = 30 # Default frame rate # Look for the type 720p size = re.findall("720p\d*", name) if len(size) == 1: # Size file_info["width"] = 1280 file_info["height"] = 720 # Frame rate framerate = list(map(int, re.findall('\d+', size[0]))) # Get all numbers from string if len(framerate) == 2: file_info["frame_rate"] = framerate[1] # Get frame rate else: file_info["frame_rate"] = 30 # Default frame rate if len(file_info) == 0: return False else: return file_info def resize(img, scale_percent): """! Resizes a BGR image """ width = int(img.shape[1] * scale_percent / 100) height = int(img.shape[0] * scale_percent / 100) dim = (width, height) return cv2.resize(img, dim, interpolation=cv2.INTER_AREA) def show_CU(image, cu_size, cu_pos): """! Shows CU in a image """ # Convert image to bgr image = yuv2bgr(image) # Blue color in BGR color = (255, 0, 0) cu_pos_begin = (cu_pos[1], cu_pos[0]) # Line thickness of 2 px thickness = 2 # CU end position cu_pos_end = (cu_pos[1] + cu_size[1], cu_pos[0] + cu_size[0]) # Draw a rectangle with blue line borders of thickness of 2 px image_final = cv2.rectangle(image, cu_pos_begin, cu_pos_end, color, thickness) resize_img = resize(image_final, 50) cv2.imshow('Frame', resize_img) # Press q on keyboard to exit print('\n\n\n Press q to exit') while not (cv2.waitKey(100) & 0xFF == ord('q')): pass # Destroy all windows cv2.destroyAllWindows() def show_all_CUs(CUDataset, file_name, POC, cu_size): """! Shows all CUs in a image with a specific size, TO BE USED WITH THE CU ORIENTE LABELS """ image_obtained = False image = None for sample in CUDataset: # print(sample['file_name'], sample['POC'], sample['CU_size']) # Draw CU in image if sample['file_name'] == file_name and sample['POC'] == POC and sample['CU_size'] == cu_size: # print('hello2') f_size = get_file_size(file_name) # Dict with the size of the image cu, frame_CU, CU_Y, CU_U, CU_V = get_cu(sample['img_path'], (f_size['height'], f_size['width']), sample['CU_pos'], sample['CU_size'], POC) if not image_obtained: # Convert image to bgr image = yuv2bgr(frame_CU) # Blue color in BGR color = (255, 0, 0) # Line thickness of 2 px thickness = 2 # Change flag image_obtained = True # CU end position cu_pos_end = (sample['CU_pos'][0] + sample['CU_size'][0], sample['CU_pos'][1] + sample['CU_size'][1]) # Draw a rectangle with blue line borders of thickness of 2 px image = cv2.rectangle(image, sample['CU_pos'], cu_pos_end, color, thickness) resize_img = resize(image, 50) cv2.imshow('Frame', resize_img) # Press Q on keyboard to exit while not (cv2.waitKey(100) & 0xFF == ord('q')): pass # Destroy all windows cv2.destroyAllWindows() # ============================================================== # Classes # ============================================================== class CUDatasetBase(Dataset): """! Dataset stage oriented with capability of loading different files and it's supposed to be used with the function dataset_utils.change_labels_function_again. Works for stage 2 and 3 """ def __init__(self, files_path, channel=0): """! Args: @param files_path (string): Path to the files with annotations. @param channel: Channel to get for the dataset (0:Luma, 1:Chroma) """ # Store file paths self.files_path = files_path self.files = dataset_utils.get_files_from_folder(self.files_path, endswith = ".txt") # Compute number of entries per file self.lst_entries_nums = self.obtain_files_sizes(self.files) # Obtain amount of entries in all dataset files self.total_num_entries = 0 for f in self.lst_entries_nums: self.total_num_entries += f # Initialize variables self.channel = channel self.index_lims = [] self.data_files = [] # Compute index limits for each file for k in range(len(self.lst_entries_nums)): sum = -1 for f in self.lst_entries_nums[:k+1]: sum += f self.index_lims.append(sum) # Dataset for each file for k in range(len(self.files)): self.data_files.append(dataset_utils.file2lst(self.files_path + "/" + self.files[k][0:-4])) def __len__(self): return self.total_num_entries def get_sample(self, entry): """! Args: @param entry (int): An instance from the labels. @return out: lst - CTU | RD_for_specific_stage | cu_left_of_stg_1 | cu_top_of_stg_1 | cu_left_for_specific_stage | cu_top_for_specific_stage | split_for_specific_stage """ # Initialize variable info_lst = [] # Add Luma CU_Y = entry["real_CTU"] # Add dimension CU_Y = torch.unsqueeze(CU_Y, 0) # Convert to float CU_Y = CU_Y.to(dtype=torch.float32) # CU positions within frame cu_pos = entry["cu_pos"] # CU positions within frame cu_size = entry["cu_size"] # Best split for CU split = entry["split"] # Rate distortion costs RDs = entry['RD'] # Unite RDs = np.reshape(RDs, (-1, 6)) # Save values info_lst.append(CU_Y) info_lst.append(cu_pos) info_lst.append(cu_size) info_lst.append(split) info_lst.append(RDs) return info_lst def obtain_files_sizes(self, files): """! Args: @param files (list): List containing the names of files with CUs info """ # Initialize variable lst = [] # Create list with the number of entries of each file for f in files: f_path = self.files_path + "/" + f[0:-4] file_obj = dataset_utils.file2lst(f_path) num_entries = len(file_obj) lst.append(num_entries) return lst def select_entry(self, idx): """! Args: @param idx (int): Index with the position to search for a specific entry """ for k in range(len(self.index_lims)): if idx <= self.index_lims[k]: # Obtain index in the object idx_obj = idx - (self.index_lims[k] - self.lst_entries_nums[k] + 1) file_obj = self.data_files[k] # Obtain entry entry = file_obj[idx_obj] return entry raise Exception("Entry not found!!") def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() # Select entry found = False while not found: entry = self.select_entry(idx) # Choose entry ### Get info from label and process # Get sample sample = self.get_sample(entry) ctu_size = (sample[0].shape[-2], sample[0].shape[-1]) if (128, 128) == ctu_size: found = True else: idx = (idx + 1) % self.total_num_entries # Increment index until a acceptable entry is found return sample class CUDatasetStg5Compl(CUDatasetBase): """! Dataset stage oriented with capability of loading different files and it's supposed to be used with the function dataset_utils.change_labels_function_again. For stage 5 This version contains all of the information available from the CU data """ def __init__(self, files_path, channel=0): """! Args: @param files_path (string): Path to the files with annotations. @param channel: Channel to get for the dataset (0:Luma, 1:Chroma) """ super(CUDatasetStg5Compl, self).__init__(files_path, channel) def get_sample(self, entry): """! Args: @param entry (int): An instance from the labels. @return out: lst - CTU | RD_for_specific_stage | cu_left_of_stg_1 | cu_top_of_stg_1 | cu_left_for_specific_stage | cu_top_for_specific_stage | split_for_specific_stage """ # Initialize variable info_lst = [] color_ch = entry[14] if color_ch == self.channel: # Add Real CU CU_Y = entry[13] # Add dimension CU_Y = torch.unsqueeze(CU_Y, 0) # Convert to float CU_Y = CU_Y.to(dtype=torch.float32) # CU positions within frame cu_pos = torch.reshape(torch.tensor(entry[4]), (-1, 2)) cu_pos_stg4 = torch.reshape(torch.tensor(entry[3]), (-1, 2)) cu_pos_stg3 = torch.reshape(torch.tensor(entry[2]), (-1, 2)) cu_pos_stg2 = torch.reshape(torch.tensor(entry[1]), (-1, 2)) # CU sizes within frame cu_size = torch.reshape(torch.tensor(entry[8]), (-1, 2)) cu_size_stg4 = torch.reshape(torch.tensor(entry[7]), (-1, 2)) cu_size_stg3 = torch.reshape(torch.tensor(entry[6]), (-1, 2)) cu_size_stg2 = torch.reshape(torch.tensor(entry[5]), (-1, 2)) # Best split for CU split = entry[12] split_stg4 = entry[11] split_stg3 = entry[10] split_stg2 = entry[9] # Other information POC = entry[15] pic_name = entry[16] orig_pos_x = entry[17][1] orig_pos_y = entry[17][0] orig_size_h = entry[18][1] orig_size_w = entry[18][0] # Rate distortion costs RDs = torch.reshape(torch.tensor(entry[0]), (-1, 6)) # Save values info_lst.append(CU_Y) info_lst.append(cu_pos_stg2) info_lst.append(cu_pos_stg3) info_lst.append(cu_pos_stg4) info_lst.append(cu_pos) info_lst.append(cu_size_stg2) info_lst.append(cu_size_stg3) info_lst.append(cu_size_stg4) info_lst.append(cu_size) info_lst.append(split_stg2) info_lst.append(split_stg3) info_lst.append(split_stg4) info_lst.append(split) info_lst.append(RDs) # Other data info_lst.append(orig_pos_x) info_lst.append(orig_pos_y) info_lst.append(orig_size_h) info_lst.append(orig_size_w) info_lst.append(POC) info_lst.append(pic_name) else: raise Exception("This can not happen! This CU color channel should be " + str(self.channel) + "Try generating the labels and obtain just a specific color channel (see dataset_utils)!") return info_lst class CUDatasetStg2Compl(CUDatasetBase): """! - Dataset stage oriented with capability of loading different files and it's supposed to be used with the function dataset_utils.change_labels_function_again. For stage 2 This version contains all of the information available from the CU data """ def __init__(self, files_path, channel=0): """! Args: @param files_path (string): Path to the files with annotations. @param channel: Channel to get for the dataset (0:Luma, 1:Chroma) """ super(CUDatasetStg2Compl, self).__init__(files_path, channel) def get_sample(self, entry): """! Args: @param entry (int): An instance from the labels. @return out: lst - CTU | RD_for_specific_stage | cu_left_of_stg_1 | cu_top_of_stg_1 | cu_left_for_specific_stage | cu_top_for_specific_stage | split_for_specific_stage """ # Initialize variable info_lst = [] color_ch = entry[5] if color_ch == self.channel: # Add Real CU CU_Y = entry[4] # Add dimension CU_Y = torch.unsqueeze(CU_Y, 0) # Convert to float CU_Y = CU_Y.to(dtype=torch.float32) # CU positions within frame cu_pos = torch.reshape(torch.tensor(entry[1]), (-1, 2)) # CU sizes within frame cu_size = torch.reshape(torch.tensor(entry[2]), (-1, 2)) # Best split for CU split = entry[3] # Rate distortion costs RDs = torch.reshape(torch.tensor(entry[0]), (-1, 6)) # Other information POC = entry[6] pic_name = entry[7] orig_pos_x = entry[8][1] orig_pos_y = entry[8][0] orig_size_h = entry[9][1] orig_size_w = entry[9][0] # Save values info_lst.append(CU_Y) info_lst.append(cu_pos) info_lst.append(cu_size) info_lst.append(split) info_lst.append(RDs) # Other data info_lst.append(orig_pos_x) info_lst.append(orig_pos_y) info_lst.append(orig_size_h) info_lst.append(orig_size_w) info_lst.append(POC) info_lst.append(pic_name) else: raise Exception("This can not happen! This CU color channel should be " + str(self.channel) + "Try generating the labels and obtain just a specific color channel (see dataset_utils)!") return info_lst class CUDatasetStg6Compl(CUDatasetBase): """! Dataset stage oriented with capability of loading different files and it's supposed to be used with the function dataset_utils.change_labels_function_again. For stage 6 This version contains all of the information available from the CU data """ def __init__(self, files_path, channel=0): """! Args: files_path (string): Path to the files with annotations. root_dir (string): Directory with all the images. channel: Channel to get for the dataset (0:Luma, 1:Chroma) """ super(CUDatasetStg6Compl, self).__init__(files_path, channel) def get_sample(self, entry): """! Args: @param entry (int): An instance from the labels. """ # Initialize variable info_lst = [] color_ch = entry[17] if color_ch == self.channel: # Add Real CU CU_Y = entry[16] # Add dimension CU_Y = torch.unsqueeze(CU_Y, 0) # Convert to float CU_Y = CU_Y.to(dtype=torch.float32) # CU positions within frame cu_pos = torch.reshape(torch.tensor(entry[5]), (-1, 2)) cu_pos_stg5 = torch.reshape(torch.tensor(entry[4]), (-1, 2)) cu_pos_stg4 = torch.reshape(torch.tensor(entry[3]), (-1, 2)) cu_pos_stg3 = torch.reshape(torch.tensor(entry[2]), (-1, 2)) cu_pos_stg2 = torch.reshape(torch.tensor(entry[1]), (-1, 2)) # CU sizes within frame cu_size = torch.reshape(torch.tensor(entry[10]), (-1, 2)) cu_size_stg5 = torch.reshape(torch.tensor(entry[9]), (-1, 2)) cu_size_stg4 = torch.reshape(torch.tensor(entry[8]), (-1, 2)) cu_size_stg3 = torch.reshape(torch.tensor(entry[7]), (-1, 2)) cu_size_stg2 = torch.reshape(torch.tensor(entry[6]), (-1, 2)) # Best split for CU split = entry[15] split_stg5 = entry[14] split_stg4 = entry[13] split_stg3 = entry[12] split_stg2 = entry[11] # Rate distortion costs RDs = torch.reshape(torch.tensor(entry[0]), (-1, 6)) # Other information POC = entry[18] pic_name = entry[19] orig_pos_x = entry[20][1] orig_pos_y = entry[20][0] orig_size_h = entry[21][1] orig_size_w = entry[21][0] # Save values info_lst.append(CU_Y) info_lst.append(cu_pos_stg2) info_lst.append(cu_pos_stg3) info_lst.append(cu_pos_stg4) info_lst.append(cu_pos_stg5) info_lst.append(cu_pos) info_lst.append(cu_size_stg2) info_lst.append(cu_size_stg3) info_lst.append(cu_size_stg4) info_lst.append(cu_size_stg5) info_lst.append(cu_size) info_lst.append(split_stg2) info_lst.append(split_stg3) info_lst.append(split_stg4) info_lst.append(split_stg5) info_lst.append(split) info_lst.append(RDs) # Other data info_lst.append(orig_pos_x) info_lst.append(orig_pos_y) info_lst.append(orig_size_h) info_lst.append(orig_size_w) info_lst.append(POC) info_lst.append(pic_name) else: raise Exception("This can not happen! This CU color channel should be " + str(self.channel) + "Try generating the labels and obtain just a specific color channel (see dataset_utils)!") return info_lst class CUDatasetStg4Compl(CUDatasetBase): """! Dataset stage oriented with capability of loading different files and it's supposed to be used with the function dataset_utils.change_labels_function_again. For stage 4 This version contains all of the information available from the CU data """ def __init__(self, files_path, channel=0): """! Args: @param files_path (string): Path to the files with annotations. @param channel: Channel to get for the dataset (0:Luma, 1:Chroma) """ super(CUDatasetStg4Compl, self).__init__(files_path, channel) def get_sample(self, entry): """! Args: @param entry (int): An instance from the labels. @param out: lst - CTU | RD_for_specific_stage | cu_left_of_stg_1 | cu_top_of_stg_1 | cu_left_for_specific_stage | cu_top_for_specific_stage | split_for_specific_stage """ # Initialize variable info_lst = [] color_ch = entry[11] if color_ch == self.channel: # Add Real CU CU_Y = entry[10] # Add dimension CU_Y = torch.unsqueeze(CU_Y, 0) # Convert to float CU_Y = CU_Y.to(dtype=torch.float32) # CU positions within frame cu_pos = torch.reshape(torch.tensor(entry[3]), (-1, 2)) cu_pos_stg3 = torch.reshape(torch.tensor(entry[2]), (-1, 2)) cu_pos_stg2 = torch.reshape(torch.tensor(entry[1]), (-1, 2)) # CU sizes within frame cu_size = torch.reshape(torch.tensor(entry[6]), (-1, 2)) cu_size_stg3 = torch.reshape(torch.tensor(entry[5]), (-1, 2)) cu_size_stg2 = torch.reshape(torch.tensor(entry[4]), (-1, 2)) # Best split for CU split = entry[9] split_stg3 = entry[8] split_stg2 = entry[7] # Rate distortion costs RDs = torch.reshape(torch.tensor(entry[0]), (-1, 6)) # Other information POC = entry[12] pic_name = entry[13] orig_pos_x = entry[14][1] orig_pos_y = entry[14][0] orig_size_h = entry[15][1] orig_size_w = entry[15][0] # Save values info_lst.append(CU_Y) info_lst.append(cu_pos_stg2) info_lst.append(cu_pos_stg3) info_lst.append(cu_pos) info_lst.append(cu_size_stg2) info_lst.append(cu_size_stg3) info_lst.append(cu_size) info_lst.append(split_stg2) info_lst.append(split_stg3) info_lst.append(split) info_lst.append(RDs) # Other data info_lst.append(orig_pos_x) info_lst.append(orig_pos_y) info_lst.append(orig_size_h) info_lst.append(orig_size_w) info_lst.append(POC) info_lst.append(pic_name) else: raise Exception("This can not happen! This CU color channel should be " + str(self.channel) + "Try generating the labels and obtain just a specific color channel (see dataset_utils)!") return info_lst class CUDatasetStg3Compl(CUDatasetBase): """! - Dataset stage oriented with capability of loading different files and it's supposed to be used with the function dataset_utils.change_labels_function_again. For stage 3 This version contains all of the information available from the CU data """ def __init__(self, files_path, channel=0): """! Args: @param files_path (string): Path to the files with annotations. @param channel: Channel to get for the dataset (0:Luma, 1:Chroma) """ super(CUDatasetStg3Compl, self).__init__(files_path, channel) def get_sample(self, entry): """! Args: @param entry (int): An instance from the labels. @return out: lst - CTU | RD_for_specific_stage | cu_left_of_stg_1 | cu_top_of_stg_1 | cu_left_for_specific_stage | cu_top_for_specific_stage | split_for_specific_stage """ # Initialize variable info_lst = [] color_ch = entry[8] if color_ch == self.channel: # Add Real CU CU_Y = entry[7] # Add dimension CU_Y = torch.unsqueeze(CU_Y, 0) # Convert to float CU_Y = CU_Y.to(dtype=torch.float32) # CU positions within frame cu_pos = torch.reshape(torch.tensor(entry[2]), (-1, 2)) cu_pos_stg2 = torch.reshape(torch.tensor(entry[1]), (-1, 2)) # CU sizes within frame cu_size = torch.reshape(torch.tensor(entry[4]), (-1, 2)) cu_size_stg2 = torch.reshape(torch.tensor(entry[3]), (-1, 2)) # Best split for CU split = entry[6] split_stg2 = entry[5] # Rate distortion costs RDs = torch.reshape(torch.tensor(entry[0]), (-1, 6)) # Other information POC = entry[9] pic_name = entry[10] orig_pos_x = entry[11][1] orig_pos_y = entry[11][0] orig_size_h = entry[12][1] orig_size_w = entry[12][0] # Save values info_lst.append(CU_Y) info_lst.append(cu_pos_stg2) info_lst.append(cu_pos) info_lst.append(cu_size_stg2) info_lst.append(cu_size) info_lst.append(split_stg2) info_lst.append(split) info_lst.append(RDs) # Other data info_lst.append(orig_pos_x) info_lst.append(orig_pos_y) info_lst.append(orig_size_h) info_lst.append(orig_size_w) info_lst.append(POC) info_lst.append(pic_name) else: raise Exception("This can not happen! This CU color channel should be " + str(self.channel) + "Try generating the labels and obtain just a specific color channel (see dataset_utils)!") return info_lst class SamplerStg6(torch.utils.data.Sampler): def __init__(self, data_source, batch_size): """! Args: @param data_source (Dataset): dataset to sample from @param batch_size (int): Batch size to sample data """ self.data_source = data_source self.batch_size = batch_size super(SamplerStg6, self).__init__(data_source) def __iter__(self): # Initialize variables data_size = len(self.data_source) indices = np.arange(data_size) indices = np.random.permutation(indices) random.shuffle(indices) dic_type_size = {} # Search for data for i in indices: # Build unique key key = str(self.data_source[i][8].squeeze()[0].item()) + str(self.data_source[i][8].squeeze()[1].item()) \ + str(self.data_source[i][9].squeeze()[0].item()) + str(self.data_source[i][9].squeeze()[1].item()) \ + str(self.data_source[i][10].squeeze()[0].item()) + str(self.data_source[i][10].squeeze()[1].item()) # Verify if key exists # If exists, add index to it # Else, create it and add index to it try: dic_type_size[key].append(i) except: dic_type_size[key] = [] dic_type_size[key].append(i) # Check if for each list the size equals the batch size or note # if it does, yield it and reset the list for k in dic_type_size.keys(): if len(dic_type_size[k]) == self.batch_size: yield dic_type_size[k] dic_type_size[k] = [] # if the loop is exited and didn't achieved batch_size # Yield the rest for k in dic_type_size.keys(): if len(dic_type_size[k]) > 0: yield dic_type_size[k] def __len__(self): return self.batch_size class SamplerStg5(torch.utils.data.Sampler): def __init__(self, data_source, batch_size): """! Args: @param data_source (Dataset): dataset to sample from @param batch_size (int): Batch size to sample data """ self.data_source = data_source self.batch_size = batch_size super(SamplerStg5, self).__init__(data_source) def __iter__(self): # Initialize variables data_size = len(self.data_source) indices = np.arange(data_size) indices = np.random.permutation(indices) random.shuffle(indices) dic_type_size = {} # Search for data for i in indices: # Build unique key key = str(self.data_source[i][7].squeeze()[0].item()) + str(self.data_source[i][7].squeeze()[1].item()) \ + str(self.data_source[i][8].squeeze()[0].item()) + str(self.data_source[i][8].squeeze()[1].item()) \ # Verify if key exists # If exists, add index to it # Else, create it and add index to it try: dic_type_size[key].append(i) except: dic_type_size[key] = [] dic_type_size[key].append(i) # Check if for each list the size equals the batch size or note # if it does, yield it and reset the list for k in dic_type_size.keys(): if len(dic_type_size[k]) == self.batch_size: yield dic_type_size[k] dic_type_size[k] = [] # if the loop is exited and didn't achieved batch_size # Yield the rest for k in dic_type_size.keys(): if len(dic_type_size[k]) > 0: yield dic_type_size[k] def __len__(self): return self.batch_size class SamplerStg4(torch.utils.data.Sampler): def __init__(self, data_source, batch_size): """! Args: @param data_source (Dataset): dataset to sample from @param batch_size (int): Batch size to sample data """ self.data_source = data_source self.batch_size = batch_size super(SamplerStg4, self).__init__(data_source) def __iter__(self): # Initialize variables data_size = len(self.data_source) indices = np.arange(data_size) indices = np.random.permutation(indices) random.shuffle(indices) dic_type_size = {} # Search for data for stage 4 and 5 for i in indices: # Build unique key key = str(self.data_source[i][4].squeeze()[0].item()) + str(self.data_source[i][4].squeeze()[1].item()) \ + str(self.data_source[i][5].squeeze()[0].item()) + str(self.data_source[i][5].squeeze()[1].item())\ + str(self.data_source[i][6].squeeze()[0].item()) + str(self.data_source[i][6].squeeze()[1].item()) # Verify if key exists # If exists, add index to it # Else, create it and add index to it try: dic_type_size[key].append(i) except: dic_type_size[key] = [] dic_type_size[key].append(i) # Check if for each list the size equals the batch size or note # if it does, yield it and reset the list for k in dic_type_size.keys(): if len(dic_type_size[k]) == self.batch_size: yield dic_type_size[k] dic_type_size[k] = [] # if the loop is exited and didn't achieved batch_size # Yield the rest for k in dic_type_size.keys(): if len(dic_type_size[k]) > 0: yield dic_type_size[k] def __len__(self): return self.batch_size