jcarnero commited on
Commit
92b515f
·
1 Parent(s): 57f6a10

new train script

Browse files
Files changed (1) hide show
  1. training/birds/train.py +80 -0
training/birds/train.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from fastai.vision.data import (
5
+ IndexSplitter,
6
+ DataBlock,
7
+ ImageBlock,
8
+ CategoryBlock,
9
+ RegexLabeller,
10
+ )
11
+ from fastai.vision.augment import (
12
+ RandomResizedCrop,
13
+ aug_transforms,
14
+ Normalize,
15
+ imagenet_stats,
16
+ )
17
+
18
+ from fastai.callback import schedule # noqa: F401
19
+ from fastai.vision.learner import vision_learner, accuracy
20
+
21
+ from birds import config
22
+ from birds.utils.kaggle import download_dataset
23
+
24
+
25
+ def get_birds_images(path):
26
+ with open(path / "images.txt", "r") as file:
27
+ lines = [
28
+ path.resolve() / "images" / line.strip().split()[1]
29
+ for line in file.readlines()
30
+ ]
31
+ return lines
32
+
33
+
34
+ def BirdsSplitter(path):
35
+ with open(path / "train_test_split.txt", "r") as file:
36
+ valid_idx = [
37
+ int(line.strip().split()[0]) - 1
38
+ for line in file.readlines()
39
+ if line.strip().split()[1] == "1"
40
+ ]
41
+ return IndexSplitter(valid_idx)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ bs = 64
46
+
47
+ if download_dataset(config.OWNER, config.DATASET, config.DATA_PATH):
48
+ import tarfile
49
+
50
+ with tarfile.open(Path(config.DATA_PATH) / "CUB_200_2011.tgz", "r:gz") as tar:
51
+ tar.extractall(path=config.DATA_PATH)
52
+
53
+ os.remove(Path(config.DATA_PATH) / "CUB_200_2011.tgz")
54
+ os.remove(Path(config.DATA_PATH) / "segmentations.tgz")
55
+
56
+ path = Path(config.DATA_PATH) / "CUB_200_2011"
57
+
58
+ item_tfms = RandomResizedCrop(460, min_scale=0.75, ratio=(1.0, 1.0))
59
+ batch_tfms = [
60
+ *aug_transforms(size=224, max_warp=0),
61
+ Normalize.from_stats(*imagenet_stats),
62
+ ]
63
+
64
+ birds = DataBlock(
65
+ blocks=(ImageBlock, CategoryBlock),
66
+ get_items=get_birds_images,
67
+ splitter=BirdsSplitter(path),
68
+ get_y=RegexLabeller(pat=r"/([^/]+)_\d+_\d+\.jpg"),
69
+ item_tfms=item_tfms,
70
+ batch_tfms=batch_tfms,
71
+ )
72
+
73
+ dls = birds.dataloaders(path)
74
+
75
+ learner = vision_learner(dls, "vit_tiny_patch16_224", metrics=[accuracy])
76
+
77
+ learner.fine_tune(7, base_lr=0.001, freeze_epochs=12)
78
+
79
+ learner.export("models/vit_exported")
80
+ learner.save("vit_saved", with_opt=False)