Dean commited on
Commit
3410172
·
1 Parent(s): a711240

Revert change to dataloader that uses is_test flag.

Browse files
src/code/custom_data_loading.py CHANGED
@@ -14,15 +14,14 @@ class ImageImageDataLoaders(DataLoaders):
14
  """Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"""
15
  @classmethod
16
  @delegates(DataLoaders.from_dblock)
17
- def from_label_func(cls, path, filenames, label_func, is_test, valid_pct=0.2, seed=None, item_transforms=None,
18
  batch_transforms=None, **kwargs):
19
  """Create from list of `fnames` in `path`s with `label_func`."""
20
  datablock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
21
  get_y=label_func,
 
22
  item_tfms=item_transforms,
23
  batch_tfms=batch_transforms)
24
- if not is_test:
25
- datablock.splitter = RandomSplitter(valid_pct, seed=seed)
26
  res = cls.from_dblock(datablock, filenames, path=path, **kwargs)
27
  return res
28
 
@@ -34,12 +33,11 @@ def get_y_fn(x):
34
  return y
35
 
36
 
37
- def create_data(data_path, is_test=False):
38
  filenames = get_files(data_path, extensions='.jpg')
39
  if len(filenames) == 0:
40
  raise ValueError("Could not find any files in the given path")
41
  dataset = ImageImageDataLoaders.from_label_func(data_path,
42
- is_test=is_test,
43
  seed=42,
44
  bs=4, num_workers=0,
45
  filenames=filenames,
 
14
  """Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"""
15
  @classmethod
16
  @delegates(DataLoaders.from_dblock)
17
+ def from_label_func(cls, path, filenames, label_func, valid_pct=0.2, seed=None, item_transforms=None,
18
  batch_transforms=None, **kwargs):
19
  """Create from list of `fnames` in `path`s with `label_func`."""
20
  datablock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
21
  get_y=label_func,
22
+ splitter=RandomSplitter(valid_pct, seed=seed),
23
  item_tfms=item_transforms,
24
  batch_tfms=batch_transforms)
 
 
25
  res = cls.from_dblock(datablock, filenames, path=path, **kwargs)
26
  return res
27
 
 
33
  return y
34
 
35
 
36
+ def create_data(data_path):
37
  filenames = get_files(data_path, extensions='.jpg')
38
  if len(filenames) == 0:
39
  raise ValueError("Could not find any files in the given path")
40
  dataset = ImageImageDataLoaders.from_label_func(data_path,
 
41
  seed=42,
42
  bs=4, num_workers=0,
43
  filenames=filenames,
src/code/eval.py CHANGED
@@ -1,5 +1,5 @@
1
  import sys
2
- from fastai.vision.all import unet_learner, Path, resnet34, MSELossFlat
3
  import torch
4
  from src.code.custom_data_loading import create_data
5
  from dagshub.fastai import DAGsHubLogger
@@ -39,8 +39,11 @@ if __name__ == "__main__":
39
  sys.exit(0)
40
 
41
  data_path = Path(sys.argv[1])
42
- data = create_data(data_path, is_test=True)
43
 
 
 
 
44
  learner = unet_learner(data,
45
  resnet34,
46
  n_out=3,
@@ -48,5 +51,5 @@ if __name__ == "__main__":
48
  path='src/',
49
  model_dir='models')
50
  learner = learner.load('model')
51
- predictions, targets = learner.get_preds()
52
  print(compute_errors(targets, predictions))
 
1
  import sys
2
+ from fastai.vision.all import unet_learner, Path, resnet34, MSELossFlat, get_files, L
3
  import torch
4
  from src.code.custom_data_loading import create_data
5
  from dagshub.fastai import DAGsHubLogger
 
39
  sys.exit(0)
40
 
41
  data_path = Path(sys.argv[1])
42
+ data = create_data(data_path)
43
 
44
+ filenames = get_files(Path(sys.argv[1]), extensions='.jpg')
45
+ test_files = L([Path(i) for i in filenames])
46
+ test_dl = data.test_dl(test_files, with_labels=True)
47
  learner = unet_learner(data,
48
  resnet34,
49
  n_out=3,
 
51
  path='src/',
52
  model_dir='models')
53
  learner = learner.load('model')
54
+ predictions, targets = learner.get_preds(dl=test_dl)
55
  print(compute_errors(targets, predictions))