Spaces:
Runtime error
Runtime error
"""@package docstring | |
@file CustomDataset.py | |
@brief This library contains usefull functions to visualise data and also Classes to store and organise data structures. | |
@section libraries_CustomDataset Libraries | |
- torch | |
- numpy | |
- dataset_utils | |
- cv2 | |
- torch.utils.data | |
- random | |
- re | |
@section classes_CustomDataset Classes | |
- CUDatasetBase | |
- CUDatasetStg2Compl | |
- CUDatasetStg3Compl | |
- CUDatasetStg4Compl | |
- CUDatasetStg5Compl | |
- CUDatasetStg6Compl | |
- SamplerStg6 | |
- SamplerStg5 | |
- SamplerStg4 | |
@section functions_CustomDataset Functions | |
- yuv2bgr(matrix) | |
- bgr2yuv(matrix) | |
- get_cu(f_path, f_size, cu_pos, cu_size, frame_number) | |
- get_file_size(name) | |
- resize(img, scale_percent) | |
- show_CU(image, cu_size, cu_pos) | |
- show_all_CUs(CUDataset, file_name, POC, cu_size) | |
@section global_vars_CustomDataset Global Variables | |
- None | |
@section todo_CustomDataset TODO | |
- None | |
@section license License | |
MIT License | |
Copyright (c) 2022 Raul Kevin do Espirito Santo Viana | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all | |
copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
SOFTWARE. | |
@section author_CustomDataset Author(s) | |
- Created by Raul Kevin Viana | |
- Last time modified is 2023-01-29 22:22:04.113134 | |
""" | |
# ============================================================== | |
# Imports | |
# ============================================================== | |
import torch | |
import numpy as np | |
from torch.utils.data import Dataset | |
import cv2 | |
import re | |
from dataset_utils import VideoCaptureYUV, VideoCaptureYUVV2 | |
import dataset_utils | |
import random | |
# ============================================================== | |
# Functions | |
# ============================================================== | |
def yuv2bgr(matrix): | |
"""! | |
Converts yuv matrix to bgr matrix | |
@param [in] matrix: Yuv matrix | |
@param [out] bgr: Bgr conversion | |
""" | |
# Convert from yuv to bgr | |
bgr = cv2.cvtColor(matrix, cv2.COLOR_YUV2BGR_I420) | |
return bgr | |
def bgr2yuv(matrix): | |
"""! | |
Converts BGR matrix to YUV matrix | |
@param [in] matrix: BGR matrix | |
@param [out] YUV: YUV conversion | |
""" | |
# Convert from bgr to yuv | |
YUV = cv2.cvtColor(matrix, cv2.COLOR_BGR2YUV_I420) | |
return YUV | |
def get_cu(f_path, f_size, cu_pos, cu_size, frame_number): | |
"""! | |
Get CU from image | |
@param [in] f_path: Path of file to get the CU from | |
@param [in] f_size: YUV file dimensions (height, width) | |
@param [in] cu_pos: Tuple with the position of the CU (y position (height), x position (width)) | |
@param [in] cu_size: Tuple with the size of the CU (y position (height), x position (width)) | |
@param [in] frame_number: Number of the frame containing the CU | |
@param [out] CU: CU with all the information from all the components(Luminance and Chroma) | |
@param [out] CU_Y: CU with the Luminance component | |
@param [out] CU_U: CU with the Chroma (Blue) component | |
@param [out] CU_V: CU with the Chroma (Red) component | |
@param [out] frame_CU: specified frame that contains the CU | |
""" | |
# Get file data | |
yuv_file = VideoCaptureYUV(f_path, f_size) | |
# Get the specific frame | |
ret, yuv_frame, luma_frame, chr_u, chr_v = yuv_file.read_raw(frame_number) # Read Frame from File | |
# Return false in case a wrongful value is returned | |
if not ret: | |
return False | |
# Get region that contain the CU | |
CU_Y = luma_frame[cu_pos[0]: (cu_size[0] + cu_pos[0]), cu_pos[1]: (cu_size[1] + cu_pos[1])] # CU Luma component | |
# Get the CU different components | |
CU_U = chr_u[int(cu_pos[0]/2): int((cu_size[0] + cu_pos[0])/2), int(cu_pos[1]/2): int((cu_size[1] + cu_pos[1])/2)] # CU Chroma blue component | |
CU_V = chr_v[int(cu_pos[0]/2): int((cu_size[0] + cu_pos[0])/2), int(cu_pos[1]/2): int((cu_size[1] + cu_pos[1])/2)] # CU Chroma red component | |
return yuv_frame, CU_Y, CU_U, CU_V | |
def get_cu_v2(img, f_size, cu_pos, cu_size): | |
"""! | |
Get CU from image. In this version, an YUV420 image is passed instead of a path. | |
@param [in] img: Image in YUV420 format | |
@param [in] f_size: YUV image dimensions (height, width) | |
@param [in] cu_pos: Tuple with the position of the CU (y position (height), x position (width)) | |
@param [in] cu_size: Tuple with the size of the CU (y position (height), x position (width)) | |
@param [out] CU: CU with all the information from all the components(Luminance and Chroma) | |
@param [out] CU_Y: CU with the Luminance component | |
@param [out] CU_U: CU with the Chroma (Blue) component | |
@param [out] CU_V: CU with the Chroma (Red) component | |
@param [out] frame_CU: specified frame that contains the CU | |
""" | |
# Get file data | |
yuv_file = VideoCaptureYUVV2(img, f_size) | |
# Get the specific frame | |
ret, yuv_frame, luma_frame, chr_u, chr_v = yuv_file.read_raw() # Read Frame from File | |
# Return false in case a wrongful value is returned | |
if not ret: | |
return False | |
# Get region that contain the CU | |
CU_Y = luma_frame[cu_pos[0]: (cu_size[0] + cu_pos[0]), cu_pos[1]: (cu_size[1] + cu_pos[1])] # CU Luma component | |
# Get the CU different components | |
CU_U = chr_u[int(cu_pos[0]/2): int((cu_size[0] + cu_pos[0])/2), int(cu_pos[1]/2): int((cu_size[1] + cu_pos[1])/2)] # CU Chroma blue component | |
CU_V = chr_v[int(cu_pos[0]/2): int((cu_size[0] + cu_pos[0])/2), int(cu_pos[1]/2): int((cu_size[1] + cu_pos[1])/2)] # CU Chroma red component | |
return yuv_frame, CU_Y, CU_U, CU_V | |
def get_file_size(name): | |
"""! | |
Retrieves information about the YUV file info (width and height) | |
@param [in] name: Name of the file where the file is located | |
@param [out] file_info: Dictionary with information about the yuv file | |
""" | |
# Initialize variable | |
file_info = {} | |
## Look for size | |
# Look for the type "Number x Number" | |
size = re.findall("\d+x\d+", name) | |
if len(size) == 1: | |
# Get size | |
size = list(map(int, re.findall('\d+', size[0]))) # Obtain the values in integer | |
file_info["width"] = size[0] | |
file_info["height"] = size[1] | |
# Look for fps | |
framerate = re.findall("_\d\d_|_\d\d\.", name) | |
if len(framerate) == 1: | |
file_info["frame_rate"] = int(framerate[0][1:3]) | |
else: | |
file_info["frame_rate"] = 30 # Default frame_rate | |
# Look for the type cif | |
size = re.findall("_cif", name) | |
if len(size) == 1: | |
# Size | |
file_info["width"] = 352 | |
file_info["height"] = 288 | |
# Frame rate | |
file_info["frame_rate"] = 30 | |
# Look for the type sif | |
size = re.findall("_sif", name) | |
if len(size) == 1: | |
# Size | |
file_info["width"] = 352 | |
file_info["height"] = 240 | |
# Frame rate | |
file_info["frame_rate"] = 30 | |
# Look for the type 4cif | |
size = re.findall("_4cif", name) | |
if len(size) == 1: | |
# Size | |
file_info["width"] = 704 | |
file_info["height"] = 576 | |
# Frame rate | |
file_info["frame_rate"] = 30 | |
# Look for the type 1080p | |
size = re.findall("1080p\d*", name) | |
if len(size) == 1: | |
# Size | |
file_info["width"] = 1920 | |
file_info["height"] = 1080 | |
# Frame rate | |
framerate = list(map(int, re.findall('\d+', size[0]))) # Get all numbers from string | |
if len(framerate) == 2: | |
file_info["frame_rate"] = framerate[1] # Get frame rate | |
else: | |
file_info["frame_rate"] = 30 # Default frame rate | |
# Look for the type 720p | |
size = re.findall("720p\d*", name) | |
if len(size) == 1: | |
# Size | |
file_info["width"] = 1280 | |
file_info["height"] = 720 | |
# Frame rate | |
framerate = list(map(int, re.findall('\d+', size[0]))) # Get all numbers from string | |
if len(framerate) == 2: | |
file_info["frame_rate"] = framerate[1] # Get frame rate | |
else: | |
file_info["frame_rate"] = 30 # Default frame rate | |
if len(file_info) == 0: | |
return False | |
else: | |
return file_info | |
def resize(img, scale_percent): | |
"""! | |
Resizes a BGR image | |
""" | |
width = int(img.shape[1] * scale_percent / 100) | |
height = int(img.shape[0] * scale_percent / 100) | |
dim = (width, height) | |
return cv2.resize(img, dim, interpolation=cv2.INTER_AREA) | |
def show_CU(image, cu_size, cu_pos): | |
"""! | |
Shows CU in a image | |
""" | |
# Convert image to bgr | |
image = yuv2bgr(image) | |
# Blue color in BGR | |
color = (255, 0, 0) | |
cu_pos_begin = (cu_pos[1], cu_pos[0]) | |
# Line thickness of 2 px | |
thickness = 2 | |
# CU end position | |
cu_pos_end = (cu_pos[1] + cu_size[1], cu_pos[0] + cu_size[0]) | |
# Draw a rectangle with blue line borders of thickness of 2 px | |
image_final = cv2.rectangle(image, cu_pos_begin, cu_pos_end, color, thickness) | |
resize_img = resize(image_final, 50) | |
cv2.imshow('Frame', resize_img) | |
# Press q on keyboard to exit | |
print('\n\n\n Press q to exit') | |
while not (cv2.waitKey(100) & 0xFF == ord('q')): | |
pass | |
# Destroy all windows | |
cv2.destroyAllWindows() | |
def show_all_CUs(CUDataset, file_name, POC, cu_size): | |
"""! | |
Shows all CUs in a image with a specific size, TO BE USED WITH THE CU ORIENTE LABELS | |
""" | |
image_obtained = False | |
image = None | |
for sample in CUDataset: | |
# print(sample['file_name'], sample['POC'], sample['CU_size']) | |
# Draw CU in image | |
if sample['file_name'] == file_name and sample['POC'] == POC and sample['CU_size'] == cu_size: | |
# print('hello2') | |
f_size = get_file_size(file_name) # Dict with the size of the image | |
cu, frame_CU, CU_Y, CU_U, CU_V = get_cu(sample['img_path'], (f_size['height'], f_size['width']), | |
sample['CU_pos'], sample['CU_size'], POC) | |
if not image_obtained: | |
# Convert image to bgr | |
image = yuv2bgr(frame_CU) | |
# Blue color in BGR | |
color = (255, 0, 0) | |
# Line thickness of 2 px | |
thickness = 2 | |
# Change flag | |
image_obtained = True | |
# CU end position | |
cu_pos_end = (sample['CU_pos'][0] + sample['CU_size'][0], sample['CU_pos'][1] + sample['CU_size'][1]) | |
# Draw a rectangle with blue line borders of thickness of 2 px | |
image = cv2.rectangle(image, sample['CU_pos'], cu_pos_end, color, thickness) | |
resize_img = resize(image, 50) | |
cv2.imshow('Frame', resize_img) | |
# Press Q on keyboard to exit | |
while not (cv2.waitKey(100) & 0xFF == ord('q')): | |
pass | |
# Destroy all windows | |
cv2.destroyAllWindows() | |
# ============================================================== | |
# Classes | |
# ============================================================== | |
class CUDatasetBase(Dataset): | |
"""! | |
Dataset stage oriented with capability of loading different files and it's supposed to be used with the function dataset_utils.change_labels_function_again. Works for stage 2 and 3 | |
""" | |
def __init__(self, files_path, channel=0): | |
"""! | |
Args: | |
@param files_path (string): Path to the files with annotations. | |
@param channel: Channel to get for the dataset (0:Luma, 1:Chroma) | |
""" | |
# Store file paths | |
self.files_path = files_path | |
self.files = dataset_utils.get_files_from_folder(self.files_path, endswith = ".txt") | |
# Compute number of entries per file | |
self.lst_entries_nums = self.obtain_files_sizes(self.files) | |
# Obtain amount of entries in all dataset files | |
self.total_num_entries = 0 | |
for f in self.lst_entries_nums: | |
self.total_num_entries += f | |
# Initialize variables | |
self.channel = channel | |
self.index_lims = [] | |
self.data_files = [] | |
# Compute index limits for each file | |
for k in range(len(self.lst_entries_nums)): | |
sum = -1 | |
for f in self.lst_entries_nums[:k+1]: | |
sum += f | |
self.index_lims.append(sum) | |
# Dataset for each file | |
for k in range(len(self.files)): | |
self.data_files.append(dataset_utils.file2lst(self.files_path + "/" + self.files[k][0:-4])) | |
def __len__(self): | |
return self.total_num_entries | |
def get_sample(self, entry): | |
"""! | |
Args: | |
@param entry (int): An instance from the labels. | |
@return out: lst - CTU | RD_for_specific_stage | cu_left_of_stg_1 | cu_top_of_stg_1 | cu_left_for_specific_stage | cu_top_for_specific_stage | split_for_specific_stage | |
""" | |
# Initialize variable | |
info_lst = [] | |
# Add Luma | |
CU_Y = entry["real_CTU"] | |
# Add dimension | |
CU_Y = torch.unsqueeze(CU_Y, 0) | |
# Convert to float | |
CU_Y = CU_Y.to(dtype=torch.float32) | |
# CU positions within frame | |
cu_pos = entry["cu_pos"] | |
# CU positions within frame | |
cu_size = entry["cu_size"] | |
# Best split for CU | |
split = entry["split"] | |
# Rate distortion costs | |
RDs = entry['RD'] | |
# Unite | |
RDs = np.reshape(RDs, (-1, 6)) | |
# Save values | |
info_lst.append(CU_Y) | |
info_lst.append(cu_pos) | |
info_lst.append(cu_size) | |
info_lst.append(split) | |
info_lst.append(RDs) | |
return info_lst | |
def obtain_files_sizes(self, files): | |
"""! | |
Args: | |
@param files (list): List containing the names of files with CUs info | |
""" | |
# Initialize variable | |
lst = [] | |
# Create list with the number of entries of each file | |
for f in files: | |
f_path = self.files_path + "/" + f[0:-4] | |
file_obj = dataset_utils.file2lst(f_path) | |
num_entries = len(file_obj) | |
lst.append(num_entries) | |
return lst | |
def select_entry(self, idx): | |
"""! | |
Args: | |
@param idx (int): Index with the position to search for a specific entry | |
""" | |
for k in range(len(self.index_lims)): | |
if idx <= self.index_lims[k]: | |
# Obtain index in the object | |
idx_obj = idx - (self.index_lims[k] - self.lst_entries_nums[k] + 1) | |
file_obj = self.data_files[k] | |
# Obtain entry | |
entry = file_obj[idx_obj] | |
return entry | |
raise Exception("Entry not found!!") | |
def __getitem__(self, idx): | |
if torch.is_tensor(idx): | |
idx = idx.tolist() | |
# Select entry | |
found = False | |
while not found: | |
entry = self.select_entry(idx) # Choose entry | |
### Get info from label and process | |
# Get sample | |
sample = self.get_sample(entry) | |
ctu_size = (sample[0].shape[-2], sample[0].shape[-1]) | |
if (128, 128) == ctu_size: | |
found = True | |
else: | |
idx = (idx + 1) % self.total_num_entries # Increment index until a acceptable entry is found | |
return sample | |
class CUDatasetStg5Compl(CUDatasetBase): | |
"""! | |
Dataset stage oriented with capability of loading different files and it's supposed to be used with the function dataset_utils.change_labels_function_again. For stage 5 | |
This version contains all of the information available from the CU data | |
""" | |
def __init__(self, files_path, channel=0): | |
"""! | |
Args: | |
@param files_path (string): Path to the files with annotations. | |
@param channel: Channel to get for the dataset (0:Luma, 1:Chroma) | |
""" | |
super(CUDatasetStg5Compl, self).__init__(files_path, channel) | |
def get_sample(self, entry): | |
"""! | |
Args: | |
@param entry (int): An instance from the labels. | |
@return out: lst - CTU | RD_for_specific_stage | cu_left_of_stg_1 | cu_top_of_stg_1 | cu_left_for_specific_stage | cu_top_for_specific_stage | split_for_specific_stage | |
""" | |
# Initialize variable | |
info_lst = [] | |
color_ch = entry[14] | |
if color_ch == self.channel: | |
# Add Real CU | |
CU_Y = entry[13] | |
# Add dimension | |
CU_Y = torch.unsqueeze(CU_Y, 0) | |
# Convert to float | |
CU_Y = CU_Y.to(dtype=torch.float32) | |
# CU positions within frame | |
cu_pos = torch.reshape(torch.tensor(entry[4]), (-1, 2)) | |
cu_pos_stg4 = torch.reshape(torch.tensor(entry[3]), (-1, 2)) | |
cu_pos_stg3 = torch.reshape(torch.tensor(entry[2]), (-1, 2)) | |
cu_pos_stg2 = torch.reshape(torch.tensor(entry[1]), (-1, 2)) | |
# CU sizes within frame | |
cu_size = torch.reshape(torch.tensor(entry[8]), (-1, 2)) | |
cu_size_stg4 = torch.reshape(torch.tensor(entry[7]), (-1, 2)) | |
cu_size_stg3 = torch.reshape(torch.tensor(entry[6]), (-1, 2)) | |
cu_size_stg2 = torch.reshape(torch.tensor(entry[5]), (-1, 2)) | |
# Best split for CU | |
split = entry[12] | |
split_stg4 = entry[11] | |
split_stg3 = entry[10] | |
split_stg2 = entry[9] | |
# Other information | |
POC = entry[15] | |
pic_name = entry[16] | |
orig_pos_x = entry[17][1] | |
orig_pos_y = entry[17][0] | |
orig_size_h = entry[18][1] | |
orig_size_w = entry[18][0] | |
# Rate distortion costs | |
RDs = torch.reshape(torch.tensor(entry[0]), (-1, 6)) | |
# Save values | |
info_lst.append(CU_Y) | |
info_lst.append(cu_pos_stg2) | |
info_lst.append(cu_pos_stg3) | |
info_lst.append(cu_pos_stg4) | |
info_lst.append(cu_pos) | |
info_lst.append(cu_size_stg2) | |
info_lst.append(cu_size_stg3) | |
info_lst.append(cu_size_stg4) | |
info_lst.append(cu_size) | |
info_lst.append(split_stg2) | |
info_lst.append(split_stg3) | |
info_lst.append(split_stg4) | |
info_lst.append(split) | |
info_lst.append(RDs) | |
# Other data | |
info_lst.append(orig_pos_x) | |
info_lst.append(orig_pos_y) | |
info_lst.append(orig_size_h) | |
info_lst.append(orig_size_w) | |
info_lst.append(POC) | |
info_lst.append(pic_name) | |
else: | |
raise Exception("This can not happen! This CU color channel should be " + str(self.channel) + "Try generating the labels and obtain just a specific color channel (see dataset_utils)!") | |
return info_lst | |
class CUDatasetStg2Compl(CUDatasetBase): | |
"""! | |
- Dataset stage oriented with capability of loading different files and it's supposed to be used with the function dataset_utils.change_labels_function_again. For stage 2 | |
This version contains all of the information available from the CU data | |
""" | |
def __init__(self, files_path, channel=0): | |
"""! | |
Args: | |
@param files_path (string): Path to the files with annotations. | |
@param channel: Channel to get for the dataset (0:Luma, 1:Chroma) | |
""" | |
super(CUDatasetStg2Compl, self).__init__(files_path, channel) | |
def get_sample(self, entry): | |
"""! | |
Args: | |
@param entry (int): An instance from the labels. | |
@return out: lst - CTU | RD_for_specific_stage | cu_left_of_stg_1 | cu_top_of_stg_1 | cu_left_for_specific_stage | cu_top_for_specific_stage | split_for_specific_stage | |
""" | |
# Initialize variable | |
info_lst = [] | |
color_ch = entry[5] | |
if color_ch == self.channel: | |
# Add Real CU | |
CU_Y = entry[4] | |
# Add dimension | |
CU_Y = torch.unsqueeze(CU_Y, 0) | |
# Convert to float | |
CU_Y = CU_Y.to(dtype=torch.float32) | |
# CU positions within frame | |
cu_pos = torch.reshape(torch.tensor(entry[1]), (-1, 2)) | |
# CU sizes within frame | |
cu_size = torch.reshape(torch.tensor(entry[2]), (-1, 2)) | |
# Best split for CU | |
split = entry[3] | |
# Rate distortion costs | |
RDs = torch.reshape(torch.tensor(entry[0]), (-1, 6)) | |
# Other information | |
POC = entry[6] | |
pic_name = entry[7] | |
orig_pos_x = entry[8][1] | |
orig_pos_y = entry[8][0] | |
orig_size_h = entry[9][1] | |
orig_size_w = entry[9][0] | |
# Save values | |
info_lst.append(CU_Y) | |
info_lst.append(cu_pos) | |
info_lst.append(cu_size) | |
info_lst.append(split) | |
info_lst.append(RDs) | |
# Other data | |
info_lst.append(orig_pos_x) | |
info_lst.append(orig_pos_y) | |
info_lst.append(orig_size_h) | |
info_lst.append(orig_size_w) | |
info_lst.append(POC) | |
info_lst.append(pic_name) | |
else: | |
raise Exception("This can not happen! This CU color channel should be " + str(self.channel) + "Try generating the labels and obtain just a specific color channel (see dataset_utils)!") | |
return info_lst | |
class CUDatasetStg6Compl(CUDatasetBase): | |
"""! | |
Dataset stage oriented with capability of loading different files and it's supposed to be used with the function dataset_utils.change_labels_function_again. For stage 6 | |
This version contains all of the information available from the CU data | |
""" | |
def __init__(self, files_path, channel=0): | |
"""! | |
Args: | |
files_path (string): Path to the files with annotations. | |
root_dir (string): Directory with all the images. | |
channel: Channel to get for the dataset (0:Luma, 1:Chroma) | |
""" | |
super(CUDatasetStg6Compl, self).__init__(files_path, channel) | |
def get_sample(self, entry): | |
"""! | |
Args: | |
@param entry (int): An instance from the labels. | |
""" | |
# Initialize variable | |
info_lst = [] | |
color_ch = entry[17] | |
if color_ch == self.channel: | |
# Add Real CU | |
CU_Y = entry[16] | |
# Add dimension | |
CU_Y = torch.unsqueeze(CU_Y, 0) | |
# Convert to float | |
CU_Y = CU_Y.to(dtype=torch.float32) | |
# CU positions within frame | |
cu_pos = torch.reshape(torch.tensor(entry[5]), (-1, 2)) | |
cu_pos_stg5 = torch.reshape(torch.tensor(entry[4]), (-1, 2)) | |
cu_pos_stg4 = torch.reshape(torch.tensor(entry[3]), (-1, 2)) | |
cu_pos_stg3 = torch.reshape(torch.tensor(entry[2]), (-1, 2)) | |
cu_pos_stg2 = torch.reshape(torch.tensor(entry[1]), (-1, 2)) | |
# CU sizes within frame | |
cu_size = torch.reshape(torch.tensor(entry[10]), (-1, 2)) | |
cu_size_stg5 = torch.reshape(torch.tensor(entry[9]), (-1, 2)) | |
cu_size_stg4 = torch.reshape(torch.tensor(entry[8]), (-1, 2)) | |
cu_size_stg3 = torch.reshape(torch.tensor(entry[7]), (-1, 2)) | |
cu_size_stg2 = torch.reshape(torch.tensor(entry[6]), (-1, 2)) | |
# Best split for CU | |
split = entry[15] | |
split_stg5 = entry[14] | |
split_stg4 = entry[13] | |
split_stg3 = entry[12] | |
split_stg2 = entry[11] | |
# Rate distortion costs | |
RDs = torch.reshape(torch.tensor(entry[0]), (-1, 6)) | |
# Other information | |
POC = entry[18] | |
pic_name = entry[19] | |
orig_pos_x = entry[20][1] | |
orig_pos_y = entry[20][0] | |
orig_size_h = entry[21][1] | |
orig_size_w = entry[21][0] | |
# Save values | |
info_lst.append(CU_Y) | |
info_lst.append(cu_pos_stg2) | |
info_lst.append(cu_pos_stg3) | |
info_lst.append(cu_pos_stg4) | |
info_lst.append(cu_pos_stg5) | |
info_lst.append(cu_pos) | |
info_lst.append(cu_size_stg2) | |
info_lst.append(cu_size_stg3) | |
info_lst.append(cu_size_stg4) | |
info_lst.append(cu_size_stg5) | |
info_lst.append(cu_size) | |
info_lst.append(split_stg2) | |
info_lst.append(split_stg3) | |
info_lst.append(split_stg4) | |
info_lst.append(split_stg5) | |
info_lst.append(split) | |
info_lst.append(RDs) | |
# Other data | |
info_lst.append(orig_pos_x) | |
info_lst.append(orig_pos_y) | |
info_lst.append(orig_size_h) | |
info_lst.append(orig_size_w) | |
info_lst.append(POC) | |
info_lst.append(pic_name) | |
else: | |
raise Exception("This can not happen! This CU color channel should be " + str(self.channel) + "Try generating the labels and obtain just a specific color channel (see dataset_utils)!") | |
return info_lst | |
class CUDatasetStg4Compl(CUDatasetBase): | |
"""! | |
Dataset stage oriented with capability of loading different files and it's supposed to be used with the function dataset_utils.change_labels_function_again. For stage 4 | |
This version contains all of the information available from the CU data | |
""" | |
def __init__(self, files_path, channel=0): | |
"""! | |
Args: | |
@param files_path (string): Path to the files with annotations. | |
@param channel: Channel to get for the dataset (0:Luma, 1:Chroma) | |
""" | |
super(CUDatasetStg4Compl, self).__init__(files_path, channel) | |
def get_sample(self, entry): | |
"""! | |
Args: | |
@param entry (int): An instance from the labels. | |
@param out: lst - CTU | RD_for_specific_stage | cu_left_of_stg_1 | cu_top_of_stg_1 | cu_left_for_specific_stage | cu_top_for_specific_stage | split_for_specific_stage | |
""" | |
# Initialize variable | |
info_lst = [] | |
color_ch = entry[11] | |
if color_ch == self.channel: | |
# Add Real CU | |
CU_Y = entry[10] | |
# Add dimension | |
CU_Y = torch.unsqueeze(CU_Y, 0) | |
# Convert to float | |
CU_Y = CU_Y.to(dtype=torch.float32) | |
# CU positions within frame | |
cu_pos = torch.reshape(torch.tensor(entry[3]), (-1, 2)) | |
cu_pos_stg3 = torch.reshape(torch.tensor(entry[2]), (-1, 2)) | |
cu_pos_stg2 = torch.reshape(torch.tensor(entry[1]), (-1, 2)) | |
# CU sizes within frame | |
cu_size = torch.reshape(torch.tensor(entry[6]), (-1, 2)) | |
cu_size_stg3 = torch.reshape(torch.tensor(entry[5]), (-1, 2)) | |
cu_size_stg2 = torch.reshape(torch.tensor(entry[4]), (-1, 2)) | |
# Best split for CU | |
split = entry[9] | |
split_stg3 = entry[8] | |
split_stg2 = entry[7] | |
# Rate distortion costs | |
RDs = torch.reshape(torch.tensor(entry[0]), (-1, 6)) | |
# Other information | |
POC = entry[12] | |
pic_name = entry[13] | |
orig_pos_x = entry[14][1] | |
orig_pos_y = entry[14][0] | |
orig_size_h = entry[15][1] | |
orig_size_w = entry[15][0] | |
# Save values | |
info_lst.append(CU_Y) | |
info_lst.append(cu_pos_stg2) | |
info_lst.append(cu_pos_stg3) | |
info_lst.append(cu_pos) | |
info_lst.append(cu_size_stg2) | |
info_lst.append(cu_size_stg3) | |
info_lst.append(cu_size) | |
info_lst.append(split_stg2) | |
info_lst.append(split_stg3) | |
info_lst.append(split) | |
info_lst.append(RDs) | |
# Other data | |
info_lst.append(orig_pos_x) | |
info_lst.append(orig_pos_y) | |
info_lst.append(orig_size_h) | |
info_lst.append(orig_size_w) | |
info_lst.append(POC) | |
info_lst.append(pic_name) | |
else: | |
raise Exception("This can not happen! This CU color channel should be " + str(self.channel) + "Try generating the labels and obtain just a specific color channel (see dataset_utils)!") | |
return info_lst | |
class CUDatasetStg3Compl(CUDatasetBase): | |
"""! | |
- Dataset stage oriented with capability of loading different files and it's supposed to be used with the function dataset_utils.change_labels_function_again. For stage 3 | |
This version contains all of the information available from the CU data | |
""" | |
def __init__(self, files_path, channel=0): | |
"""! | |
Args: | |
@param files_path (string): Path to the files with annotations. | |
@param channel: Channel to get for the dataset (0:Luma, 1:Chroma) | |
""" | |
super(CUDatasetStg3Compl, self).__init__(files_path, channel) | |
def get_sample(self, entry): | |
"""! | |
Args: | |
@param entry (int): An instance from the labels. | |
@return out: lst - CTU | RD_for_specific_stage | cu_left_of_stg_1 | cu_top_of_stg_1 | cu_left_for_specific_stage | cu_top_for_specific_stage | split_for_specific_stage | |
""" | |
# Initialize variable | |
info_lst = [] | |
color_ch = entry[8] | |
if color_ch == self.channel: | |
# Add Real CU | |
CU_Y = entry[7] | |
# Add dimension | |
CU_Y = torch.unsqueeze(CU_Y, 0) | |
# Convert to float | |
CU_Y = CU_Y.to(dtype=torch.float32) | |
# CU positions within frame | |
cu_pos = torch.reshape(torch.tensor(entry[2]), (-1, 2)) | |
cu_pos_stg2 = torch.reshape(torch.tensor(entry[1]), (-1, 2)) | |
# CU sizes within frame | |
cu_size = torch.reshape(torch.tensor(entry[4]), (-1, 2)) | |
cu_size_stg2 = torch.reshape(torch.tensor(entry[3]), (-1, 2)) | |
# Best split for CU | |
split = entry[6] | |
split_stg2 = entry[5] | |
# Rate distortion costs | |
RDs = torch.reshape(torch.tensor(entry[0]), (-1, 6)) | |
# Other information | |
POC = entry[9] | |
pic_name = entry[10] | |
orig_pos_x = entry[11][1] | |
orig_pos_y = entry[11][0] | |
orig_size_h = entry[12][1] | |
orig_size_w = entry[12][0] | |
# Save values | |
info_lst.append(CU_Y) | |
info_lst.append(cu_pos_stg2) | |
info_lst.append(cu_pos) | |
info_lst.append(cu_size_stg2) | |
info_lst.append(cu_size) | |
info_lst.append(split_stg2) | |
info_lst.append(split) | |
info_lst.append(RDs) | |
# Other data | |
info_lst.append(orig_pos_x) | |
info_lst.append(orig_pos_y) | |
info_lst.append(orig_size_h) | |
info_lst.append(orig_size_w) | |
info_lst.append(POC) | |
info_lst.append(pic_name) | |
else: | |
raise Exception("This can not happen! This CU color channel should be " + str(self.channel) + "Try generating the labels and obtain just a specific color channel (see dataset_utils)!") | |
return info_lst | |
class SamplerStg6(torch.utils.data.Sampler): | |
def __init__(self, data_source, batch_size): | |
"""! | |
Args: | |
@param data_source (Dataset): dataset to sample from | |
@param batch_size (int): Batch size to sample data | |
""" | |
self.data_source = data_source | |
self.batch_size = batch_size | |
super(SamplerStg6, self).__init__(data_source) | |
def __iter__(self): | |
# Initialize variables | |
data_size = len(self.data_source) | |
indices = np.arange(data_size) | |
indices = np.random.permutation(indices) | |
random.shuffle(indices) | |
dic_type_size = {} | |
# Search for data | |
for i in indices: | |
# Build unique key | |
key = str(self.data_source[i][8].squeeze()[0].item()) + str(self.data_source[i][8].squeeze()[1].item()) \ | |
+ str(self.data_source[i][9].squeeze()[0].item()) + str(self.data_source[i][9].squeeze()[1].item()) \ | |
+ str(self.data_source[i][10].squeeze()[0].item()) + str(self.data_source[i][10].squeeze()[1].item()) | |
# Verify if key exists | |
# If exists, add index to it | |
# Else, create it and add index to it | |
try: | |
dic_type_size[key].append(i) | |
except: | |
dic_type_size[key] = [] | |
dic_type_size[key].append(i) | |
# Check if for each list the size equals the batch size or note | |
# if it does, yield it and reset the list | |
for k in dic_type_size.keys(): | |
if len(dic_type_size[k]) == self.batch_size: | |
yield dic_type_size[k] | |
dic_type_size[k] = [] | |
# if the loop is exited and didn't achieved batch_size | |
# Yield the rest | |
for k in dic_type_size.keys(): | |
if len(dic_type_size[k]) > 0: | |
yield dic_type_size[k] | |
def __len__(self): | |
return self.batch_size | |
class SamplerStg5(torch.utils.data.Sampler): | |
def __init__(self, data_source, batch_size): | |
"""! | |
Args: | |
@param data_source (Dataset): dataset to sample from | |
@param batch_size (int): Batch size to sample data | |
""" | |
self.data_source = data_source | |
self.batch_size = batch_size | |
super(SamplerStg5, self).__init__(data_source) | |
def __iter__(self): | |
# Initialize variables | |
data_size = len(self.data_source) | |
indices = np.arange(data_size) | |
indices = np.random.permutation(indices) | |
random.shuffle(indices) | |
dic_type_size = {} | |
# Search for data | |
for i in indices: | |
# Build unique key | |
key = str(self.data_source[i][7].squeeze()[0].item()) + str(self.data_source[i][7].squeeze()[1].item()) \ | |
+ str(self.data_source[i][8].squeeze()[0].item()) + str(self.data_source[i][8].squeeze()[1].item()) \ | |
# Verify if key exists | |
# If exists, add index to it | |
# Else, create it and add index to it | |
try: | |
dic_type_size[key].append(i) | |
except: | |
dic_type_size[key] = [] | |
dic_type_size[key].append(i) | |
# Check if for each list the size equals the batch size or note | |
# if it does, yield it and reset the list | |
for k in dic_type_size.keys(): | |
if len(dic_type_size[k]) == self.batch_size: | |
yield dic_type_size[k] | |
dic_type_size[k] = [] | |
# if the loop is exited and didn't achieved batch_size | |
# Yield the rest | |
for k in dic_type_size.keys(): | |
if len(dic_type_size[k]) > 0: | |
yield dic_type_size[k] | |
def __len__(self): | |
return self.batch_size | |
class SamplerStg4(torch.utils.data.Sampler): | |
def __init__(self, data_source, batch_size): | |
"""! | |
Args: | |
@param data_source (Dataset): dataset to sample from | |
@param batch_size (int): Batch size to sample data | |
""" | |
self.data_source = data_source | |
self.batch_size = batch_size | |
super(SamplerStg4, self).__init__(data_source) | |
def __iter__(self): | |
# Initialize variables | |
data_size = len(self.data_source) | |
indices = np.arange(data_size) | |
indices = np.random.permutation(indices) | |
random.shuffle(indices) | |
dic_type_size = {} | |
# Search for data for stage 4 and 5 | |
for i in indices: | |
# Build unique key | |
key = str(self.data_source[i][4].squeeze()[0].item()) + str(self.data_source[i][4].squeeze()[1].item()) \ | |
+ str(self.data_source[i][5].squeeze()[0].item()) + str(self.data_source[i][5].squeeze()[1].item())\ | |
+ str(self.data_source[i][6].squeeze()[0].item()) + str(self.data_source[i][6].squeeze()[1].item()) | |
# Verify if key exists | |
# If exists, add index to it | |
# Else, create it and add index to it | |
try: | |
dic_type_size[key].append(i) | |
except: | |
dic_type_size[key] = [] | |
dic_type_size[key].append(i) | |
# Check if for each list the size equals the batch size or note | |
# if it does, yield it and reset the list | |
for k in dic_type_size.keys(): | |
if len(dic_type_size[k]) == self.batch_size: | |
yield dic_type_size[k] | |
dic_type_size[k] = [] | |
# if the loop is exited and didn't achieved batch_size | |
# Yield the rest | |
for k in dic_type_size.keys(): | |
if len(dic_type_size[k]) > 0: | |
yield dic_type_size[k] | |
def __len__(self): | |
return self.batch_size | |