cabasus / funcs /dataloader.py
arcan3's picture
added finalise
41ed540
import glob, json, os
import torch
import warnings
import numpy as np
from torch.utils.data import Dataset
class BaseDataset2(Dataset):
"""Template class for all datasets in the project."""
def __init__(self, x, y):
"""Initialize dataset.
Args:
x(ndarray): Input features.
y(ndarray): Targets.
"""
self.data = torch.from_numpy(x).float()
self.targets = torch.from_numpy(y).float()
self.latents = None
self.labels = None
self.is_radial = []
self.partition = True
def __getitem__(self, index):
return self.data[index], self.targets[index], index
def __len__(self):
return len(self.data)
def numpy(self, idx=None):
"""Get dataset as ndarray.
Specify indices to return a subset of the dataset, otherwise return whole dataset.
Args:
idx(int, optional): Specify index or indices to return.
Returns:
ndarray: Return flattened dataset as a ndarray.
"""
n = len(self)
data = self.data.numpy().reshape((n, -1))
if idx is None:
return data, self.targets.numpy()
else:
return data[idx], self.targets[idx].numpy()
def get_latents(self):
"""Get latent variables.
Returns:
latents(ndarray): Latent variables for each sample.
"""
return self.latents
def load_json(file_path):
with open(file_path, 'r') as f:
data = json.load(f)
return data
def read_json_files(file):
data_x = []
data_y = []
samples = load_json(file)
valid_samples = 0
for sample in samples:
data = []
skip_sample = False
for key in ['AX1', 'AX2', 'AX3', 'AX4', 'AY1', 'AY2', 'AY3', 'AY4', 'AZ1', 'AZ2', 'AZ3', 'AZ4', 'GX1', 'GX2', 'GX3', 'GX4', 'GY1', 'GY2', 'GY3', 'GY4', 'GZ1', 'GZ2', 'GZ3', 'GZ4', 'GZ1_precise_time_diff', 'GZ2_precise_time_diff', 'GZ3_precise_time_diff', 'GZ4_precise_time_diff', 'precise_time_diff']:
if key in sample:
if key.endswith('_precise_time_diff') or key == 'precise_time_diff':
if sample[key] is None:
skip_sample = True
break
data.append(round(sample[key])*20)
else:
data.extend(sample[key])
else:
warnings.warn(f"KeyError: {key} not found in JSON file: {file}")
if skip_sample:
#warnings.warn(f"Skipped sample with null values in JSON file: {json_file}")
continue
if len(data) != 768*2 + 5: # 24 keys * 64 values each + 5 additional values
warnings.warn(f"Incomplete sample in JSON file: {file}")
continue
valid_samples += 1
tensor = torch.tensor(data, dtype=torch.float32)
data_x.append(tensor)
data_y.append(1)
if valid_samples == 0:
warnings.warn(f"No valid samples found in JSON file: {file}")
if not data_x:
raise ValueError("No valid samples found in all the JSON files.")
return torch.stack(data_x), torch.tensor(data_y, dtype=torch.long)