Spaces:
Paused
Paused
Dean
commited on
Commit
·
3410172
1
Parent(s):
a711240
Revert change to dataloader that uses is_test flag.
Browse files- src/code/custom_data_loading.py +3 -5
- src/code/eval.py +6 -3
src/code/custom_data_loading.py
CHANGED
@@ -14,15 +14,14 @@ class ImageImageDataLoaders(DataLoaders):
|
|
14 |
"""Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"""
|
15 |
@classmethod
|
16 |
@delegates(DataLoaders.from_dblock)
|
17 |
-
def from_label_func(cls, path, filenames, label_func,
|
18 |
batch_transforms=None, **kwargs):
|
19 |
"""Create from list of `fnames` in `path`s with `label_func`."""
|
20 |
datablock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
|
21 |
get_y=label_func,
|
|
|
22 |
item_tfms=item_transforms,
|
23 |
batch_tfms=batch_transforms)
|
24 |
-
if not is_test:
|
25 |
-
datablock.splitter = RandomSplitter(valid_pct, seed=seed)
|
26 |
res = cls.from_dblock(datablock, filenames, path=path, **kwargs)
|
27 |
return res
|
28 |
|
@@ -34,12 +33,11 @@ def get_y_fn(x):
|
|
34 |
return y
|
35 |
|
36 |
|
37 |
-
def create_data(data_path
|
38 |
filenames = get_files(data_path, extensions='.jpg')
|
39 |
if len(filenames) == 0:
|
40 |
raise ValueError("Could not find any files in the given path")
|
41 |
dataset = ImageImageDataLoaders.from_label_func(data_path,
|
42 |
-
is_test=is_test,
|
43 |
seed=42,
|
44 |
bs=4, num_workers=0,
|
45 |
filenames=filenames,
|
|
|
14 |
"""Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"""
|
15 |
@classmethod
|
16 |
@delegates(DataLoaders.from_dblock)
|
17 |
+
def from_label_func(cls, path, filenames, label_func, valid_pct=0.2, seed=None, item_transforms=None,
|
18 |
batch_transforms=None, **kwargs):
|
19 |
"""Create from list of `fnames` in `path`s with `label_func`."""
|
20 |
datablock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
|
21 |
get_y=label_func,
|
22 |
+
splitter=RandomSplitter(valid_pct, seed=seed),
|
23 |
item_tfms=item_transforms,
|
24 |
batch_tfms=batch_transforms)
|
|
|
|
|
25 |
res = cls.from_dblock(datablock, filenames, path=path, **kwargs)
|
26 |
return res
|
27 |
|
|
|
33 |
return y
|
34 |
|
35 |
|
36 |
+
def create_data(data_path):
|
37 |
filenames = get_files(data_path, extensions='.jpg')
|
38 |
if len(filenames) == 0:
|
39 |
raise ValueError("Could not find any files in the given path")
|
40 |
dataset = ImageImageDataLoaders.from_label_func(data_path,
|
|
|
41 |
seed=42,
|
42 |
bs=4, num_workers=0,
|
43 |
filenames=filenames,
|
src/code/eval.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import sys
|
2 |
-
from fastai.vision.all import unet_learner, Path, resnet34, MSELossFlat
|
3 |
import torch
|
4 |
from src.code.custom_data_loading import create_data
|
5 |
from dagshub.fastai import DAGsHubLogger
|
@@ -39,8 +39,11 @@ if __name__ == "__main__":
|
|
39 |
sys.exit(0)
|
40 |
|
41 |
data_path = Path(sys.argv[1])
|
42 |
-
data = create_data(data_path
|
43 |
|
|
|
|
|
|
|
44 |
learner = unet_learner(data,
|
45 |
resnet34,
|
46 |
n_out=3,
|
@@ -48,5 +51,5 @@ if __name__ == "__main__":
|
|
48 |
path='src/',
|
49 |
model_dir='models')
|
50 |
learner = learner.load('model')
|
51 |
-
predictions, targets = learner.get_preds()
|
52 |
print(compute_errors(targets, predictions))
|
|
|
1 |
import sys
|
2 |
+
from fastai.vision.all import unet_learner, Path, resnet34, MSELossFlat, get_files, L
|
3 |
import torch
|
4 |
from src.code.custom_data_loading import create_data
|
5 |
from dagshub.fastai import DAGsHubLogger
|
|
|
39 |
sys.exit(0)
|
40 |
|
41 |
data_path = Path(sys.argv[1])
|
42 |
+
data = create_data(data_path)
|
43 |
|
44 |
+
filenames = get_files(Path(sys.argv[1]), extensions='.jpg')
|
45 |
+
test_files = L([Path(i) for i in filenames])
|
46 |
+
test_dl = data.test_dl(test_files, with_labels=True)
|
47 |
learner = unet_learner(data,
|
48 |
resnet34,
|
49 |
n_out=3,
|
|
|
51 |
path='src/',
|
52 |
model_dir='models')
|
53 |
learner = learner.load('model')
|
54 |
+
predictions, targets = learner.get_preds(dl=test_dl)
|
55 |
print(compute_errors(targets, predictions))
|