import glob, json, os import torch import warnings import numpy as np from torch.utils.data import Dataset class BaseDataset2(Dataset): """Template class for all datasets in the project.""" def __init__(self, x, y): """Initialize dataset. Args: x(ndarray): Input features. y(ndarray): Targets. """ self.data = torch.from_numpy(x).float() self.targets = torch.from_numpy(y).float() self.latents = None self.labels = None self.is_radial = [] self.partition = True def __getitem__(self, index): return self.data[index], self.targets[index], index def __len__(self): return len(self.data) def numpy(self, idx=None): """Get dataset as ndarray. Specify indices to return a subset of the dataset, otherwise return whole dataset. Args: idx(int, optional): Specify index or indices to return. Returns: ndarray: Return flattened dataset as a ndarray. """ n = len(self) data = self.data.numpy().reshape((n, -1)) if idx is None: return data, self.targets.numpy() else: return data[idx], self.targets[idx].numpy() def get_latents(self): """Get latent variables. Returns: latents(ndarray): Latent variables for each sample. """ return self.latents def load_json(file_path): with open(file_path, 'r') as f: data = json.load(f) return data def read_json_files(file): data_x = [] data_y = [] samples = load_json(file) valid_samples = 0 for sample in samples: data = [] skip_sample = False for key in ['AX1', 'AX2', 'AX3', 'AX4', 'AY1', 'AY2', 'AY3', 'AY4', 'AZ1', 'AZ2', 'AZ3', 'AZ4', 'GX1', 'GX2', 'GX3', 'GX4', 'GY1', 'GY2', 'GY3', 'GY4', 'GZ1', 'GZ2', 'GZ3', 'GZ4', 'GZ1_precise_time_diff', 'GZ2_precise_time_diff', 'GZ3_precise_time_diff', 'GZ4_precise_time_diff', 'precise_time_diff']: if key in sample: if key.endswith('_precise_time_diff') or key == 'precise_time_diff': if sample[key] is None: skip_sample = True break data.append(round(sample[key])*20) else: data.extend(sample[key]) else: warnings.warn(f"KeyError: {key} not found in JSON file: {file}") if skip_sample: #warnings.warn(f"Skipped sample with null values in JSON file: {json_file}") continue if len(data) != 768*2 + 5: # 24 keys * 64 values each + 5 additional values warnings.warn(f"Incomplete sample in JSON file: {file}") continue valid_samples += 1 tensor = torch.tensor(data, dtype=torch.float32) data_x.append(tensor) data_y.append(1) if valid_samples == 0: warnings.warn(f"No valid samples found in JSON file: {file}") if not data_x: raise ValueError("No valid samples found in all the JSON files.") return torch.stack(data_x), torch.tensor(data_y, dtype=torch.long)