Spaces:
Paused
Paused
Dean
commited on
Commit
•
b013dfa
1
Parent(s):
889b443
reformatting the training file
Browse files- src/code/training.py +9 -4
src/code/training.py
CHANGED
@@ -3,11 +3,14 @@ import sys
|
|
3 |
from fastai.vision.all import *
|
4 |
from torchvision.utils import save_image
|
5 |
|
|
|
6 |
class ImageImageDataLoaders(DataLoaders):
|
7 |
"Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"
|
|
|
8 |
@classmethod
|
9 |
@delegates(DataLoaders.from_dblock)
|
10 |
-
def from_label_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None, item_tfms=None,
|
|
|
11 |
"Create from list of `fnames` in `path`s with `label_func`."
|
12 |
dblock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
|
13 |
splitter=RandomSplitter(valid_pct, seed=seed),
|
@@ -26,8 +29,9 @@ def get_y_fn(x):
|
|
26 |
|
27 |
|
28 |
def create_data(data_path):
|
29 |
-
fnames = get_files(data_path/'train', extensions='.jpg')
|
30 |
-
data = ImageImageDataLoaders.from_label_func(data_path/'train', seed=42, bs=4, num_workers=0,
|
|
|
31 |
return data
|
32 |
|
33 |
|
@@ -37,7 +41,8 @@ if __name__ == "__main__":
|
|
37 |
sys.exit(0)
|
38 |
|
39 |
data = create_data(Path(sys.argv[1]))
|
40 |
-
learner = unet_learner(data, resnet34, metrics=rmse, wd=1e-2, n_out=3, loss_func=MSELossFlat(),
|
|
|
41 |
print("Training model...")
|
42 |
learner.fine_tune(1)
|
43 |
print("Saving model...")
|
|
|
3 |
from fastai.vision.all import *
|
4 |
from torchvision.utils import save_image
|
5 |
|
6 |
+
|
7 |
class ImageImageDataLoaders(DataLoaders):
|
8 |
"Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"
|
9 |
+
|
10 |
@classmethod
|
11 |
@delegates(DataLoaders.from_dblock)
|
12 |
+
def from_label_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None, item_tfms=None,
|
13 |
+
batch_tfms=None, **kwargs):
|
14 |
"Create from list of `fnames` in `path`s with `label_func`."
|
15 |
dblock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
|
16 |
splitter=RandomSplitter(valid_pct, seed=seed),
|
|
|
29 |
|
30 |
|
31 |
def create_data(data_path):
|
32 |
+
fnames = get_files(data_path / 'train', extensions='.jpg')
|
33 |
+
data = ImageImageDataLoaders.from_label_func(data_path / 'train', seed=42, bs=4, num_workers=0,
|
34 |
+
fnames=fnames, label_func=get_y_fn)
|
35 |
return data
|
36 |
|
37 |
|
|
|
41 |
sys.exit(0)
|
42 |
|
43 |
data = create_data(Path(sys.argv[1]))
|
44 |
+
learner = unet_learner(data, resnet34, metrics=rmse, wd=1e-2, n_out=3, loss_func=MSELossFlat(),
|
45 |
+
path='src/test/')
|
46 |
print("Training model...")
|
47 |
learner.fine_tune(1)
|
48 |
print("Saving model...")
|