Spaces:
Running
Running
rzimmerdev
commited on
Commit
·
1de9461
1
Parent(s):
49b098d
fix: Changed default dataset image type
Browse files- src/dataset.py +5 -4
src/dataset.py
CHANGED
@@ -2,10 +2,11 @@
|
|
2 |
# coding: utf-8
|
3 |
import gzip
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
from torch.utils.data import Dataset
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
def load_mnist(download_dir):
|
@@ -37,7 +38,7 @@ class DatasetMNIST(Dataset):
|
|
37 |
def __getitem__(self, n):
|
38 |
if n > self.total:
|
39 |
raise ValueError(f"Dataset doesn't have enough elements to suffice request of {n} elements.")
|
40 |
-
return self.data[n]
|
41 |
|
42 |
def __len__(self):
|
43 |
return len(self.data)
|
|
|
2 |
# coding: utf-8
|
3 |
import gzip
|
4 |
|
5 |
+
import torch
|
|
|
|
|
6 |
from torch.utils.data import Dataset
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from src.downloader import download_dataset
|
10 |
|
11 |
|
12 |
def load_mnist(download_dir):
|
|
|
38 |
def __getitem__(self, n):
|
39 |
if n > self.total:
|
40 |
raise ValueError(f"Dataset doesn't have enough elements to suffice request of {n} elements.")
|
41 |
+
return torch.tensor(self.data[n][0].reshape(1, 28, 28), dtype=torch.float32), torch.tensor(self.data[n][1])
|
42 |
|
43 |
def __len__(self):
|
44 |
return len(self.data)
|