Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import json | |
import os | |
from abc import abstractmethod | |
from pathlib import Path | |
import json5 | |
import torch | |
import yaml | |
# TODO: for training and validating | |
class BaseDataset(torch.utils.data.Dataset): | |
r"""Base dataset for training and validating.""" | |
def __init__(self, args, cfg, is_valid=False): | |
pass | |
class BaseTestDataset(torch.utils.data.Dataset): | |
r"""Test dataset for inference.""" | |
def __init__(self, args=None, cfg=None, infer_type="from_dataset"): | |
assert infer_type in ["from_dataset", "from_file"] | |
self.args = args | |
self.cfg = cfg | |
self.infer_type = infer_type | |
def __getitem__(self, index): | |
pass | |
def __len__(self): | |
return len(self.metadata) | |
def get_metadata(self): | |
path = Path(self.args.source) | |
if path.suffix == ".json" or path.suffix == ".jsonc": | |
metadata = json5.load(open(self.args.source, "r")) | |
elif path.suffix == ".yaml" or path.suffix == ".yml": | |
metadata = yaml.full_load(open(self.args.source, "r")) | |
else: | |
raise ValueError(f"Unsupported file type: {path.suffix}") | |
return metadata | |