Spaces:
Sleeping
Sleeping
new train script
Browse files- training/birds/train.py +80 -0
training/birds/train.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from fastai.vision.data import (
|
5 |
+
IndexSplitter,
|
6 |
+
DataBlock,
|
7 |
+
ImageBlock,
|
8 |
+
CategoryBlock,
|
9 |
+
RegexLabeller,
|
10 |
+
)
|
11 |
+
from fastai.vision.augment import (
|
12 |
+
RandomResizedCrop,
|
13 |
+
aug_transforms,
|
14 |
+
Normalize,
|
15 |
+
imagenet_stats,
|
16 |
+
)
|
17 |
+
|
18 |
+
from fastai.callback import schedule # noqa: F401
|
19 |
+
from fastai.vision.learner import vision_learner, accuracy
|
20 |
+
|
21 |
+
from birds import config
|
22 |
+
from birds.utils.kaggle import download_dataset
|
23 |
+
|
24 |
+
|
25 |
+
def get_birds_images(path):
|
26 |
+
with open(path / "images.txt", "r") as file:
|
27 |
+
lines = [
|
28 |
+
path.resolve() / "images" / line.strip().split()[1]
|
29 |
+
for line in file.readlines()
|
30 |
+
]
|
31 |
+
return lines
|
32 |
+
|
33 |
+
|
34 |
+
def BirdsSplitter(path):
|
35 |
+
with open(path / "train_test_split.txt", "r") as file:
|
36 |
+
valid_idx = [
|
37 |
+
int(line.strip().split()[0]) - 1
|
38 |
+
for line in file.readlines()
|
39 |
+
if line.strip().split()[1] == "1"
|
40 |
+
]
|
41 |
+
return IndexSplitter(valid_idx)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
bs = 64
|
46 |
+
|
47 |
+
if download_dataset(config.OWNER, config.DATASET, config.DATA_PATH):
|
48 |
+
import tarfile
|
49 |
+
|
50 |
+
with tarfile.open(Path(config.DATA_PATH) / "CUB_200_2011.tgz", "r:gz") as tar:
|
51 |
+
tar.extractall(path=config.DATA_PATH)
|
52 |
+
|
53 |
+
os.remove(Path(config.DATA_PATH) / "CUB_200_2011.tgz")
|
54 |
+
os.remove(Path(config.DATA_PATH) / "segmentations.tgz")
|
55 |
+
|
56 |
+
path = Path(config.DATA_PATH) / "CUB_200_2011"
|
57 |
+
|
58 |
+
item_tfms = RandomResizedCrop(460, min_scale=0.75, ratio=(1.0, 1.0))
|
59 |
+
batch_tfms = [
|
60 |
+
*aug_transforms(size=224, max_warp=0),
|
61 |
+
Normalize.from_stats(*imagenet_stats),
|
62 |
+
]
|
63 |
+
|
64 |
+
birds = DataBlock(
|
65 |
+
blocks=(ImageBlock, CategoryBlock),
|
66 |
+
get_items=get_birds_images,
|
67 |
+
splitter=BirdsSplitter(path),
|
68 |
+
get_y=RegexLabeller(pat=r"/([^/]+)_\d+_\d+\.jpg"),
|
69 |
+
item_tfms=item_tfms,
|
70 |
+
batch_tfms=batch_tfms,
|
71 |
+
)
|
72 |
+
|
73 |
+
dls = birds.dataloaders(path)
|
74 |
+
|
75 |
+
learner = vision_learner(dls, "vit_tiny_patch16_224", metrics=[accuracy])
|
76 |
+
|
77 |
+
learner.fine_tune(7, base_lr=0.001, freeze_epochs=12)
|
78 |
+
|
79 |
+
learner.export("models/vit_exported")
|
80 |
+
learner.save("vit_saved", with_opt=False)
|