Dean commited on
Commit
b013dfa
·
1 Parent(s): 889b443

reformatting the training file

Browse files
Files changed (1) hide show
  1. src/code/training.py +9 -4
src/code/training.py CHANGED
@@ -3,11 +3,14 @@ import sys
3
  from fastai.vision.all import *
4
  from torchvision.utils import save_image
5
 
 
6
  class ImageImageDataLoaders(DataLoaders):
7
  "Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"
 
8
  @classmethod
9
  @delegates(DataLoaders.from_dblock)
10
- def from_label_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None, item_tfms=None, batch_tfms=None, **kwargs):
 
11
  "Create from list of `fnames` in `path`s with `label_func`."
12
  dblock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
13
  splitter=RandomSplitter(valid_pct, seed=seed),
@@ -26,8 +29,9 @@ def get_y_fn(x):
26
 
27
 
28
  def create_data(data_path):
29
- fnames = get_files(data_path/'train', extensions='.jpg')
30
- data = ImageImageDataLoaders.from_label_func(data_path/'train', seed=42, bs=4, num_workers=0, fnames=fnames, label_func=get_y_fn)
 
31
  return data
32
 
33
 
@@ -37,7 +41,8 @@ if __name__ == "__main__":
37
  sys.exit(0)
38
 
39
  data = create_data(Path(sys.argv[1]))
40
- learner = unet_learner(data, resnet34, metrics=rmse, wd=1e-2, n_out=3, loss_func=MSELossFlat(), path='src/test/')
 
41
  print("Training model...")
42
  learner.fine_tune(1)
43
  print("Saving model...")
 
3
  from fastai.vision.all import *
4
  from torchvision.utils import save_image
5
 
6
+
7
  class ImageImageDataLoaders(DataLoaders):
8
  "Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"
9
+
10
  @classmethod
11
  @delegates(DataLoaders.from_dblock)
12
+ def from_label_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None, item_tfms=None,
13
+ batch_tfms=None, **kwargs):
14
  "Create from list of `fnames` in `path`s with `label_func`."
15
  dblock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
16
  splitter=RandomSplitter(valid_pct, seed=seed),
 
29
 
30
 
31
  def create_data(data_path):
32
+ fnames = get_files(data_path / 'train', extensions='.jpg')
33
+ data = ImageImageDataLoaders.from_label_func(data_path / 'train', seed=42, bs=4, num_workers=0,
34
+ fnames=fnames, label_func=get_y_fn)
35
  return data
36
 
37
 
 
41
  sys.exit(0)
42
 
43
  data = create_data(Path(sys.argv[1]))
44
+ learner = unet_learner(data, resnet34, metrics=rmse, wd=1e-2, n_out=3, loss_func=MSELossFlat(),
45
+ path='src/test/')
46
  print("Training model...")
47
  learner.fine_tune(1)
48
  print("Saving model...")