Spaces:
Build error
Build error
import glob, json, os | |
import torch | |
import warnings | |
import numpy as np | |
from torch.utils.data import Dataset | |
class BaseDataset2(Dataset): | |
"""Template class for all datasets in the project.""" | |
def __init__(self, x, y): | |
"""Initialize dataset. | |
Args: | |
x(ndarray): Input features. | |
y(ndarray): Targets. | |
""" | |
self.data = torch.from_numpy(x).float() | |
self.targets = torch.from_numpy(y).float() | |
self.latents = None | |
self.labels = None | |
self.is_radial = [] | |
self.partition = True | |
def __getitem__(self, index): | |
return self.data[index], self.targets[index], index | |
def __len__(self): | |
return len(self.data) | |
def numpy(self, idx=None): | |
"""Get dataset as ndarray. | |
Specify indices to return a subset of the dataset, otherwise return whole dataset. | |
Args: | |
idx(int, optional): Specify index or indices to return. | |
Returns: | |
ndarray: Return flattened dataset as a ndarray. | |
""" | |
n = len(self) | |
data = self.data.numpy().reshape((n, -1)) | |
if idx is None: | |
return data, self.targets.numpy() | |
else: | |
return data[idx], self.targets[idx].numpy() | |
def get_latents(self): | |
"""Get latent variables. | |
Returns: | |
latents(ndarray): Latent variables for each sample. | |
""" | |
return self.latents | |
def load_json(file_path): | |
with open(file_path, 'r') as f: | |
data = json.load(f) | |
return data | |
def read_json_files(file): | |
data_x = [] | |
data_y = [] | |
samples = load_json(file) | |
valid_samples = 0 | |
for sample in samples: | |
data = [] | |
skip_sample = False | |
for key in ['AX1', 'AX2', 'AX3', 'AX4', 'AY1', 'AY2', 'AY3', 'AY4', 'AZ1', 'AZ2', 'AZ3', 'AZ4', 'GX1', 'GX2', 'GX3', 'GX4', 'GY1', 'GY2', 'GY3', 'GY4', 'GZ1', 'GZ2', 'GZ3', 'GZ4', 'GZ1_precise_time_diff', 'GZ2_precise_time_diff', 'GZ3_precise_time_diff', 'GZ4_precise_time_diff', 'precise_time_diff']: | |
if key in sample: | |
if key.endswith('_precise_time_diff') or key == 'precise_time_diff': | |
if sample[key] is None: | |
skip_sample = True | |
break | |
data.append(round(sample[key])*20) | |
else: | |
data.extend(sample[key]) | |
else: | |
warnings.warn(f"KeyError: {key} not found in JSON file: {file}") | |
if skip_sample: | |
#warnings.warn(f"Skipped sample with null values in JSON file: {json_file}") | |
continue | |
if len(data) != 768*2 + 5: # 24 keys * 64 values each + 5 additional values | |
warnings.warn(f"Incomplete sample in JSON file: {file}") | |
continue | |
valid_samples += 1 | |
tensor = torch.tensor(data, dtype=torch.float32) | |
data_x.append(tensor) | |
data_y.append(1) | |
if valid_samples == 0: | |
warnings.warn(f"No valid samples found in JSON file: {file}") | |
if not data_x: | |
raise ValueError("No valid samples found in all the JSON files.") | |
return torch.stack(data_x), torch.tensor(data_y, dtype=torch.long) |