Dean commited on
Commit
79fd7d0
·
1 Parent(s): 0b86a0a

Fixed a bug in the training stage where the model was not saved, commiting before training on colab

Browse files
Files changed (2) hide show
  1. dvc.yaml +1 -1
  2. src/code/training.py +5 -12
dvc.yaml CHANGED
@@ -9,7 +9,7 @@ stages:
9
  outs:
10
  - src/data/processed/
11
  train:
12
- cmd: python3 src/code/training.py src/data/processed src/models
13
  deps:
14
  - src/code/training.py
15
  - src/data/processed/
 
9
  outs:
10
  - src/data/processed/
11
  train:
12
+ cmd: python3 src/code/training.py src/data/processed
13
  deps:
14
  - src/code/training.py
15
  - src/data/processed/
src/code/training.py CHANGED
@@ -17,20 +17,13 @@ def create_data(data_path):
17
  return data
18
 
19
 
20
- def train(data):
21
- learner = unet_learner(data, resnet34, metrics=rmse, wd=1e-2, n_out=1, loss_func=MSELossFlat())
22
- learner.fine_tune(1)
23
-
24
-
25
  if __name__ == "__main__":
26
- if len(sys.argv) < 3:
27
- print("usage: %s <data_path> <out_folder>" % sys.argv[0], file=sys.stderr)
28
  sys.exit(0)
29
 
30
  data = create_data(Path(sys.argv[1]))
31
- data.batch_size = 1
32
- data.num_workers = 0
33
- learner = train(data)
34
 
35
- learner.save(sys.argv[2])
36
- learner.show_results()
 
17
  return data
18
 
19
 
 
 
 
 
 
20
  if __name__ == "__main__":
21
+ if len(sys.argv) < 2:
22
+ print("usage: %s <data_path>" % sys.argv[0], file=sys.stderr)
23
  sys.exit(0)
24
 
25
  data = create_data(Path(sys.argv[1]))
26
+ learner = unet_learner(data, resnet34, metrics=rmse, wd=1e-2, n_out=1, loss_func=MSELossFlat(), path='src/')
27
+ learner.fine_tune(1)
 
28
 
29
+ learner.save('model')