File size: 326 Bytes
122d428
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
from torch.utils.data import Dataset

# Create a custom dataset class that takes a single input sample
class SingleInputDataset(Dataset):
    def __init__(self, input_single):
        self.sample = input_single
        
    def __len__(self):
        return 1
    
    def __getitem__(self, index):
        return self.sample