init commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +1 -0
- CLS2IDX.py +1000 -0
- README.md +124 -13
- RobustViT.ipynb +0 -0
- SegmentationTest/data/Imagenet.py +74 -0
- SegmentationTest/data/VOC.py +372 -0
- SegmentationTest/data/__init__.py +0 -0
- SegmentationTest/data/imagenet_utils.py +1002 -0
- SegmentationTest/data/transforms.py +442 -0
- SegmentationTest/imagenet_seg_eval.py +319 -0
- SegmentationTest/utils/__init__.py +0 -0
- SegmentationTest/utils/confusionmatrix.py +88 -0
- SegmentationTest/utils/iou.py +93 -0
- SegmentationTest/utils/metric.py +12 -0
- SegmentationTest/utils/metrices.py +208 -0
- SegmentationTest/utils/parallel.py +260 -0
- SegmentationTest/utils/render.py +266 -0
- SegmentationTest/utils/saver.py +34 -0
- SegmentationTest/utils/summaries.py +11 -0
- ViT/ViT.py +308 -0
- ViT_new.py → ViT/ViT_new.py +0 -0
- ViT/__init__.py +0 -0
- ViT/explainer.py +71 -0
- ViT/helpers.py +295 -0
- ViT/layer_helpers.py +21 -0
- ViT/weight_init.py +60 -0
- imagenet_ablation_gt.py +590 -0
- imagenet_classes.json +1002 -0
- imagenet_eval_robustness.py +337 -0
- imagenet_eval_robustness_per_class.py +343 -0
- imagenet_finetune.py +567 -0
- imagenet_finetune_gradmask.py +586 -0
- imagenet_finetune_rrr.py +570 -0
- imagenet_finetune_tokencut.py +577 -0
- label_str_to_imagenet_classes.py +133 -0
- objectnet_dataset.py +117 -0
- robustness_dataset.py +66 -0
- robustness_dataset_per_class.py +65 -0
- samples/augreg_base/1_in.png +0 -0
- samples/augreg_base/2_in.png +0 -0
- samples/augreg_base/3_in.png +0 -0
- samples/augreg_base/a.png +0 -0
- samples/augreg_base/a_2.png +0 -0
- samples/augreg_base/a_3.png +0 -0
- samples/catdog.png +0 -0
- samples/deit_base/1_in.png +0 -0
- samples/deit_base/2_in.png +0 -0
- samples/deit_base/3_in.png +0 -0
- samples/deit_base/a.png +0 -0
- samples/deit_base/a_2.png +0 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.idea
|
CLS2IDX.py
ADDED
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CLS2IDX = {0: 'tench, Tinca tinca',
|
2 |
+
1: 'goldfish, Carassius auratus',
|
3 |
+
2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
|
4 |
+
3: 'tiger shark, Galeocerdo cuvieri',
|
5 |
+
4: 'hammerhead, hammerhead shark',
|
6 |
+
5: 'electric ray, crampfish, numbfish, torpedo',
|
7 |
+
6: 'stingray',
|
8 |
+
7: 'cock',
|
9 |
+
8: 'hen',
|
10 |
+
9: 'ostrich, Struthio camelus',
|
11 |
+
10: 'brambling, Fringilla montifringilla',
|
12 |
+
11: 'goldfinch, Carduelis carduelis',
|
13 |
+
12: 'house finch, linnet, Carpodacus mexicanus',
|
14 |
+
13: 'junco, snowbird',
|
15 |
+
14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
|
16 |
+
15: 'robin, American robin, Turdus migratorius',
|
17 |
+
16: 'bulbul',
|
18 |
+
17: 'jay',
|
19 |
+
18: 'magpie',
|
20 |
+
19: 'chickadee',
|
21 |
+
20: 'water ouzel, dipper',
|
22 |
+
21: 'kite',
|
23 |
+
22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
|
24 |
+
23: 'vulture',
|
25 |
+
24: 'great grey owl, great gray owl, Strix nebulosa',
|
26 |
+
25: 'European fire salamander, Salamandra salamandra',
|
27 |
+
26: 'common newt, Triturus vulgaris',
|
28 |
+
27: 'eft',
|
29 |
+
28: 'spotted salamander, Ambystoma maculatum',
|
30 |
+
29: 'axolotl, mud puppy, Ambystoma mexicanum',
|
31 |
+
30: 'bullfrog, Rana catesbeiana',
|
32 |
+
31: 'tree frog, tree-frog',
|
33 |
+
32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
|
34 |
+
33: 'loggerhead, loggerhead turtle, Caretta caretta',
|
35 |
+
34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
|
36 |
+
35: 'mud turtle',
|
37 |
+
36: 'terrapin',
|
38 |
+
37: 'box turtle, box tortoise',
|
39 |
+
38: 'banded gecko',
|
40 |
+
39: 'common iguana, iguana, Iguana iguana',
|
41 |
+
40: 'American chameleon, anole, Anolis carolinensis',
|
42 |
+
41: 'whiptail, whiptail lizard',
|
43 |
+
42: 'agama',
|
44 |
+
43: 'frilled lizard, Chlamydosaurus kingi',
|
45 |
+
44: 'alligator lizard',
|
46 |
+
45: 'Gila monster, Heloderma suspectum',
|
47 |
+
46: 'green lizard, Lacerta viridis',
|
48 |
+
47: 'African chameleon, Chamaeleo chamaeleon',
|
49 |
+
48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
|
50 |
+
49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
|
51 |
+
50: 'American alligator, Alligator mississipiensis',
|
52 |
+
51: 'triceratops',
|
53 |
+
52: 'thunder snake, worm snake, Carphophis amoenus',
|
54 |
+
53: 'ringneck snake, ring-necked snake, ring snake',
|
55 |
+
54: 'hognose snake, puff adder, sand viper',
|
56 |
+
55: 'green snake, grass snake',
|
57 |
+
56: 'king snake, kingsnake',
|
58 |
+
57: 'garter snake, grass snake',
|
59 |
+
58: 'water snake',
|
60 |
+
59: 'vine snake',
|
61 |
+
60: 'night snake, Hypsiglena torquata',
|
62 |
+
61: 'boa constrictor, Constrictor constrictor',
|
63 |
+
62: 'rock python, rock snake, Python sebae',
|
64 |
+
63: 'Indian cobra, Naja naja',
|
65 |
+
64: 'green mamba',
|
66 |
+
65: 'sea snake',
|
67 |
+
66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
|
68 |
+
67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
|
69 |
+
68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
|
70 |
+
69: 'trilobite',
|
71 |
+
70: 'harvestman, daddy longlegs, Phalangium opilio',
|
72 |
+
71: 'scorpion',
|
73 |
+
72: 'black and gold garden spider, Argiope aurantia',
|
74 |
+
73: 'barn spider, Araneus cavaticus',
|
75 |
+
74: 'garden spider, Aranea diademata',
|
76 |
+
75: 'black widow, Latrodectus mactans',
|
77 |
+
76: 'tarantula',
|
78 |
+
77: 'wolf spider, hunting spider',
|
79 |
+
78: 'tick',
|
80 |
+
79: 'centipede',
|
81 |
+
80: 'black grouse',
|
82 |
+
81: 'ptarmigan',
|
83 |
+
82: 'ruffed grouse, partridge, Bonasa umbellus',
|
84 |
+
83: 'prairie chicken, prairie grouse, prairie fowl',
|
85 |
+
84: 'peacock',
|
86 |
+
85: 'quail',
|
87 |
+
86: 'partridge',
|
88 |
+
87: 'African grey, African gray, Psittacus erithacus',
|
89 |
+
88: 'macaw',
|
90 |
+
89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
|
91 |
+
90: 'lorikeet',
|
92 |
+
91: 'coucal',
|
93 |
+
92: 'bee eater',
|
94 |
+
93: 'hornbill',
|
95 |
+
94: 'hummingbird',
|
96 |
+
95: 'jacamar',
|
97 |
+
96: 'toucan',
|
98 |
+
97: 'drake',
|
99 |
+
98: 'red-breasted merganser, Mergus serrator',
|
100 |
+
99: 'goose',
|
101 |
+
100: 'black swan, Cygnus atratus',
|
102 |
+
101: 'tusker',
|
103 |
+
102: 'echidna, spiny anteater, anteater',
|
104 |
+
103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
|
105 |
+
104: 'wallaby, brush kangaroo',
|
106 |
+
105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
|
107 |
+
106: 'wombat',
|
108 |
+
107: 'jellyfish',
|
109 |
+
108: 'sea anemone, anemone',
|
110 |
+
109: 'brain coral',
|
111 |
+
110: 'flatworm, platyhelminth',
|
112 |
+
111: 'nematode, nematode worm, roundworm',
|
113 |
+
112: 'conch',
|
114 |
+
113: 'snail',
|
115 |
+
114: 'slug',
|
116 |
+
115: 'sea slug, nudibranch',
|
117 |
+
116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
|
118 |
+
117: 'chambered nautilus, pearly nautilus, nautilus',
|
119 |
+
118: 'Dungeness crab, Cancer magister',
|
120 |
+
119: 'rock crab, Cancer irroratus',
|
121 |
+
120: 'fiddler crab',
|
122 |
+
121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
|
123 |
+
122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
|
124 |
+
123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
|
125 |
+
124: 'crayfish, crawfish, crawdad, crawdaddy',
|
126 |
+
125: 'hermit crab',
|
127 |
+
126: 'isopod',
|
128 |
+
127: 'white stork, Ciconia ciconia',
|
129 |
+
128: 'black stork, Ciconia nigra',
|
130 |
+
129: 'spoonbill',
|
131 |
+
130: 'flamingo',
|
132 |
+
131: 'little blue heron, Egretta caerulea',
|
133 |
+
132: 'American egret, great white heron, Egretta albus',
|
134 |
+
133: 'bittern',
|
135 |
+
134: 'crane',
|
136 |
+
135: 'limpkin, Aramus pictus',
|
137 |
+
136: 'European gallinule, Porphyrio porphyrio',
|
138 |
+
137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
|
139 |
+
138: 'bustard',
|
140 |
+
139: 'ruddy turnstone, Arenaria interpres',
|
141 |
+
140: 'red-backed sandpiper, dunlin, Erolia alpina',
|
142 |
+
141: 'redshank, Tringa totanus',
|
143 |
+
142: 'dowitcher',
|
144 |
+
143: 'oystercatcher, oyster catcher',
|
145 |
+
144: 'pelican',
|
146 |
+
145: 'king penguin, Aptenodytes patagonica',
|
147 |
+
146: 'albatross, mollymawk',
|
148 |
+
147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
|
149 |
+
148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
|
150 |
+
149: 'dugong, Dugong dugon',
|
151 |
+
150: 'sea lion',
|
152 |
+
151: 'Chihuahua',
|
153 |
+
152: 'Japanese spaniel',
|
154 |
+
153: 'Maltese dog, Maltese terrier, Maltese',
|
155 |
+
154: 'Pekinese, Pekingese, Peke',
|
156 |
+
155: 'Shih-Tzu',
|
157 |
+
156: 'Blenheim spaniel',
|
158 |
+
157: 'papillon',
|
159 |
+
158: 'toy terrier',
|
160 |
+
159: 'Rhodesian ridgeback',
|
161 |
+
160: 'Afghan hound, Afghan',
|
162 |
+
161: 'basset, basset hound',
|
163 |
+
162: 'beagle',
|
164 |
+
163: 'bloodhound, sleuthhound',
|
165 |
+
164: 'bluetick',
|
166 |
+
165: 'black-and-tan coonhound',
|
167 |
+
166: 'Walker hound, Walker foxhound',
|
168 |
+
167: 'English foxhound',
|
169 |
+
168: 'redbone',
|
170 |
+
169: 'borzoi, Russian wolfhound',
|
171 |
+
170: 'Irish wolfhound',
|
172 |
+
171: 'Italian greyhound',
|
173 |
+
172: 'whippet',
|
174 |
+
173: 'Ibizan hound, Ibizan Podenco',
|
175 |
+
174: 'Norwegian elkhound, elkhound',
|
176 |
+
175: 'otterhound, otter hound',
|
177 |
+
176: 'Saluki, gazelle hound',
|
178 |
+
177: 'Scottish deerhound, deerhound',
|
179 |
+
178: 'Weimaraner',
|
180 |
+
179: 'Staffordshire bullterrier, Staffordshire bull terrier',
|
181 |
+
180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
|
182 |
+
181: 'Bedlington terrier',
|
183 |
+
182: 'Border terrier',
|
184 |
+
183: 'Kerry blue terrier',
|
185 |
+
184: 'Irish terrier',
|
186 |
+
185: 'Norfolk terrier',
|
187 |
+
186: 'Norwich terrier',
|
188 |
+
187: 'Yorkshire terrier',
|
189 |
+
188: 'wire-haired fox terrier',
|
190 |
+
189: 'Lakeland terrier',
|
191 |
+
190: 'Sealyham terrier, Sealyham',
|
192 |
+
191: 'Airedale, Airedale terrier',
|
193 |
+
192: 'cairn, cairn terrier',
|
194 |
+
193: 'Australian terrier',
|
195 |
+
194: 'Dandie Dinmont, Dandie Dinmont terrier',
|
196 |
+
195: 'Boston bull, Boston terrier',
|
197 |
+
196: 'miniature schnauzer',
|
198 |
+
197: 'giant schnauzer',
|
199 |
+
198: 'standard schnauzer',
|
200 |
+
199: 'Scotch terrier, Scottish terrier, Scottie',
|
201 |
+
200: 'Tibetan terrier, chrysanthemum dog',
|
202 |
+
201: 'silky terrier, Sydney silky',
|
203 |
+
202: 'soft-coated wheaten terrier',
|
204 |
+
203: 'West Highland white terrier',
|
205 |
+
204: 'Lhasa, Lhasa apso',
|
206 |
+
205: 'flat-coated retriever',
|
207 |
+
206: 'curly-coated retriever',
|
208 |
+
207: 'golden retriever',
|
209 |
+
208: 'Labrador retriever',
|
210 |
+
209: 'Chesapeake Bay retriever',
|
211 |
+
210: 'German short-haired pointer',
|
212 |
+
211: 'vizsla, Hungarian pointer',
|
213 |
+
212: 'English setter',
|
214 |
+
213: 'Irish setter, red setter',
|
215 |
+
214: 'Gordon setter',
|
216 |
+
215: 'Brittany spaniel',
|
217 |
+
216: 'clumber, clumber spaniel',
|
218 |
+
217: 'English springer, English springer spaniel',
|
219 |
+
218: 'Welsh springer spaniel',
|
220 |
+
219: 'cocker spaniel, English cocker spaniel, cocker',
|
221 |
+
220: 'Sussex spaniel',
|
222 |
+
221: 'Irish water spaniel',
|
223 |
+
222: 'kuvasz',
|
224 |
+
223: 'schipperke',
|
225 |
+
224: 'groenendael',
|
226 |
+
225: 'malinois',
|
227 |
+
226: 'briard',
|
228 |
+
227: 'kelpie',
|
229 |
+
228: 'komondor',
|
230 |
+
229: 'Old English sheepdog, bobtail',
|
231 |
+
230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
|
232 |
+
231: 'collie',
|
233 |
+
232: 'Border collie',
|
234 |
+
233: 'Bouvier des Flandres, Bouviers des Flandres',
|
235 |
+
234: 'Rottweiler',
|
236 |
+
235: 'German shepherd, German shepherd dog, German police dog, alsatian',
|
237 |
+
236: 'Doberman, Doberman pinscher',
|
238 |
+
237: 'miniature pinscher',
|
239 |
+
238: 'Greater Swiss Mountain dog',
|
240 |
+
239: 'Bernese mountain dog',
|
241 |
+
240: 'Appenzeller',
|
242 |
+
241: 'EntleBucher',
|
243 |
+
242: 'boxer',
|
244 |
+
243: 'bull mastiff',
|
245 |
+
244: 'Tibetan mastiff',
|
246 |
+
245: 'French bulldog',
|
247 |
+
246: 'Great Dane',
|
248 |
+
247: 'Saint Bernard, St Bernard',
|
249 |
+
248: 'Eskimo dog, husky',
|
250 |
+
249: 'malamute, malemute, Alaskan malamute',
|
251 |
+
250: 'Siberian husky',
|
252 |
+
251: 'dalmatian, coach dog, carriage dog',
|
253 |
+
252: 'affenpinscher, monkey pinscher, monkey dog',
|
254 |
+
253: 'basenji',
|
255 |
+
254: 'pug, pug-dog',
|
256 |
+
255: 'Leonberg',
|
257 |
+
256: 'Newfoundland, Newfoundland dog',
|
258 |
+
257: 'Great Pyrenees',
|
259 |
+
258: 'Samoyed, Samoyede',
|
260 |
+
259: 'Pomeranian',
|
261 |
+
260: 'chow, chow chow',
|
262 |
+
261: 'keeshond',
|
263 |
+
262: 'Brabancon griffon',
|
264 |
+
263: 'Pembroke, Pembroke Welsh corgi',
|
265 |
+
264: 'Cardigan, Cardigan Welsh corgi',
|
266 |
+
265: 'toy poodle',
|
267 |
+
266: 'miniature poodle',
|
268 |
+
267: 'standard poodle',
|
269 |
+
268: 'Mexican hairless',
|
270 |
+
269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
|
271 |
+
270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
|
272 |
+
271: 'red wolf, maned wolf, Canis rufus, Canis niger',
|
273 |
+
272: 'coyote, prairie wolf, brush wolf, Canis latrans',
|
274 |
+
273: 'dingo, warrigal, warragal, Canis dingo',
|
275 |
+
274: 'dhole, Cuon alpinus',
|
276 |
+
275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
|
277 |
+
276: 'hyena, hyaena',
|
278 |
+
277: 'red fox, Vulpes vulpes',
|
279 |
+
278: 'kit fox, Vulpes macrotis',
|
280 |
+
279: 'Arctic fox, white fox, Alopex lagopus',
|
281 |
+
280: 'grey fox, gray fox, Urocyon cinereoargenteus',
|
282 |
+
281: 'tabby, tabby cat',
|
283 |
+
282: 'tiger cat',
|
284 |
+
283: 'Persian cat',
|
285 |
+
284: 'Siamese cat, Siamese',
|
286 |
+
285: 'Egyptian cat',
|
287 |
+
286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
|
288 |
+
287: 'lynx, catamount',
|
289 |
+
288: 'leopard, Panthera pardus',
|
290 |
+
289: 'snow leopard, ounce, Panthera uncia',
|
291 |
+
290: 'jaguar, panther, Panthera onca, Felis onca',
|
292 |
+
291: 'lion, king of beasts, Panthera leo',
|
293 |
+
292: 'tiger, Panthera tigris',
|
294 |
+
293: 'cheetah, chetah, Acinonyx jubatus',
|
295 |
+
294: 'brown bear, bruin, Ursus arctos',
|
296 |
+
295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
|
297 |
+
296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
|
298 |
+
297: 'sloth bear, Melursus ursinus, Ursus ursinus',
|
299 |
+
298: 'mongoose',
|
300 |
+
299: 'meerkat, mierkat',
|
301 |
+
300: 'tiger beetle',
|
302 |
+
301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
|
303 |
+
302: 'ground beetle, carabid beetle',
|
304 |
+
303: 'long-horned beetle, longicorn, longicorn beetle',
|
305 |
+
304: 'leaf beetle, chrysomelid',
|
306 |
+
305: 'dung beetle',
|
307 |
+
306: 'rhinoceros beetle',
|
308 |
+
307: 'weevil',
|
309 |
+
308: 'fly',
|
310 |
+
309: 'bee',
|
311 |
+
310: 'ant, emmet, pismire',
|
312 |
+
311: 'grasshopper, hopper',
|
313 |
+
312: 'cricket',
|
314 |
+
313: 'walking stick, walkingstick, stick insect',
|
315 |
+
314: 'cockroach, roach',
|
316 |
+
315: 'mantis, mantid',
|
317 |
+
316: 'cicada, cicala',
|
318 |
+
317: 'leafhopper',
|
319 |
+
318: 'lacewing, lacewing fly',
|
320 |
+
319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
|
321 |
+
320: 'damselfly',
|
322 |
+
321: 'admiral',
|
323 |
+
322: 'ringlet, ringlet butterfly',
|
324 |
+
323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
|
325 |
+
324: 'cabbage butterfly',
|
326 |
+
325: 'sulphur butterfly, sulfur butterfly',
|
327 |
+
326: 'lycaenid, lycaenid butterfly',
|
328 |
+
327: 'starfish, sea star',
|
329 |
+
328: 'sea urchin',
|
330 |
+
329: 'sea cucumber, holothurian',
|
331 |
+
330: 'wood rabbit, cottontail, cottontail rabbit',
|
332 |
+
331: 'hare',
|
333 |
+
332: 'Angora, Angora rabbit',
|
334 |
+
333: 'hamster',
|
335 |
+
334: 'porcupine, hedgehog',
|
336 |
+
335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
|
337 |
+
336: 'marmot',
|
338 |
+
337: 'beaver',
|
339 |
+
338: 'guinea pig, Cavia cobaya',
|
340 |
+
339: 'sorrel',
|
341 |
+
340: 'zebra',
|
342 |
+
341: 'hog, pig, grunter, squealer, Sus scrofa',
|
343 |
+
342: 'wild boar, boar, Sus scrofa',
|
344 |
+
343: 'warthog',
|
345 |
+
344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
|
346 |
+
345: 'ox',
|
347 |
+
346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
|
348 |
+
347: 'bison',
|
349 |
+
348: 'ram, tup',
|
350 |
+
349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
|
351 |
+
350: 'ibex, Capra ibex',
|
352 |
+
351: 'hartebeest',
|
353 |
+
352: 'impala, Aepyceros melampus',
|
354 |
+
353: 'gazelle',
|
355 |
+
354: 'Arabian camel, dromedary, Camelus dromedarius',
|
356 |
+
355: 'llama',
|
357 |
+
356: 'weasel',
|
358 |
+
357: 'mink',
|
359 |
+
358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
|
360 |
+
359: 'black-footed ferret, ferret, Mustela nigripes',
|
361 |
+
360: 'otter',
|
362 |
+
361: 'skunk, polecat, wood pussy',
|
363 |
+
362: 'badger',
|
364 |
+
363: 'armadillo',
|
365 |
+
364: 'three-toed sloth, ai, Bradypus tridactylus',
|
366 |
+
365: 'orangutan, orang, orangutang, Pongo pygmaeus',
|
367 |
+
366: 'gorilla, Gorilla gorilla',
|
368 |
+
367: 'chimpanzee, chimp, Pan troglodytes',
|
369 |
+
368: 'gibbon, Hylobates lar',
|
370 |
+
369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
|
371 |
+
370: 'guenon, guenon monkey',
|
372 |
+
371: 'patas, hussar monkey, Erythrocebus patas',
|
373 |
+
372: 'baboon',
|
374 |
+
373: 'macaque',
|
375 |
+
374: 'langur',
|
376 |
+
375: 'colobus, colobus monkey',
|
377 |
+
376: 'proboscis monkey, Nasalis larvatus',
|
378 |
+
377: 'marmoset',
|
379 |
+
378: 'capuchin, ringtail, Cebus capucinus',
|
380 |
+
379: 'howler monkey, howler',
|
381 |
+
380: 'titi, titi monkey',
|
382 |
+
381: 'spider monkey, Ateles geoffroyi',
|
383 |
+
382: 'squirrel monkey, Saimiri sciureus',
|
384 |
+
383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
|
385 |
+
384: 'indri, indris, Indri indri, Indri brevicaudatus',
|
386 |
+
385: 'Indian elephant, Elephas maximus',
|
387 |
+
386: 'African elephant, Loxodonta africana',
|
388 |
+
387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
|
389 |
+
388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
|
390 |
+
389: 'barracouta, snoek',
|
391 |
+
390: 'eel',
|
392 |
+
391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
|
393 |
+
392: 'rock beauty, Holocanthus tricolor',
|
394 |
+
393: 'anemone fish',
|
395 |
+
394: 'sturgeon',
|
396 |
+
395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
|
397 |
+
396: 'lionfish',
|
398 |
+
397: 'puffer, pufferfish, blowfish, globefish',
|
399 |
+
398: 'abacus',
|
400 |
+
399: 'abaya',
|
401 |
+
400: "academic gown, academic robe, judge's robe",
|
402 |
+
401: 'accordion, piano accordion, squeeze box',
|
403 |
+
402: 'acoustic guitar',
|
404 |
+
403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
|
405 |
+
404: 'airliner',
|
406 |
+
405: 'airship, dirigible',
|
407 |
+
406: 'altar',
|
408 |
+
407: 'ambulance',
|
409 |
+
408: 'amphibian, amphibious vehicle',
|
410 |
+
409: 'analog clock',
|
411 |
+
410: 'apiary, bee house',
|
412 |
+
411: 'apron',
|
413 |
+
412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
|
414 |
+
413: 'assault rifle, assault gun',
|
415 |
+
414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
|
416 |
+
415: 'bakery, bakeshop, bakehouse',
|
417 |
+
416: 'balance beam, beam',
|
418 |
+
417: 'balloon',
|
419 |
+
418: 'ballpoint, ballpoint pen, ballpen, Biro',
|
420 |
+
419: 'Band Aid',
|
421 |
+
420: 'banjo',
|
422 |
+
421: 'bannister, banister, balustrade, balusters, handrail',
|
423 |
+
422: 'barbell',
|
424 |
+
423: 'barber chair',
|
425 |
+
424: 'barbershop',
|
426 |
+
425: 'barn',
|
427 |
+
426: 'barometer',
|
428 |
+
427: 'barrel, cask',
|
429 |
+
428: 'barrow, garden cart, lawn cart, wheelbarrow',
|
430 |
+
429: 'baseball',
|
431 |
+
430: 'basketball',
|
432 |
+
431: 'bassinet',
|
433 |
+
432: 'bassoon',
|
434 |
+
433: 'bathing cap, swimming cap',
|
435 |
+
434: 'bath towel',
|
436 |
+
435: 'bathtub, bathing tub, bath, tub',
|
437 |
+
436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
|
438 |
+
437: 'beacon, lighthouse, beacon light, pharos',
|
439 |
+
438: 'beaker',
|
440 |
+
439: 'bearskin, busby, shako',
|
441 |
+
440: 'beer bottle',
|
442 |
+
441: 'beer glass',
|
443 |
+
442: 'bell cote, bell cot',
|
444 |
+
443: 'bib',
|
445 |
+
444: 'bicycle-built-for-two, tandem bicycle, tandem',
|
446 |
+
445: 'bikini, two-piece',
|
447 |
+
446: 'binder, ring-binder',
|
448 |
+
447: 'binoculars, field glasses, opera glasses',
|
449 |
+
448: 'birdhouse',
|
450 |
+
449: 'boathouse',
|
451 |
+
450: 'bobsled, bobsleigh, bob',
|
452 |
+
451: 'bolo tie, bolo, bola tie, bola',
|
453 |
+
452: 'bonnet, poke bonnet',
|
454 |
+
453: 'bookcase',
|
455 |
+
454: 'bookshop, bookstore, bookstall',
|
456 |
+
455: 'bottlecap',
|
457 |
+
456: 'bow',
|
458 |
+
457: 'bow tie, bow-tie, bowtie',
|
459 |
+
458: 'brass, memorial tablet, plaque',
|
460 |
+
459: 'brassiere, bra, bandeau',
|
461 |
+
460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
|
462 |
+
461: 'breastplate, aegis, egis',
|
463 |
+
462: 'broom',
|
464 |
+
463: 'bucket, pail',
|
465 |
+
464: 'buckle',
|
466 |
+
465: 'bulletproof vest',
|
467 |
+
466: 'bullet train, bullet',
|
468 |
+
467: 'butcher shop, meat market',
|
469 |
+
468: 'cab, hack, taxi, taxicab',
|
470 |
+
469: 'caldron, cauldron',
|
471 |
+
470: 'candle, taper, wax light',
|
472 |
+
471: 'cannon',
|
473 |
+
472: 'canoe',
|
474 |
+
473: 'can opener, tin opener',
|
475 |
+
474: 'cardigan',
|
476 |
+
475: 'car mirror',
|
477 |
+
476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
|
478 |
+
477: "carpenter's kit, tool kit",
|
479 |
+
478: 'carton',
|
480 |
+
479: 'car wheel',
|
481 |
+
480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
|
482 |
+
481: 'cassette',
|
483 |
+
482: 'cassette player',
|
484 |
+
483: 'castle',
|
485 |
+
484: 'catamaran',
|
486 |
+
485: 'CD player',
|
487 |
+
486: 'cello, violoncello',
|
488 |
+
487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
|
489 |
+
488: 'chain',
|
490 |
+
489: 'chainlink fence',
|
491 |
+
490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
|
492 |
+
491: 'chain saw, chainsaw',
|
493 |
+
492: 'chest',
|
494 |
+
493: 'chiffonier, commode',
|
495 |
+
494: 'chime, bell, gong',
|
496 |
+
495: 'china cabinet, china closet',
|
497 |
+
496: 'Christmas stocking',
|
498 |
+
497: 'church, church building',
|
499 |
+
498: 'cinema, movie theater, movie theatre, movie house, picture palace',
|
500 |
+
499: 'cleaver, meat cleaver, chopper',
|
501 |
+
500: 'cliff dwelling',
|
502 |
+
501: 'cloak',
|
503 |
+
502: 'clog, geta, patten, sabot',
|
504 |
+
503: 'cocktail shaker',
|
505 |
+
504: 'coffee mug',
|
506 |
+
505: 'coffeepot',
|
507 |
+
506: 'coil, spiral, volute, whorl, helix',
|
508 |
+
507: 'combination lock',
|
509 |
+
508: 'computer keyboard, keypad',
|
510 |
+
509: 'confectionery, confectionary, candy store',
|
511 |
+
510: 'container ship, containership, container vessel',
|
512 |
+
511: 'convertible',
|
513 |
+
512: 'corkscrew, bottle screw',
|
514 |
+
513: 'cornet, horn, trumpet, trump',
|
515 |
+
514: 'cowboy boot',
|
516 |
+
515: 'cowboy hat, ten-gallon hat',
|
517 |
+
516: 'cradle',
|
518 |
+
517: 'crane',
|
519 |
+
518: 'crash helmet',
|
520 |
+
519: 'crate',
|
521 |
+
520: 'crib, cot',
|
522 |
+
521: 'Crock Pot',
|
523 |
+
522: 'croquet ball',
|
524 |
+
523: 'crutch',
|
525 |
+
524: 'cuirass',
|
526 |
+
525: 'dam, dike, dyke',
|
527 |
+
526: 'desk',
|
528 |
+
527: 'desktop computer',
|
529 |
+
528: 'dial telephone, dial phone',
|
530 |
+
529: 'diaper, nappy, napkin',
|
531 |
+
530: 'digital clock',
|
532 |
+
531: 'digital watch',
|
533 |
+
532: 'dining table, board',
|
534 |
+
533: 'dishrag, dishcloth',
|
535 |
+
534: 'dishwasher, dish washer, dishwashing machine',
|
536 |
+
535: 'disk brake, disc brake',
|
537 |
+
536: 'dock, dockage, docking facility',
|
538 |
+
537: 'dogsled, dog sled, dog sleigh',
|
539 |
+
538: 'dome',
|
540 |
+
539: 'doormat, welcome mat',
|
541 |
+
540: 'drilling platform, offshore rig',
|
542 |
+
541: 'drum, membranophone, tympan',
|
543 |
+
542: 'drumstick',
|
544 |
+
543: 'dumbbell',
|
545 |
+
544: 'Dutch oven',
|
546 |
+
545: 'electric fan, blower',
|
547 |
+
546: 'electric guitar',
|
548 |
+
547: 'electric locomotive',
|
549 |
+
548: 'entertainment center',
|
550 |
+
549: 'envelope',
|
551 |
+
550: 'espresso maker',
|
552 |
+
551: 'face powder',
|
553 |
+
552: 'feather boa, boa',
|
554 |
+
553: 'file, file cabinet, filing cabinet',
|
555 |
+
554: 'fireboat',
|
556 |
+
555: 'fire engine, fire truck',
|
557 |
+
556: 'fire screen, fireguard',
|
558 |
+
557: 'flagpole, flagstaff',
|
559 |
+
558: 'flute, transverse flute',
|
560 |
+
559: 'folding chair',
|
561 |
+
560: 'football helmet',
|
562 |
+
561: 'forklift',
|
563 |
+
562: 'fountain',
|
564 |
+
563: 'fountain pen',
|
565 |
+
564: 'four-poster',
|
566 |
+
565: 'freight car',
|
567 |
+
566: 'French horn, horn',
|
568 |
+
567: 'frying pan, frypan, skillet',
|
569 |
+
568: 'fur coat',
|
570 |
+
569: 'garbage truck, dustcart',
|
571 |
+
570: 'gasmask, respirator, gas helmet',
|
572 |
+
571: 'gas pump, gasoline pump, petrol pump, island dispenser',
|
573 |
+
572: 'goblet',
|
574 |
+
573: 'go-kart',
|
575 |
+
574: 'golf ball',
|
576 |
+
575: 'golfcart, golf cart',
|
577 |
+
576: 'gondola',
|
578 |
+
577: 'gong, tam-tam',
|
579 |
+
578: 'gown',
|
580 |
+
579: 'grand piano, grand',
|
581 |
+
580: 'greenhouse, nursery, glasshouse',
|
582 |
+
581: 'grille, radiator grille',
|
583 |
+
582: 'grocery store, grocery, food market, market',
|
584 |
+
583: 'guillotine',
|
585 |
+
584: 'hair slide',
|
586 |
+
585: 'hair spray',
|
587 |
+
586: 'half track',
|
588 |
+
587: 'hammer',
|
589 |
+
588: 'hamper',
|
590 |
+
589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
|
591 |
+
590: 'hand-held computer, hand-held microcomputer',
|
592 |
+
591: 'handkerchief, hankie, hanky, hankey',
|
593 |
+
592: 'hard disc, hard disk, fixed disk',
|
594 |
+
593: 'harmonica, mouth organ, harp, mouth harp',
|
595 |
+
594: 'harp',
|
596 |
+
595: 'harvester, reaper',
|
597 |
+
596: 'hatchet',
|
598 |
+
597: 'holster',
|
599 |
+
598: 'home theater, home theatre',
|
600 |
+
599: 'honeycomb',
|
601 |
+
600: 'hook, claw',
|
602 |
+
601: 'hoopskirt, crinoline',
|
603 |
+
602: 'horizontal bar, high bar',
|
604 |
+
603: 'horse cart, horse-cart',
|
605 |
+
604: 'hourglass',
|
606 |
+
605: 'iPod',
|
607 |
+
606: 'iron, smoothing iron',
|
608 |
+
607: "jack-o'-lantern",
|
609 |
+
608: 'jean, blue jean, denim',
|
610 |
+
609: 'jeep, landrover',
|
611 |
+
610: 'jersey, T-shirt, tee shirt',
|
612 |
+
611: 'jigsaw puzzle',
|
613 |
+
612: 'jinrikisha, ricksha, rickshaw',
|
614 |
+
613: 'joystick',
|
615 |
+
614: 'kimono',
|
616 |
+
615: 'knee pad',
|
617 |
+
616: 'knot',
|
618 |
+
617: 'lab coat, laboratory coat',
|
619 |
+
618: 'ladle',
|
620 |
+
619: 'lampshade, lamp shade',
|
621 |
+
620: 'laptop, laptop computer',
|
622 |
+
621: 'lawn mower, mower',
|
623 |
+
622: 'lens cap, lens cover',
|
624 |
+
623: 'letter opener, paper knife, paperknife',
|
625 |
+
624: 'library',
|
626 |
+
625: 'lifeboat',
|
627 |
+
626: 'lighter, light, igniter, ignitor',
|
628 |
+
627: 'limousine, limo',
|
629 |
+
628: 'liner, ocean liner',
|
630 |
+
629: 'lipstick, lip rouge',
|
631 |
+
630: 'Loafer',
|
632 |
+
631: 'lotion',
|
633 |
+
632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
|
634 |
+
633: "loupe, jeweler's loupe",
|
635 |
+
634: 'lumbermill, sawmill',
|
636 |
+
635: 'magnetic compass',
|
637 |
+
636: 'mailbag, postbag',
|
638 |
+
637: 'mailbox, letter box',
|
639 |
+
638: 'maillot',
|
640 |
+
639: 'maillot, tank suit',
|
641 |
+
640: 'manhole cover',
|
642 |
+
641: 'maraca',
|
643 |
+
642: 'marimba, xylophone',
|
644 |
+
643: 'mask',
|
645 |
+
644: 'matchstick',
|
646 |
+
645: 'maypole',
|
647 |
+
646: 'maze, labyrinth',
|
648 |
+
647: 'measuring cup',
|
649 |
+
648: 'medicine chest, medicine cabinet',
|
650 |
+
649: 'megalith, megalithic structure',
|
651 |
+
650: 'microphone, mike',
|
652 |
+
651: 'microwave, microwave oven',
|
653 |
+
652: 'military uniform',
|
654 |
+
653: 'milk can',
|
655 |
+
654: 'minibus',
|
656 |
+
655: 'miniskirt, mini',
|
657 |
+
656: 'minivan',
|
658 |
+
657: 'missile',
|
659 |
+
658: 'mitten',
|
660 |
+
659: 'mixing bowl',
|
661 |
+
660: 'mobile home, manufactured home',
|
662 |
+
661: 'Model T',
|
663 |
+
662: 'modem',
|
664 |
+
663: 'monastery',
|
665 |
+
664: 'monitor',
|
666 |
+
665: 'moped',
|
667 |
+
666: 'mortar',
|
668 |
+
667: 'mortarboard',
|
669 |
+
668: 'mosque',
|
670 |
+
669: 'mosquito net',
|
671 |
+
670: 'motor scooter, scooter',
|
672 |
+
671: 'mountain bike, all-terrain bike, off-roader',
|
673 |
+
672: 'mountain tent',
|
674 |
+
673: 'mouse, computer mouse',
|
675 |
+
674: 'mousetrap',
|
676 |
+
675: 'moving van',
|
677 |
+
676: 'muzzle',
|
678 |
+
677: 'nail',
|
679 |
+
678: 'neck brace',
|
680 |
+
679: 'necklace',
|
681 |
+
680: 'nipple',
|
682 |
+
681: 'notebook, notebook computer',
|
683 |
+
682: 'obelisk',
|
684 |
+
683: 'oboe, hautboy, hautbois',
|
685 |
+
684: 'ocarina, sweet potato',
|
686 |
+
685: 'odometer, hodometer, mileometer, milometer',
|
687 |
+
686: 'oil filter',
|
688 |
+
687: 'organ, pipe organ',
|
689 |
+
688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
|
690 |
+
689: 'overskirt',
|
691 |
+
690: 'oxcart',
|
692 |
+
691: 'oxygen mask',
|
693 |
+
692: 'packet',
|
694 |
+
693: 'paddle, boat paddle',
|
695 |
+
694: 'paddlewheel, paddle wheel',
|
696 |
+
695: 'padlock',
|
697 |
+
696: 'paintbrush',
|
698 |
+
697: "pajama, pyjama, pj's, jammies",
|
699 |
+
698: 'palace',
|
700 |
+
699: 'panpipe, pandean pipe, syrinx',
|
701 |
+
700: 'paper towel',
|
702 |
+
701: 'parachute, chute',
|
703 |
+
702: 'parallel bars, bars',
|
704 |
+
703: 'park bench',
|
705 |
+
704: 'parking meter',
|
706 |
+
705: 'passenger car, coach, carriage',
|
707 |
+
706: 'patio, terrace',
|
708 |
+
707: 'pay-phone, pay-station',
|
709 |
+
708: 'pedestal, plinth, footstall',
|
710 |
+
709: 'pencil box, pencil case',
|
711 |
+
710: 'pencil sharpener',
|
712 |
+
711: 'perfume, essence',
|
713 |
+
712: 'Petri dish',
|
714 |
+
713: 'photocopier',
|
715 |
+
714: 'pick, plectrum, plectron',
|
716 |
+
715: 'pickelhaube',
|
717 |
+
716: 'picket fence, paling',
|
718 |
+
717: 'pickup, pickup truck',
|
719 |
+
718: 'pier',
|
720 |
+
719: 'piggy bank, penny bank',
|
721 |
+
720: 'pill bottle',
|
722 |
+
721: 'pillow',
|
723 |
+
722: 'ping-pong ball',
|
724 |
+
723: 'pinwheel',
|
725 |
+
724: 'pirate, pirate ship',
|
726 |
+
725: 'pitcher, ewer',
|
727 |
+
726: "plane, carpenter's plane, woodworking plane",
|
728 |
+
727: 'planetarium',
|
729 |
+
728: 'plastic bag',
|
730 |
+
729: 'plate rack',
|
731 |
+
730: 'plow, plough',
|
732 |
+
731: "plunger, plumber's helper",
|
733 |
+
732: 'Polaroid camera, Polaroid Land camera',
|
734 |
+
733: 'pole',
|
735 |
+
734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
|
736 |
+
735: 'poncho',
|
737 |
+
736: 'pool table, billiard table, snooker table',
|
738 |
+
737: 'pop bottle, soda bottle',
|
739 |
+
738: 'pot, flowerpot',
|
740 |
+
739: "potter's wheel",
|
741 |
+
740: 'power drill',
|
742 |
+
741: 'prayer rug, prayer mat',
|
743 |
+
742: 'printer',
|
744 |
+
743: 'prison, prison house',
|
745 |
+
744: 'projectile, missile',
|
746 |
+
745: 'projector',
|
747 |
+
746: 'puck, hockey puck',
|
748 |
+
747: 'punching bag, punch bag, punching ball, punchball',
|
749 |
+
748: 'purse',
|
750 |
+
749: 'quill, quill pen',
|
751 |
+
750: 'quilt, comforter, comfort, puff',
|
752 |
+
751: 'racer, race car, racing car',
|
753 |
+
752: 'racket, racquet',
|
754 |
+
753: 'radiator',
|
755 |
+
754: 'radio, wireless',
|
756 |
+
755: 'radio telescope, radio reflector',
|
757 |
+
756: 'rain barrel',
|
758 |
+
757: 'recreational vehicle, RV, R.V.',
|
759 |
+
758: 'reel',
|
760 |
+
759: 'reflex camera',
|
761 |
+
760: 'refrigerator, icebox',
|
762 |
+
761: 'remote control, remote',
|
763 |
+
762: 'restaurant, eating house, eating place, eatery',
|
764 |
+
763: 'revolver, six-gun, six-shooter',
|
765 |
+
764: 'rifle',
|
766 |
+
765: 'rocking chair, rocker',
|
767 |
+
766: 'rotisserie',
|
768 |
+
767: 'rubber eraser, rubber, pencil eraser',
|
769 |
+
768: 'rugby ball',
|
770 |
+
769: 'rule, ruler',
|
771 |
+
770: 'running shoe',
|
772 |
+
771: 'safe',
|
773 |
+
772: 'safety pin',
|
774 |
+
773: 'saltshaker, salt shaker',
|
775 |
+
774: 'sandal',
|
776 |
+
775: 'sarong',
|
777 |
+
776: 'sax, saxophone',
|
778 |
+
777: 'scabbard',
|
779 |
+
778: 'scale, weighing machine',
|
780 |
+
779: 'school bus',
|
781 |
+
780: 'schooner',
|
782 |
+
781: 'scoreboard',
|
783 |
+
782: 'screen, CRT screen',
|
784 |
+
783: 'screw',
|
785 |
+
784: 'screwdriver',
|
786 |
+
785: 'seat belt, seatbelt',
|
787 |
+
786: 'sewing machine',
|
788 |
+
787: 'shield, buckler',
|
789 |
+
788: 'shoe shop, shoe-shop, shoe store',
|
790 |
+
789: 'shoji',
|
791 |
+
790: 'shopping basket',
|
792 |
+
791: 'shopping cart',
|
793 |
+
792: 'shovel',
|
794 |
+
793: 'shower cap',
|
795 |
+
794: 'shower curtain',
|
796 |
+
795: 'ski',
|
797 |
+
796: 'ski mask',
|
798 |
+
797: 'sleeping bag',
|
799 |
+
798: 'slide rule, slipstick',
|
800 |
+
799: 'sliding door',
|
801 |
+
800: 'slot, one-armed bandit',
|
802 |
+
801: 'snorkel',
|
803 |
+
802: 'snowmobile',
|
804 |
+
803: 'snowplow, snowplough',
|
805 |
+
804: 'soap dispenser',
|
806 |
+
805: 'soccer ball',
|
807 |
+
806: 'sock',
|
808 |
+
807: 'solar dish, solar collector, solar furnace',
|
809 |
+
808: 'sombrero',
|
810 |
+
809: 'soup bowl',
|
811 |
+
810: 'space bar',
|
812 |
+
811: 'space heater',
|
813 |
+
812: 'space shuttle',
|
814 |
+
813: 'spatula',
|
815 |
+
814: 'speedboat',
|
816 |
+
815: "spider web, spider's web",
|
817 |
+
816: 'spindle',
|
818 |
+
817: 'sports car, sport car',
|
819 |
+
818: 'spotlight, spot',
|
820 |
+
819: 'stage',
|
821 |
+
820: 'steam locomotive',
|
822 |
+
821: 'steel arch bridge',
|
823 |
+
822: 'steel drum',
|
824 |
+
823: 'stethoscope',
|
825 |
+
824: 'stole',
|
826 |
+
825: 'stone wall',
|
827 |
+
826: 'stopwatch, stop watch',
|
828 |
+
827: 'stove',
|
829 |
+
828: 'strainer',
|
830 |
+
829: 'streetcar, tram, tramcar, trolley, trolley car',
|
831 |
+
830: 'stretcher',
|
832 |
+
831: 'studio couch, day bed',
|
833 |
+
832: 'stupa, tope',
|
834 |
+
833: 'submarine, pigboat, sub, U-boat',
|
835 |
+
834: 'suit, suit of clothes',
|
836 |
+
835: 'sundial',
|
837 |
+
836: 'sunglass',
|
838 |
+
837: 'sunglasses, dark glasses, shades',
|
839 |
+
838: 'sunscreen, sunblock, sun blocker',
|
840 |
+
839: 'suspension bridge',
|
841 |
+
840: 'swab, swob, mop',
|
842 |
+
841: 'sweatshirt',
|
843 |
+
842: 'swimming trunks, bathing trunks',
|
844 |
+
843: 'swing',
|
845 |
+
844: 'switch, electric switch, electrical switch',
|
846 |
+
845: 'syringe',
|
847 |
+
846: 'table lamp',
|
848 |
+
847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
|
849 |
+
848: 'tape player',
|
850 |
+
849: 'teapot',
|
851 |
+
850: 'teddy, teddy bear',
|
852 |
+
851: 'television, television system',
|
853 |
+
852: 'tennis ball',
|
854 |
+
853: 'thatch, thatched roof',
|
855 |
+
854: 'theater curtain, theatre curtain',
|
856 |
+
855: 'thimble',
|
857 |
+
856: 'thresher, thrasher, threshing machine',
|
858 |
+
857: 'throne',
|
859 |
+
858: 'tile roof',
|
860 |
+
859: 'toaster',
|
861 |
+
860: 'tobacco shop, tobacconist shop, tobacconist',
|
862 |
+
861: 'toilet seat',
|
863 |
+
862: 'torch',
|
864 |
+
863: 'totem pole',
|
865 |
+
864: 'tow truck, tow car, wrecker',
|
866 |
+
865: 'toyshop',
|
867 |
+
866: 'tractor',
|
868 |
+
867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
|
869 |
+
868: 'tray',
|
870 |
+
869: 'trench coat',
|
871 |
+
870: 'tricycle, trike, velocipede',
|
872 |
+
871: 'trimaran',
|
873 |
+
872: 'tripod',
|
874 |
+
873: 'triumphal arch',
|
875 |
+
874: 'trolleybus, trolley coach, trackless trolley',
|
876 |
+
875: 'trombone',
|
877 |
+
876: 'tub, vat',
|
878 |
+
877: 'turnstile',
|
879 |
+
878: 'typewriter keyboard',
|
880 |
+
879: 'umbrella',
|
881 |
+
880: 'unicycle, monocycle',
|
882 |
+
881: 'upright, upright piano',
|
883 |
+
882: 'vacuum, vacuum cleaner',
|
884 |
+
883: 'vase',
|
885 |
+
884: 'vault',
|
886 |
+
885: 'velvet',
|
887 |
+
886: 'vending machine',
|
888 |
+
887: 'vestment',
|
889 |
+
888: 'viaduct',
|
890 |
+
889: 'violin, fiddle',
|
891 |
+
890: 'volleyball',
|
892 |
+
891: 'waffle iron',
|
893 |
+
892: 'wall clock',
|
894 |
+
893: 'wallet, billfold, notecase, pocketbook',
|
895 |
+
894: 'wardrobe, closet, press',
|
896 |
+
895: 'warplane, military plane',
|
897 |
+
896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
|
898 |
+
897: 'washer, automatic washer, washing machine',
|
899 |
+
898: 'water bottle',
|
900 |
+
899: 'water jug',
|
901 |
+
900: 'water tower',
|
902 |
+
901: 'whiskey jug',
|
903 |
+
902: 'whistle',
|
904 |
+
903: 'wig',
|
905 |
+
904: 'window screen',
|
906 |
+
905: 'window shade',
|
907 |
+
906: 'Windsor tie',
|
908 |
+
907: 'wine bottle',
|
909 |
+
908: 'wing',
|
910 |
+
909: 'wok',
|
911 |
+
910: 'wooden spoon',
|
912 |
+
911: 'wool, woolen, woollen',
|
913 |
+
912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
|
914 |
+
913: 'wreck',
|
915 |
+
914: 'yawl',
|
916 |
+
915: 'yurt',
|
917 |
+
916: 'web site, website, internet site, site',
|
918 |
+
917: 'comic book',
|
919 |
+
918: 'crossword puzzle, crossword',
|
920 |
+
919: 'street sign',
|
921 |
+
920: 'traffic light, traffic signal, stoplight',
|
922 |
+
921: 'book jacket, dust cover, dust jacket, dust wrapper',
|
923 |
+
922: 'menu',
|
924 |
+
923: 'plate',
|
925 |
+
924: 'guacamole',
|
926 |
+
925: 'consomme',
|
927 |
+
926: 'hot pot, hotpot',
|
928 |
+
927: 'trifle',
|
929 |
+
928: 'ice cream, icecream',
|
930 |
+
929: 'ice lolly, lolly, lollipop, popsicle',
|
931 |
+
930: 'French loaf',
|
932 |
+
931: 'bagel, beigel',
|
933 |
+
932: 'pretzel',
|
934 |
+
933: 'cheeseburger',
|
935 |
+
934: 'hotdog, hot dog, red hot',
|
936 |
+
935: 'mashed potato',
|
937 |
+
936: 'head cabbage',
|
938 |
+
937: 'broccoli',
|
939 |
+
938: 'cauliflower',
|
940 |
+
939: 'zucchini, courgette',
|
941 |
+
940: 'spaghetti squash',
|
942 |
+
941: 'acorn squash',
|
943 |
+
942: 'butternut squash',
|
944 |
+
943: 'cucumber, cuke',
|
945 |
+
944: 'artichoke, globe artichoke',
|
946 |
+
945: 'bell pepper',
|
947 |
+
946: 'cardoon',
|
948 |
+
947: 'mushroom',
|
949 |
+
948: 'Granny Smith',
|
950 |
+
949: 'strawberry',
|
951 |
+
950: 'orange',
|
952 |
+
951: 'lemon',
|
953 |
+
952: 'fig',
|
954 |
+
953: 'pineapple, ananas',
|
955 |
+
954: 'banana',
|
956 |
+
955: 'jackfruit, jak, jack',
|
957 |
+
956: 'custard apple',
|
958 |
+
957: 'pomegranate',
|
959 |
+
958: 'hay',
|
960 |
+
959: 'carbonara',
|
961 |
+
960: 'chocolate sauce, chocolate syrup',
|
962 |
+
961: 'dough',
|
963 |
+
962: 'meat loaf, meatloaf',
|
964 |
+
963: 'pizza, pizza pie',
|
965 |
+
964: 'potpie',
|
966 |
+
965: 'burrito',
|
967 |
+
966: 'red wine',
|
968 |
+
967: 'espresso',
|
969 |
+
968: 'cup',
|
970 |
+
969: 'eggnog',
|
971 |
+
970: 'alp',
|
972 |
+
971: 'bubble',
|
973 |
+
972: 'cliff, drop, drop-off',
|
974 |
+
973: 'coral reef',
|
975 |
+
974: 'geyser',
|
976 |
+
975: 'lakeside, lakeshore',
|
977 |
+
976: 'promontory, headland, head, foreland',
|
978 |
+
977: 'sandbar, sand bar',
|
979 |
+
978: 'seashore, coast, seacoast, sea-coast',
|
980 |
+
979: 'valley, vale',
|
981 |
+
980: 'volcano',
|
982 |
+
981: 'ballplayer, baseball player',
|
983 |
+
982: 'groom, bridegroom',
|
984 |
+
983: 'scuba diver',
|
985 |
+
984: 'rapeseed',
|
986 |
+
985: 'daisy',
|
987 |
+
986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
|
988 |
+
987: 'corn',
|
989 |
+
988: 'acorn',
|
990 |
+
989: 'hip, rose hip, rosehip',
|
991 |
+
990: 'buckeye, horse chestnut, conker',
|
992 |
+
991: 'coral fungus',
|
993 |
+
992: 'agaric',
|
994 |
+
993: 'gyromitra',
|
995 |
+
994: 'stinkhorn, carrion fungus',
|
996 |
+
995: 'earthstar',
|
997 |
+
996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
|
998 |
+
997: 'bolete',
|
999 |
+
998: 'ear, spike, capitulum',
|
1000 |
+
999: 'toilet tissue, toilet paper, bathroom tissue'}
|
README.md
CHANGED
@@ -1,13 +1,124 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RobustViT
|
2 |
+
|
3 |
+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hila-chefer/RobustViT/blob/master/RobustViT.ipynb)
|
4 |
+
|
5 |
+
Official PyTorch implementation of **Optimizing Relevance Maps of Vision Transformers Improves Robustness**. This code allows to
|
6 |
+
finetune the explainability maps of Vision Transformers to enhance robustness.
|
7 |
+
|
8 |
+
The method employs loss functions directly to the explainability maps to ensure that the model is focused mostly on the foreground of the image:
|
9 |
+
<p align="center">
|
10 |
+
<img width="500" height="400" src="teaser.png">
|
11 |
+
</p>
|
12 |
+
Using a short finetuning process with only 3 labeled examples from 500 classes, our method imrpoves robustness of ViT models across different model sizes and training techniques, even when data augmentations/ regularization are applied.
|
13 |
+
|
14 |
+
## Producing Segmenataion Data
|
15 |
+
### Using ImageNet-S
|
16 |
+
To use the ImageNet-S labeled data, [download the `ImageNetS919` dataset](https://github.com/UnsupervisedSemanticSegmentation/ImageNet-S)
|
17 |
+
|
18 |
+
### Using TokenCut for unsupervised segmentation
|
19 |
+
1. Clone the TokenCut project
|
20 |
+
```
|
21 |
+
git clone https://github.com/YangtaoWANG95/TokenCut.git
|
22 |
+
```
|
23 |
+
2. Install the dependencies
|
24 |
+
Python 3.7, PyTorch 1.7.1 and CUDA 11.2. Please refer to the official installation. If CUDA 10.2 has been properly installed:
|
25 |
+
```
|
26 |
+
pip install torch==1.7.1 torchvision==0.8.2
|
27 |
+
```
|
28 |
+
Followed by
|
29 |
+
```
|
30 |
+
pip install -r TokenCut/requirements.txt
|
31 |
+
|
32 |
+
3. Use the following command to extract the segmentation maps:
|
33 |
+
```
|
34 |
+
python tokencut_generate_segmentation.py --img_path <PATH_TO_IMAGE> --out_dir <PATH_TO_OUTPUT_DIRECTORY>
|
35 |
+
```
|
36 |
+
|
37 |
+
|
38 |
+
## Finetuning ViT models
|
39 |
+
|
40 |
+
To finetune a pretrained ViT model use the `imagenet_finetune.py` script. Notice to uncomment the import line containing the pretrained model you
|
41 |
+
wish to finetune.
|
42 |
+
|
43 |
+
Usage example:
|
44 |
+
|
45 |
+
```bash
|
46 |
+
python imagenet_finetune.py --seg_data <PATH_TO_SEGMENTATION_DATA> --data <PATH_TO_IMAGENET> --gpu 0 --lr <LR> --lambda_seg <SEG> --lambda_acc <ACC> --lambda_background <BACK> --lambda_foreground <FORE>
|
47 |
+
```
|
48 |
+
|
49 |
+
Notes:
|
50 |
+
|
51 |
+
* For all models we use :
|
52 |
+
* `lambda_seg=0.8`
|
53 |
+
* `lambda_acc=0.2`
|
54 |
+
* `lambda_background=2`
|
55 |
+
* `lambda_foreground=0.3`
|
56 |
+
* For **DeiT** models, a temprature is required as follows:
|
57 |
+
* `temprature=0.65` for DeiT-B
|
58 |
+
* `temprature=0.55` for DeiT-S
|
59 |
+
* The learning rates per model are:
|
60 |
+
* ViT-B: 3e-6
|
61 |
+
* ViT-L: 9e-7
|
62 |
+
* AR-S: 2e-6
|
63 |
+
* AR-B: 6e-7
|
64 |
+
* AR-L: 9e-7
|
65 |
+
* DeiT-S: 1e-6
|
66 |
+
* DeiT-B: 8e-7
|
67 |
+
|
68 |
+
## Baseline methods
|
69 |
+
Notice to uncomment the import line containing the pretrained model you wish to finetune in the code.
|
70 |
+
|
71 |
+
### GradMask
|
72 |
+
Run the following command:
|
73 |
+
```bash
|
74 |
+
python imagenet_finetune_gradmask.py --seg_data <PATH_TO_SEGMENTATION_DATA> --data <PATH_TO_IMAGENET> --gpu 0 --lr <LR> --lambda_seg <SEG> --lambda_acc <ACC>
|
75 |
+
```
|
76 |
+
All hyperparameters for the different models can be found in section D of the supplementary material.
|
77 |
+
|
78 |
+
### Right for the Right Reasons
|
79 |
+
Run the following command:
|
80 |
+
```bash
|
81 |
+
python imagenet_finetune_rrr.py --seg_data <PATH_TO_SEGMENTATION_DATA> --data <PATH_TO_IMAGENET> --gpu 0 --lr <LR> --lambda_seg <SEG> --lambda_acc <ACC>
|
82 |
+
```
|
83 |
+
All hyperparameters for the different models can be found in section D of the supplementary material.
|
84 |
+
|
85 |
+
## Evaluation
|
86 |
+
|
87 |
+
### Robustness Evaluation
|
88 |
+
|
89 |
+
1. Download the evaluation datasets:
|
90 |
+
* [INet-A](https://github.com/hendrycks/natural-adv-examples)
|
91 |
+
* [INet-R](https://github.com/hendrycks/imagenet-r)
|
92 |
+
* [INet-v2](https://github.com/modestyachts/ImageNetV2)
|
93 |
+
* [ObjectNet](https://objectnet.dev/)
|
94 |
+
* [SI-Score](https://github.com/google-research/si-score)
|
95 |
+
|
96 |
+
2. Run the following script to evaluate:
|
97 |
+
|
98 |
+
```bash
|
99 |
+
python imagenet_eval_robustness.py --data <PATH_TO_ROBUSTNESS_DATASET> --batch-size <BATCH_SIZE> --evaluate --checkpoint <PATH_TO_FINETUNED_CHECKPOINT>
|
100 |
+
```
|
101 |
+
* Notice to uncomment the import line containing the pretrained model you wish to evaluate in the code.
|
102 |
+
* To evaluate the original model simply omit the `checkpoint` parameter.
|
103 |
+
* For the INet-v2 dataset add `--isV2`.
|
104 |
+
* For the ObjectNet dataset add `--isObjectNet`.
|
105 |
+
* For the SI datasets add `--isSI`.
|
106 |
+
|
107 |
+
### Segmentation Evaluation
|
108 |
+
Our segmentation tests are based on the test in the official implementation of [Transformer Interpretability Beyond Attention Visualization](https://github.com/hila-chefer/Transformer-Explainability).
|
109 |
+
1. [Download the ImageNet segmentation test set](https://github.com/hila-chefer/Transformer-Explainability#section-a-segmentation-results).
|
110 |
+
2. Run the following script to evaluate:
|
111 |
+
|
112 |
+
```bash
|
113 |
+
PYTHONPATH=./:$PYTHONPATH python SegmentationTest/imagenet_seg_eval.py --imagenet-seg-path <PATH_TO_gtsegs_ijcv.mat>
|
114 |
+
```
|
115 |
+
* Notice to uncomment the import line containing the pretrained model you wish to evaluate in the code.
|
116 |
+
|
117 |
+
### Credits
|
118 |
+
* The TokenCut code is built on top of [LOST](https://github.com/valeoai/LOST), [DINO](https://github.com/facebookresearch/dino), [Segswap](https://github.com/XiSHEN0220/SegSwap), and [Bilateral_Sovlver](https://github.com/poolio/bilateral_solver).
|
119 |
+
* Our ViT code is based on the [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) repository.
|
120 |
+
* Our ImageNet finetuning code is based on [code from the official PyTorch repo](https://github.com/pytorch/examples/blob/main/imagenet/main.py).
|
121 |
+
* The code to convert ObjectNet classes to ImageNet classes was taken from [the torchprune repo](https://github.com/lucaslie/torchprune/blob/b753745b773c3ed259bf819d193ce8573d89efbb/src/torchprune/torchprune/util/datasets/objectnet.py).
|
122 |
+
* The code to convert SI-Score classes to ImageNet classes was taken from [the official implementation](https://github.com/google-research/si-score).
|
123 |
+
|
124 |
+
We would like to sincerely thank the authors for their great works.
|
RobustViT.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
SegmentationTest/data/Imagenet.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.utils.data as data
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
import h5py
|
8 |
+
|
9 |
+
__all__ = ['ImagenetResults']
|
10 |
+
|
11 |
+
|
12 |
+
class Imagenet_Segmentation(data.Dataset):
|
13 |
+
CLASSES = 2
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
path,
|
17 |
+
transform=None,
|
18 |
+
target_transform=None):
|
19 |
+
self.path = path
|
20 |
+
self.transform = transform
|
21 |
+
self.target_transform = target_transform
|
22 |
+
self.h5py = None
|
23 |
+
tmp = h5py.File(path, 'r')
|
24 |
+
self.data_length = len(tmp['/value/img'])
|
25 |
+
tmp.close()
|
26 |
+
del tmp
|
27 |
+
|
28 |
+
def __getitem__(self, index):
|
29 |
+
|
30 |
+
if self.h5py is None:
|
31 |
+
self.h5py = h5py.File(self.path, 'r')
|
32 |
+
|
33 |
+
img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
|
34 |
+
target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
|
35 |
+
|
36 |
+
img = Image.fromarray(img).convert('RGB')
|
37 |
+
target = Image.fromarray(target)
|
38 |
+
|
39 |
+
if self.transform is not None:
|
40 |
+
img = self.transform(img)
|
41 |
+
|
42 |
+
if self.target_transform is not None:
|
43 |
+
target = np.array(self.target_transform(target)).astype('int32')
|
44 |
+
target = torch.from_numpy(target).long()
|
45 |
+
|
46 |
+
return img, target
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return self.data_length
|
50 |
+
|
51 |
+
|
52 |
+
class ImagenetResults(data.Dataset):
|
53 |
+
def __init__(self, path):
|
54 |
+
super(ImagenetResults, self).__init__()
|
55 |
+
|
56 |
+
self.path = os.path.join(path, 'results.hdf5')
|
57 |
+
self.data = None
|
58 |
+
|
59 |
+
print('Reading dataset length...')
|
60 |
+
with h5py.File(self.path, 'r') as f:
|
61 |
+
self.data_length = len(f['/image'])
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
return self.data_length
|
65 |
+
|
66 |
+
def __getitem__(self, item):
|
67 |
+
if self.data is None:
|
68 |
+
self.data = h5py.File(self.path, 'r')
|
69 |
+
|
70 |
+
image = torch.tensor(self.data['image'][item])
|
71 |
+
vis = torch.tensor(self.data['vis'][item])
|
72 |
+
target = torch.tensor(self.data['target'][item]).long()
|
73 |
+
|
74 |
+
return image, vis, target
|
SegmentationTest/data/VOC.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tarfile
|
3 |
+
import torch
|
4 |
+
import torch.utils.data as data
|
5 |
+
import numpy as np
|
6 |
+
import h5py
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
from scipy import io
|
10 |
+
from torchvision.datasets.utils import download_url
|
11 |
+
|
12 |
+
DATASET_YEAR_DICT = {
|
13 |
+
'2012': {
|
14 |
+
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
|
15 |
+
'filename': 'VOCtrainval_11-May-2012.tar',
|
16 |
+
'md5': '6cd6e144f989b92b3379bac3b3de84fd',
|
17 |
+
'base_dir': 'VOCdevkit/VOC2012'
|
18 |
+
},
|
19 |
+
'2011': {
|
20 |
+
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
|
21 |
+
'filename': 'VOCtrainval_25-May-2011.tar',
|
22 |
+
'md5': '6c3384ef61512963050cb5d687e5bf1e',
|
23 |
+
'base_dir': 'TrainVal/VOCdevkit/VOC2011'
|
24 |
+
},
|
25 |
+
'2010': {
|
26 |
+
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
|
27 |
+
'filename': 'VOCtrainval_03-May-2010.tar',
|
28 |
+
'md5': 'da459979d0c395079b5c75ee67908abb',
|
29 |
+
'base_dir': 'VOCdevkit/VOC2010'
|
30 |
+
},
|
31 |
+
'2009': {
|
32 |
+
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
|
33 |
+
'filename': 'VOCtrainval_11-May-2009.tar',
|
34 |
+
'md5': '59065e4b188729180974ef6572f6a212',
|
35 |
+
'base_dir': 'VOCdevkit/VOC2009'
|
36 |
+
},
|
37 |
+
'2008': {
|
38 |
+
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
|
39 |
+
'filename': 'VOCtrainval_11-May-2012.tar',
|
40 |
+
'md5': '2629fa636546599198acfcfbfcf1904a',
|
41 |
+
'base_dir': 'VOCdevkit/VOC2008'
|
42 |
+
},
|
43 |
+
'2007': {
|
44 |
+
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
|
45 |
+
'filename': 'VOCtrainval_06-Nov-2007.tar',
|
46 |
+
'md5': 'c52e279531787c972589f7e41ab4ae64',
|
47 |
+
'base_dir': 'VOCdevkit/VOC2007'
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
class VOCSegmentation(data.Dataset):
|
53 |
+
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
root (string): Root directory of the VOC Dataset.
|
57 |
+
year (string, optional): The dataset year, supports years 2007 to 2012.
|
58 |
+
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
|
59 |
+
download (bool, optional): If true, downloads the dataset from the internet and
|
60 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
61 |
+
downloaded again.
|
62 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
63 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
64 |
+
target_transform (callable, optional): A function/transform that takes in the
|
65 |
+
target and transforms it.
|
66 |
+
"""
|
67 |
+
|
68 |
+
CLASSES = 20
|
69 |
+
CLASSES_NAMES = [
|
70 |
+
'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
|
71 |
+
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
|
72 |
+
'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
|
73 |
+
'tvmonitor', 'ambigious'
|
74 |
+
]
|
75 |
+
|
76 |
+
def __init__(self,
|
77 |
+
root,
|
78 |
+
year='2012',
|
79 |
+
image_set='train',
|
80 |
+
download=False,
|
81 |
+
transform=None,
|
82 |
+
target_transform=None):
|
83 |
+
self.root = os.path.expanduser(root)
|
84 |
+
self.year = year
|
85 |
+
self.url = DATASET_YEAR_DICT[year]['url']
|
86 |
+
self.filename = DATASET_YEAR_DICT[year]['filename']
|
87 |
+
self.md5 = DATASET_YEAR_DICT[year]['md5']
|
88 |
+
self.transform = transform
|
89 |
+
self.target_transform = target_transform
|
90 |
+
self.image_set = image_set
|
91 |
+
base_dir = DATASET_YEAR_DICT[year]['base_dir']
|
92 |
+
voc_root = os.path.join(self.root, base_dir)
|
93 |
+
image_dir = os.path.join(voc_root, 'JPEGImages')
|
94 |
+
mask_dir = os.path.join(voc_root, 'SegmentationClass')
|
95 |
+
|
96 |
+
if download:
|
97 |
+
download_extract(self.url, self.root, self.filename, self.md5)
|
98 |
+
|
99 |
+
if not os.path.isdir(voc_root):
|
100 |
+
raise RuntimeError('Dataset not found or corrupted.' +
|
101 |
+
' You can use download=True to download it')
|
102 |
+
|
103 |
+
splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
|
104 |
+
|
105 |
+
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
|
106 |
+
|
107 |
+
if not os.path.exists(split_f):
|
108 |
+
raise ValueError(
|
109 |
+
'Wrong image_set entered! Please use image_set="train" '
|
110 |
+
'or image_set="trainval" or image_set="val"')
|
111 |
+
|
112 |
+
with open(os.path.join(split_f), "r") as f:
|
113 |
+
file_names = [x.strip() for x in f.readlines()]
|
114 |
+
|
115 |
+
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
|
116 |
+
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
|
117 |
+
assert (len(self.images) == len(self.masks))
|
118 |
+
|
119 |
+
def __getitem__(self, index):
|
120 |
+
"""
|
121 |
+
Args:
|
122 |
+
index (int): Index
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
tuple: (image, target) where target is the image segmentation.
|
126 |
+
"""
|
127 |
+
img = Image.open(self.images[index]).convert('RGB')
|
128 |
+
target = Image.open(self.masks[index])
|
129 |
+
|
130 |
+
if self.transform is not None:
|
131 |
+
img = self.transform(img)
|
132 |
+
|
133 |
+
if self.target_transform is not None:
|
134 |
+
target = np.array(self.target_transform(target)).astype('int32')
|
135 |
+
target[target == 255] = -1
|
136 |
+
target = torch.from_numpy(target).long()
|
137 |
+
|
138 |
+
return img, target
|
139 |
+
|
140 |
+
@staticmethod
|
141 |
+
def _mask_transform(mask):
|
142 |
+
target = np.array(mask).astype('int32')
|
143 |
+
target[target == 255] = -1
|
144 |
+
return torch.from_numpy(target).long()
|
145 |
+
|
146 |
+
def __len__(self):
|
147 |
+
return len(self.images)
|
148 |
+
|
149 |
+
@property
|
150 |
+
def pred_offset(self):
|
151 |
+
return 0
|
152 |
+
|
153 |
+
|
154 |
+
class VOCClassification(data.Dataset):
|
155 |
+
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
root (string): Root directory of the VOC Dataset.
|
159 |
+
year (string, optional): The dataset year, supports years 2007 to 2012.
|
160 |
+
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
|
161 |
+
download (bool, optional): If true, downloads the dataset from the internet and
|
162 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
163 |
+
downloaded again.
|
164 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
165 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
166 |
+
"""
|
167 |
+
CLASSES = 20
|
168 |
+
|
169 |
+
def __init__(self,
|
170 |
+
root,
|
171 |
+
year='2012',
|
172 |
+
image_set='train',
|
173 |
+
download=False,
|
174 |
+
transform=None):
|
175 |
+
self.root = os.path.expanduser(root)
|
176 |
+
self.year = year
|
177 |
+
self.url = DATASET_YEAR_DICT[year]['url']
|
178 |
+
self.filename = DATASET_YEAR_DICT[year]['filename']
|
179 |
+
self.md5 = DATASET_YEAR_DICT[year]['md5']
|
180 |
+
self.transform = transform
|
181 |
+
self.image_set = image_set
|
182 |
+
base_dir = DATASET_YEAR_DICT[year]['base_dir']
|
183 |
+
voc_root = os.path.join(self.root, base_dir)
|
184 |
+
image_dir = os.path.join(voc_root, 'JPEGImages')
|
185 |
+
mask_dir = os.path.join(voc_root, 'SegmentationClass')
|
186 |
+
|
187 |
+
if download:
|
188 |
+
download_extract(self.url, self.root, self.filename, self.md5)
|
189 |
+
|
190 |
+
if not os.path.isdir(voc_root):
|
191 |
+
raise RuntimeError('Dataset not found or corrupted.' +
|
192 |
+
' You can use download=True to download it')
|
193 |
+
|
194 |
+
splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
|
195 |
+
|
196 |
+
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
|
197 |
+
|
198 |
+
if not os.path.exists(split_f):
|
199 |
+
raise ValueError(
|
200 |
+
'Wrong image_set entered! Please use image_set="train" '
|
201 |
+
'or image_set="trainval" or image_set="val"')
|
202 |
+
|
203 |
+
with open(os.path.join(split_f), "r") as f:
|
204 |
+
file_names = [x.strip() for x in f.readlines()]
|
205 |
+
|
206 |
+
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
|
207 |
+
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
|
208 |
+
assert (len(self.images) == len(self.masks))
|
209 |
+
|
210 |
+
def __getitem__(self, index):
|
211 |
+
"""
|
212 |
+
Args:
|
213 |
+
index (int): Index
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
tuple: (image, target) where target is the image segmentation.
|
217 |
+
"""
|
218 |
+
img = Image.open(self.images[index]).convert('RGB')
|
219 |
+
target = Image.open(self.masks[index])
|
220 |
+
|
221 |
+
# if self.transform is not None:
|
222 |
+
# img = self.transform(img)
|
223 |
+
if self.transform is not None:
|
224 |
+
img, target = self.transform(img, target)
|
225 |
+
|
226 |
+
visible_classes = np.unique(target)
|
227 |
+
labels = torch.zeros(self.CLASSES)
|
228 |
+
for id in visible_classes:
|
229 |
+
if id not in (0, 255):
|
230 |
+
labels[id - 1].fill_(1)
|
231 |
+
|
232 |
+
return img, labels
|
233 |
+
|
234 |
+
def __len__(self):
|
235 |
+
return len(self.images)
|
236 |
+
|
237 |
+
|
238 |
+
class VOCSBDClassification(data.Dataset):
|
239 |
+
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
root (string): Root directory of the VOC Dataset.
|
243 |
+
year (string, optional): The dataset year, supports years 2007 to 2012.
|
244 |
+
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
|
245 |
+
download (bool, optional): If true, downloads the dataset from the internet and
|
246 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
247 |
+
downloaded again.
|
248 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
249 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
250 |
+
"""
|
251 |
+
CLASSES = 20
|
252 |
+
|
253 |
+
def __init__(self,
|
254 |
+
root,
|
255 |
+
sbd_root,
|
256 |
+
year='2012',
|
257 |
+
image_set='train',
|
258 |
+
download=False,
|
259 |
+
transform=None):
|
260 |
+
self.root = os.path.expanduser(root)
|
261 |
+
self.sbd_root = os.path.expanduser(sbd_root)
|
262 |
+
self.year = year
|
263 |
+
self.url = DATASET_YEAR_DICT[year]['url']
|
264 |
+
self.filename = DATASET_YEAR_DICT[year]['filename']
|
265 |
+
self.md5 = DATASET_YEAR_DICT[year]['md5']
|
266 |
+
self.transform = transform
|
267 |
+
self.image_set = image_set
|
268 |
+
base_dir = DATASET_YEAR_DICT[year]['base_dir']
|
269 |
+
voc_root = os.path.join(self.root, base_dir)
|
270 |
+
image_dir = os.path.join(voc_root, 'JPEGImages')
|
271 |
+
mask_dir = os.path.join(voc_root, 'SegmentationClass')
|
272 |
+
sbd_image_dir = os.path.join(sbd_root, 'img')
|
273 |
+
sbd_mask_dir = os.path.join(sbd_root, 'cls')
|
274 |
+
|
275 |
+
if download:
|
276 |
+
download_extract(self.url, self.root, self.filename, self.md5)
|
277 |
+
|
278 |
+
if not os.path.isdir(voc_root):
|
279 |
+
raise RuntimeError('Dataset not found or corrupted.' +
|
280 |
+
' You can use download=True to download it')
|
281 |
+
|
282 |
+
splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
|
283 |
+
|
284 |
+
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
|
285 |
+
sbd_split = os.path.join(sbd_root, 'train.txt')
|
286 |
+
|
287 |
+
if not os.path.exists(split_f):
|
288 |
+
raise ValueError(
|
289 |
+
'Wrong image_set entered! Please use image_set="train" '
|
290 |
+
'or image_set="trainval" or image_set="val"')
|
291 |
+
|
292 |
+
with open(os.path.join(split_f), "r") as f:
|
293 |
+
voc_file_names = [x.strip() for x in f.readlines()]
|
294 |
+
|
295 |
+
with open(os.path.join(sbd_split), "r") as f:
|
296 |
+
sbd_file_names = [x.strip() for x in f.readlines()]
|
297 |
+
|
298 |
+
self.images = [os.path.join(image_dir, x + ".jpg") for x in voc_file_names]
|
299 |
+
self.images += [os.path.join(sbd_image_dir, x + ".jpg") for x in sbd_file_names]
|
300 |
+
self.masks = [os.path.join(mask_dir, x + ".png") for x in voc_file_names]
|
301 |
+
self.masks += [os.path.join(sbd_mask_dir, x + ".mat") for x in sbd_file_names]
|
302 |
+
assert (len(self.images) == len(self.masks))
|
303 |
+
|
304 |
+
def __getitem__(self, index):
|
305 |
+
"""
|
306 |
+
Args:
|
307 |
+
index (int): Index
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
tuple: (image, target) where target is the image segmentation.
|
311 |
+
"""
|
312 |
+
img = Image.open(self.images[index]).convert('RGB')
|
313 |
+
mask_path = self.masks[index]
|
314 |
+
if mask_path[-3:] == 'mat':
|
315 |
+
target = io.loadmat(mask_path, struct_as_record=False, squeeze_me=True)['GTcls'].Segmentation
|
316 |
+
target = Image.fromarray(target, mode='P')
|
317 |
+
else:
|
318 |
+
target = Image.open(self.masks[index])
|
319 |
+
|
320 |
+
if self.transform is not None:
|
321 |
+
img, target = self.transform(img, target)
|
322 |
+
|
323 |
+
visible_classes = np.unique(target)
|
324 |
+
labels = torch.zeros(self.CLASSES)
|
325 |
+
for id in visible_classes:
|
326 |
+
if id not in (0, 255):
|
327 |
+
labels[id - 1].fill_(1)
|
328 |
+
|
329 |
+
return img, labels
|
330 |
+
|
331 |
+
def __len__(self):
|
332 |
+
return len(self.images)
|
333 |
+
|
334 |
+
|
335 |
+
def download_extract(url, root, filename, md5):
|
336 |
+
download_url(url, root, filename, md5)
|
337 |
+
with tarfile.open(os.path.join(root, filename), "r") as tar:
|
338 |
+
tar.extractall(path=root)
|
339 |
+
|
340 |
+
|
341 |
+
class VOCResults(data.Dataset):
|
342 |
+
CLASSES = 20
|
343 |
+
CLASSES_NAMES = [
|
344 |
+
'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
|
345 |
+
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
|
346 |
+
'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
|
347 |
+
'tvmonitor', 'ambigious'
|
348 |
+
]
|
349 |
+
|
350 |
+
def __init__(self, path):
|
351 |
+
super(VOCResults, self).__init__()
|
352 |
+
|
353 |
+
self.path = os.path.join(path, 'results.hdf5')
|
354 |
+
self.data = None
|
355 |
+
|
356 |
+
print('Reading dataset length...')
|
357 |
+
with h5py.File(self.path , 'r') as f:
|
358 |
+
self.data_length = len(f['/image'])
|
359 |
+
|
360 |
+
def __len__(self):
|
361 |
+
return self.data_length
|
362 |
+
|
363 |
+
def __getitem__(self, item):
|
364 |
+
if self.data is None:
|
365 |
+
self.data = h5py.File(self.path, 'r')
|
366 |
+
|
367 |
+
image = torch.tensor(self.data['image'][item])
|
368 |
+
vis = torch.tensor(self.data['vis'][item])
|
369 |
+
target = torch.tensor(self.data['target'][item])
|
370 |
+
class_pred = torch.tensor(self.data['class_pred'][item])
|
371 |
+
|
372 |
+
return image, vis, target, class_pred
|
SegmentationTest/data/__init__.py
ADDED
File without changes
|
SegmentationTest/data/imagenet_utils.py
ADDED
@@ -0,0 +1,1002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CLS2IDX = {
|
2 |
+
0: 'tench, Tinca tinca',
|
3 |
+
1: 'goldfish, Carassius auratus',
|
4 |
+
2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
|
5 |
+
3: 'tiger shark, Galeocerdo cuvieri',
|
6 |
+
4: 'hammerhead, hammerhead shark',
|
7 |
+
5: 'electric ray, crampfish, numbfish, torpedo',
|
8 |
+
6: 'stingray',
|
9 |
+
7: 'cock',
|
10 |
+
8: 'hen',
|
11 |
+
9: 'ostrich, Struthio camelus',
|
12 |
+
10: 'brambling, Fringilla montifringilla',
|
13 |
+
11: 'goldfinch, Carduelis carduelis',
|
14 |
+
12: 'house finch, linnet, Carpodacus mexicanus',
|
15 |
+
13: 'junco, snowbird',
|
16 |
+
14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
|
17 |
+
15: 'robin, American robin, Turdus migratorius',
|
18 |
+
16: 'bulbul',
|
19 |
+
17: 'jay',
|
20 |
+
18: 'magpie',
|
21 |
+
19: 'chickadee',
|
22 |
+
20: 'water ouzel, dipper',
|
23 |
+
21: 'kite',
|
24 |
+
22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
|
25 |
+
23: 'vulture',
|
26 |
+
24: 'great grey owl, great gray owl, Strix nebulosa',
|
27 |
+
25: 'European fire salamander, Salamandra salamandra',
|
28 |
+
26: 'common newt, Triturus vulgaris',
|
29 |
+
27: 'eft',
|
30 |
+
28: 'spotted salamander, Ambystoma maculatum',
|
31 |
+
29: 'axolotl, mud puppy, Ambystoma mexicanum',
|
32 |
+
30: 'bullfrog, Rana catesbeiana',
|
33 |
+
31: 'tree frog, tree-frog',
|
34 |
+
32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
|
35 |
+
33: 'loggerhead, loggerhead turtle, Caretta caretta',
|
36 |
+
34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
|
37 |
+
35: 'mud turtle',
|
38 |
+
36: 'terrapin',
|
39 |
+
37: 'box turtle, box tortoise',
|
40 |
+
38: 'banded gecko',
|
41 |
+
39: 'common iguana, iguana, Iguana iguana',
|
42 |
+
40: 'American chameleon, anole, Anolis carolinensis',
|
43 |
+
41: 'whiptail, whiptail lizard',
|
44 |
+
42: 'agama',
|
45 |
+
43: 'frilled lizard, Chlamydosaurus kingi',
|
46 |
+
44: 'alligator lizard',
|
47 |
+
45: 'Gila monster, Heloderma suspectum',
|
48 |
+
46: 'green lizard, Lacerta viridis',
|
49 |
+
47: 'African chameleon, Chamaeleo chamaeleon',
|
50 |
+
48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
|
51 |
+
49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
|
52 |
+
50: 'American alligator, Alligator mississipiensis',
|
53 |
+
51: 'triceratops',
|
54 |
+
52: 'thunder snake, worm snake, Carphophis amoenus',
|
55 |
+
53: 'ringneck snake, ring-necked snake, ring snake',
|
56 |
+
54: 'hognose snake, puff adder, sand viper',
|
57 |
+
55: 'green snake, grass snake',
|
58 |
+
56: 'king snake, kingsnake',
|
59 |
+
57: 'garter snake, grass snake',
|
60 |
+
58: 'water snake',
|
61 |
+
59: 'vine snake',
|
62 |
+
60: 'night snake, Hypsiglena torquata',
|
63 |
+
61: 'boa constrictor, Constrictor constrictor',
|
64 |
+
62: 'rock python, rock snake, Python sebae',
|
65 |
+
63: 'Indian cobra, Naja naja',
|
66 |
+
64: 'green mamba',
|
67 |
+
65: 'sea snake',
|
68 |
+
66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
|
69 |
+
67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
|
70 |
+
68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
|
71 |
+
69: 'trilobite',
|
72 |
+
70: 'harvestman, daddy longlegs, Phalangium opilio',
|
73 |
+
71: 'scorpion',
|
74 |
+
72: 'black and gold garden spider, Argiope aurantia',
|
75 |
+
73: 'barn spider, Araneus cavaticus',
|
76 |
+
74: 'garden spider, Aranea diademata',
|
77 |
+
75: 'black widow, Latrodectus mactans',
|
78 |
+
76: 'tarantula',
|
79 |
+
77: 'wolf spider, hunting spider',
|
80 |
+
78: 'tick',
|
81 |
+
79: 'centipede',
|
82 |
+
80: 'black grouse',
|
83 |
+
81: 'ptarmigan',
|
84 |
+
82: 'ruffed grouse, partridge, Bonasa umbellus',
|
85 |
+
83: 'prairie chicken, prairie grouse, prairie fowl',
|
86 |
+
84: 'peacock',
|
87 |
+
85: 'quail',
|
88 |
+
86: 'partridge',
|
89 |
+
87: 'African grey, African gray, Psittacus erithacus',
|
90 |
+
88: 'macaw',
|
91 |
+
89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
|
92 |
+
90: 'lorikeet',
|
93 |
+
91: 'coucal',
|
94 |
+
92: 'bee eater',
|
95 |
+
93: 'hornbill',
|
96 |
+
94: 'hummingbird',
|
97 |
+
95: 'jacamar',
|
98 |
+
96: 'toucan',
|
99 |
+
97: 'drake',
|
100 |
+
98: 'red-breasted merganser, Mergus serrator',
|
101 |
+
99: 'goose',
|
102 |
+
100: 'black swan, Cygnus atratus',
|
103 |
+
101: 'tusker',
|
104 |
+
102: 'echidna, spiny anteater, anteater',
|
105 |
+
103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
|
106 |
+
104: 'wallaby, brush kangaroo',
|
107 |
+
105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
|
108 |
+
106: 'wombat',
|
109 |
+
107: 'jellyfish',
|
110 |
+
108: 'sea anemone, anemone',
|
111 |
+
109: 'brain coral',
|
112 |
+
110: 'flatworm, platyhelminth',
|
113 |
+
111: 'nematode, nematode worm, roundworm',
|
114 |
+
112: 'conch',
|
115 |
+
113: 'snail',
|
116 |
+
114: 'slug',
|
117 |
+
115: 'sea slug, nudibranch',
|
118 |
+
116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
|
119 |
+
117: 'chambered nautilus, pearly nautilus, nautilus',
|
120 |
+
118: 'Dungeness crab, Cancer magister',
|
121 |
+
119: 'rock crab, Cancer irroratus',
|
122 |
+
120: 'fiddler crab',
|
123 |
+
121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
|
124 |
+
122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
|
125 |
+
123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
|
126 |
+
124: 'crayfish, crawfish, crawdad, crawdaddy',
|
127 |
+
125: 'hermit crab',
|
128 |
+
126: 'isopod',
|
129 |
+
127: 'white stork, Ciconia ciconia',
|
130 |
+
128: 'black stork, Ciconia nigra',
|
131 |
+
129: 'spoonbill',
|
132 |
+
130: 'flamingo',
|
133 |
+
131: 'little blue heron, Egretta caerulea',
|
134 |
+
132: 'American egret, great white heron, Egretta albus',
|
135 |
+
133: 'bittern',
|
136 |
+
134: 'crane',
|
137 |
+
135: 'limpkin, Aramus pictus',
|
138 |
+
136: 'European gallinule, Porphyrio porphyrio',
|
139 |
+
137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
|
140 |
+
138: 'bustard',
|
141 |
+
139: 'ruddy turnstone, Arenaria interpres',
|
142 |
+
140: 'red-backed sandpiper, dunlin, Erolia alpina',
|
143 |
+
141: 'redshank, Tringa totanus',
|
144 |
+
142: 'dowitcher',
|
145 |
+
143: 'oystercatcher, oyster catcher',
|
146 |
+
144: 'pelican',
|
147 |
+
145: 'king penguin, Aptenodytes patagonica',
|
148 |
+
146: 'albatross, mollymawk',
|
149 |
+
147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
|
150 |
+
148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
|
151 |
+
149: 'dugong, Dugong dugon',
|
152 |
+
150: 'sea lion',
|
153 |
+
151: 'Chihuahua',
|
154 |
+
152: 'Japanese spaniel',
|
155 |
+
153: 'Maltese dog, Maltese terrier, Maltese',
|
156 |
+
154: 'Pekinese, Pekingese, Peke',
|
157 |
+
155: 'Shih-Tzu',
|
158 |
+
156: 'Blenheim spaniel',
|
159 |
+
157: 'papillon',
|
160 |
+
158: 'toy terrier',
|
161 |
+
159: 'Rhodesian ridgeback',
|
162 |
+
160: 'Afghan hound, Afghan',
|
163 |
+
161: 'basset, basset hound',
|
164 |
+
162: 'beagle',
|
165 |
+
163: 'bloodhound, sleuthhound',
|
166 |
+
164: 'bluetick',
|
167 |
+
165: 'black-and-tan coonhound',
|
168 |
+
166: 'Walker hound, Walker foxhound',
|
169 |
+
167: 'English foxhound',
|
170 |
+
168: 'redbone',
|
171 |
+
169: 'borzoi, Russian wolfhound',
|
172 |
+
170: 'Irish wolfhound',
|
173 |
+
171: 'Italian greyhound',
|
174 |
+
172: 'whippet',
|
175 |
+
173: 'Ibizan hound, Ibizan Podenco',
|
176 |
+
174: 'Norwegian elkhound, elkhound',
|
177 |
+
175: 'otterhound, otter hound',
|
178 |
+
176: 'Saluki, gazelle hound',
|
179 |
+
177: 'Scottish deerhound, deerhound',
|
180 |
+
178: 'Weimaraner',
|
181 |
+
179: 'Staffordshire bullterrier, Staffordshire bull terrier',
|
182 |
+
180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
|
183 |
+
181: 'Bedlington terrier',
|
184 |
+
182: 'Border terrier',
|
185 |
+
183: 'Kerry blue terrier',
|
186 |
+
184: 'Irish terrier',
|
187 |
+
185: 'Norfolk terrier',
|
188 |
+
186: 'Norwich terrier',
|
189 |
+
187: 'Yorkshire terrier',
|
190 |
+
188: 'wire-haired fox terrier',
|
191 |
+
189: 'Lakeland terrier',
|
192 |
+
190: 'Sealyham terrier, Sealyham',
|
193 |
+
191: 'Airedale, Airedale terrier',
|
194 |
+
192: 'cairn, cairn terrier',
|
195 |
+
193: 'Australian terrier',
|
196 |
+
194: 'Dandie Dinmont, Dandie Dinmont terrier',
|
197 |
+
195: 'Boston bull, Boston terrier',
|
198 |
+
196: 'miniature schnauzer',
|
199 |
+
197: 'giant schnauzer',
|
200 |
+
198: 'standard schnauzer',
|
201 |
+
199: 'Scotch terrier, Scottish terrier, Scottie',
|
202 |
+
200: 'Tibetan terrier, chrysanthemum dog',
|
203 |
+
201: 'silky terrier, Sydney silky',
|
204 |
+
202: 'soft-coated wheaten terrier',
|
205 |
+
203: 'West Highland white terrier',
|
206 |
+
204: 'Lhasa, Lhasa apso',
|
207 |
+
205: 'flat-coated retriever',
|
208 |
+
206: 'curly-coated retriever',
|
209 |
+
207: 'golden retriever',
|
210 |
+
208: 'Labrador retriever',
|
211 |
+
209: 'Chesapeake Bay retriever',
|
212 |
+
210: 'German short-haired pointer',
|
213 |
+
211: 'vizsla, Hungarian pointer',
|
214 |
+
212: 'English setter',
|
215 |
+
213: 'Irish setter, red setter',
|
216 |
+
214: 'Gordon setter',
|
217 |
+
215: 'Brittany spaniel',
|
218 |
+
216: 'clumber, clumber spaniel',
|
219 |
+
217: 'English springer, English springer spaniel',
|
220 |
+
218: 'Welsh springer spaniel',
|
221 |
+
219: 'cocker spaniel, English cocker spaniel, cocker',
|
222 |
+
220: 'Sussex spaniel',
|
223 |
+
221: 'Irish water spaniel',
|
224 |
+
222: 'kuvasz',
|
225 |
+
223: 'schipperke',
|
226 |
+
224: 'groenendael',
|
227 |
+
225: 'malinois',
|
228 |
+
226: 'briard',
|
229 |
+
227: 'kelpie',
|
230 |
+
228: 'komondor',
|
231 |
+
229: 'Old English sheepdog, bobtail',
|
232 |
+
230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
|
233 |
+
231: 'collie',
|
234 |
+
232: 'Border collie',
|
235 |
+
233: 'Bouvier des Flandres, Bouviers des Flandres',
|
236 |
+
234: 'Rottweiler',
|
237 |
+
235: 'German shepherd, German shepherd dog, German police dog, alsatian',
|
238 |
+
236: 'Doberman, Doberman pinscher',
|
239 |
+
237: 'miniature pinscher',
|
240 |
+
238: 'Greater Swiss Mountain dog',
|
241 |
+
239: 'Bernese mountain dog',
|
242 |
+
240: 'Appenzeller',
|
243 |
+
241: 'EntleBucher',
|
244 |
+
242: 'boxer',
|
245 |
+
243: 'bull mastiff',
|
246 |
+
244: 'Tibetan mastiff',
|
247 |
+
245: 'French bulldog',
|
248 |
+
246: 'Great Dane',
|
249 |
+
247: 'Saint Bernard, St Bernard',
|
250 |
+
248: 'Eskimo dog, husky',
|
251 |
+
249: 'malamute, malemute, Alaskan malamute',
|
252 |
+
250: 'Siberian husky',
|
253 |
+
251: 'dalmatian, coach dog, carriage dog',
|
254 |
+
252: 'affenpinscher, monkey pinscher, monkey dog',
|
255 |
+
253: 'basenji',
|
256 |
+
254: 'pug, pug-dog',
|
257 |
+
255: 'Leonberg',
|
258 |
+
256: 'Newfoundland, Newfoundland dog',
|
259 |
+
257: 'Great Pyrenees',
|
260 |
+
258: 'Samoyed, Samoyede',
|
261 |
+
259: 'Pomeranian',
|
262 |
+
260: 'chow, chow chow',
|
263 |
+
261: 'keeshond',
|
264 |
+
262: 'Brabancon griffon',
|
265 |
+
263: 'Pembroke, Pembroke Welsh corgi',
|
266 |
+
264: 'Cardigan, Cardigan Welsh corgi',
|
267 |
+
265: 'toy poodle',
|
268 |
+
266: 'miniature poodle',
|
269 |
+
267: 'standard poodle',
|
270 |
+
268: 'Mexican hairless',
|
271 |
+
269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
|
272 |
+
270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
|
273 |
+
271: 'red wolf, maned wolf, Canis rufus, Canis niger',
|
274 |
+
272: 'coyote, prairie wolf, brush wolf, Canis latrans',
|
275 |
+
273: 'dingo, warrigal, warragal, Canis dingo',
|
276 |
+
274: 'dhole, Cuon alpinus',
|
277 |
+
275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
|
278 |
+
276: 'hyena, hyaena',
|
279 |
+
277: 'red fox, Vulpes vulpes',
|
280 |
+
278: 'kit fox, Vulpes macrotis',
|
281 |
+
279: 'Arctic fox, white fox, Alopex lagopus',
|
282 |
+
280: 'grey fox, gray fox, Urocyon cinereoargenteus',
|
283 |
+
281: 'tabby, tabby cat',
|
284 |
+
282: 'tiger cat',
|
285 |
+
283: 'Persian cat',
|
286 |
+
284: 'Siamese cat, Siamese',
|
287 |
+
285: 'Egyptian cat',
|
288 |
+
286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
|
289 |
+
287: 'lynx, catamount',
|
290 |
+
288: 'leopard, Panthera pardus',
|
291 |
+
289: 'snow leopard, ounce, Panthera uncia',
|
292 |
+
290: 'jaguar, panther, Panthera onca, Felis onca',
|
293 |
+
291: 'lion, king of beasts, Panthera leo',
|
294 |
+
292: 'tiger, Panthera tigris',
|
295 |
+
293: 'cheetah, chetah, Acinonyx jubatus',
|
296 |
+
294: 'brown bear, bruin, Ursus arctos',
|
297 |
+
295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
|
298 |
+
296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
|
299 |
+
297: 'sloth bear, Melursus ursinus, Ursus ursinus',
|
300 |
+
298: 'mongoose',
|
301 |
+
299: 'meerkat, mierkat',
|
302 |
+
300: 'tiger beetle',
|
303 |
+
301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
|
304 |
+
302: 'ground beetle, carabid beetle',
|
305 |
+
303: 'long-horned beetle, longicorn, longicorn beetle',
|
306 |
+
304: 'leaf beetle, chrysomelid',
|
307 |
+
305: 'dung beetle',
|
308 |
+
306: 'rhinoceros beetle',
|
309 |
+
307: 'weevil',
|
310 |
+
308: 'fly',
|
311 |
+
309: 'bee',
|
312 |
+
310: 'ant, emmet, pismire',
|
313 |
+
311: 'grasshopper, hopper',
|
314 |
+
312: 'cricket',
|
315 |
+
313: 'walking stick, walkingstick, stick insect',
|
316 |
+
314: 'cockroach, roach',
|
317 |
+
315: 'mantis, mantid',
|
318 |
+
316: 'cicada, cicala',
|
319 |
+
317: 'leafhopper',
|
320 |
+
318: 'lacewing, lacewing fly',
|
321 |
+
319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
|
322 |
+
320: 'damselfly',
|
323 |
+
321: 'admiral',
|
324 |
+
322: 'ringlet, ringlet butterfly',
|
325 |
+
323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
|
326 |
+
324: 'cabbage butterfly',
|
327 |
+
325: 'sulphur butterfly, sulfur butterfly',
|
328 |
+
326: 'lycaenid, lycaenid butterfly',
|
329 |
+
327: 'starfish, sea star',
|
330 |
+
328: 'sea urchin',
|
331 |
+
329: 'sea cucumber, holothurian',
|
332 |
+
330: 'wood rabbit, cottontail, cottontail rabbit',
|
333 |
+
331: 'hare',
|
334 |
+
332: 'Angora, Angora rabbit',
|
335 |
+
333: 'hamster',
|
336 |
+
334: 'porcupine, hedgehog',
|
337 |
+
335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
|
338 |
+
336: 'marmot',
|
339 |
+
337: 'beaver',
|
340 |
+
338: 'guinea pig, Cavia cobaya',
|
341 |
+
339: 'sorrel',
|
342 |
+
340: 'zebra',
|
343 |
+
341: 'hog, pig, grunter, squealer, Sus scrofa',
|
344 |
+
342: 'wild boar, boar, Sus scrofa',
|
345 |
+
343: 'warthog',
|
346 |
+
344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
|
347 |
+
345: 'ox',
|
348 |
+
346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
|
349 |
+
347: 'bison',
|
350 |
+
348: 'ram, tup',
|
351 |
+
349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
|
352 |
+
350: 'ibex, Capra ibex',
|
353 |
+
351: 'hartebeest',
|
354 |
+
352: 'impala, Aepyceros melampus',
|
355 |
+
353: 'gazelle',
|
356 |
+
354: 'Arabian camel, dromedary, Camelus dromedarius',
|
357 |
+
355: 'llama',
|
358 |
+
356: 'weasel',
|
359 |
+
357: 'mink',
|
360 |
+
358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
|
361 |
+
359: 'black-footed ferret, ferret, Mustela nigripes',
|
362 |
+
360: 'otter',
|
363 |
+
361: 'skunk, polecat, wood pussy',
|
364 |
+
362: 'badger',
|
365 |
+
363: 'armadillo',
|
366 |
+
364: 'three-toed sloth, ai, Bradypus tridactylus',
|
367 |
+
365: 'orangutan, orang, orangutang, Pongo pygmaeus',
|
368 |
+
366: 'gorilla, Gorilla gorilla',
|
369 |
+
367: 'chimpanzee, chimp, Pan troglodytes',
|
370 |
+
368: 'gibbon, Hylobates lar',
|
371 |
+
369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
|
372 |
+
370: 'guenon, guenon monkey',
|
373 |
+
371: 'patas, hussar monkey, Erythrocebus patas',
|
374 |
+
372: 'baboon',
|
375 |
+
373: 'macaque',
|
376 |
+
374: 'langur',
|
377 |
+
375: 'colobus, colobus monkey',
|
378 |
+
376: 'proboscis monkey, Nasalis larvatus',
|
379 |
+
377: 'marmoset',
|
380 |
+
378: 'capuchin, ringtail, Cebus capucinus',
|
381 |
+
379: 'howler monkey, howler',
|
382 |
+
380: 'titi, titi monkey',
|
383 |
+
381: 'spider monkey, Ateles geoffroyi',
|
384 |
+
382: 'squirrel monkey, Saimiri sciureus',
|
385 |
+
383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
|
386 |
+
384: 'indri, indris, Indri indri, Indri brevicaudatus',
|
387 |
+
385: 'Indian elephant, Elephas maximus',
|
388 |
+
386: 'African elephant, Loxodonta africana',
|
389 |
+
387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
|
390 |
+
388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
|
391 |
+
389: 'barracouta, snoek',
|
392 |
+
390: 'eel',
|
393 |
+
391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
|
394 |
+
392: 'rock beauty, Holocanthus tricolor',
|
395 |
+
393: 'anemone fish',
|
396 |
+
394: 'sturgeon',
|
397 |
+
395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
|
398 |
+
396: 'lionfish',
|
399 |
+
397: 'puffer, pufferfish, blowfish, globefish',
|
400 |
+
398: 'abacus',
|
401 |
+
399: 'abaya',
|
402 |
+
400: "academic gown, academic robe, judge's robe",
|
403 |
+
401: 'accordion, piano accordion, squeeze box',
|
404 |
+
402: 'acoustic guitar',
|
405 |
+
403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
|
406 |
+
404: 'airliner',
|
407 |
+
405: 'airship, dirigible',
|
408 |
+
406: 'altar',
|
409 |
+
407: 'ambulance',
|
410 |
+
408: 'amphibian, amphibious vehicle',
|
411 |
+
409: 'analog clock',
|
412 |
+
410: 'apiary, bee house',
|
413 |
+
411: 'apron',
|
414 |
+
412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
|
415 |
+
413: 'assault rifle, assault gun',
|
416 |
+
414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
|
417 |
+
415: 'bakery, bakeshop, bakehouse',
|
418 |
+
416: 'balance beam, beam',
|
419 |
+
417: 'balloon',
|
420 |
+
418: 'ballpoint, ballpoint pen, ballpen, Biro',
|
421 |
+
419: 'Band Aid',
|
422 |
+
420: 'banjo',
|
423 |
+
421: 'bannister, banister, balustrade, balusters, handrail',
|
424 |
+
422: 'barbell',
|
425 |
+
423: 'barber chair',
|
426 |
+
424: 'barbershop',
|
427 |
+
425: 'barn',
|
428 |
+
426: 'barometer',
|
429 |
+
427: 'barrel, cask',
|
430 |
+
428: 'barrow, garden cart, lawn cart, wheelbarrow',
|
431 |
+
429: 'baseball',
|
432 |
+
430: 'basketball',
|
433 |
+
431: 'bassinet',
|
434 |
+
432: 'bassoon',
|
435 |
+
433: 'bathing cap, swimming cap',
|
436 |
+
434: 'bath towel',
|
437 |
+
435: 'bathtub, bathing tub, bath, tub',
|
438 |
+
436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
|
439 |
+
437: 'beacon, lighthouse, beacon light, pharos',
|
440 |
+
438: 'beaker',
|
441 |
+
439: 'bearskin, busby, shako',
|
442 |
+
440: 'beer bottle',
|
443 |
+
441: 'beer glass',
|
444 |
+
442: 'bell cote, bell cot',
|
445 |
+
443: 'bib',
|
446 |
+
444: 'bicycle-built-for-two, tandem bicycle, tandem',
|
447 |
+
445: 'bikini, two-piece',
|
448 |
+
446: 'binder, ring-binder',
|
449 |
+
447: 'binoculars, field glasses, opera glasses',
|
450 |
+
448: 'birdhouse',
|
451 |
+
449: 'boathouse',
|
452 |
+
450: 'bobsled, bobsleigh, bob',
|
453 |
+
451: 'bolo tie, bolo, bola tie, bola',
|
454 |
+
452: 'bonnet, poke bonnet',
|
455 |
+
453: 'bookcase',
|
456 |
+
454: 'bookshop, bookstore, bookstall',
|
457 |
+
455: 'bottlecap',
|
458 |
+
456: 'bow',
|
459 |
+
457: 'bow tie, bow-tie, bowtie',
|
460 |
+
458: 'brass, memorial tablet, plaque',
|
461 |
+
459: 'brassiere, bra, bandeau',
|
462 |
+
460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
|
463 |
+
461: 'breastplate, aegis, egis',
|
464 |
+
462: 'broom',
|
465 |
+
463: 'bucket, pail',
|
466 |
+
464: 'buckle',
|
467 |
+
465: 'bulletproof vest',
|
468 |
+
466: 'bullet train, bullet',
|
469 |
+
467: 'butcher shop, meat market',
|
470 |
+
468: 'cab, hack, taxi, taxicab',
|
471 |
+
469: 'caldron, cauldron',
|
472 |
+
470: 'candle, taper, wax light',
|
473 |
+
471: 'cannon',
|
474 |
+
472: 'canoe',
|
475 |
+
473: 'can opener, tin opener',
|
476 |
+
474: 'cardigan',
|
477 |
+
475: 'car mirror',
|
478 |
+
476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
|
479 |
+
477: "carpenter's kit, tool kit",
|
480 |
+
478: 'carton',
|
481 |
+
479: 'car wheel',
|
482 |
+
480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
|
483 |
+
481: 'cassette',
|
484 |
+
482: 'cassette player',
|
485 |
+
483: 'castle',
|
486 |
+
484: 'catamaran',
|
487 |
+
485: 'CD player',
|
488 |
+
486: 'cello, violoncello',
|
489 |
+
487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
|
490 |
+
488: 'chain',
|
491 |
+
489: 'chainlink fence',
|
492 |
+
490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
|
493 |
+
491: 'chain saw, chainsaw',
|
494 |
+
492: 'chest',
|
495 |
+
493: 'chiffonier, commode',
|
496 |
+
494: 'chime, bell, gong',
|
497 |
+
495: 'china cabinet, china closet',
|
498 |
+
496: 'Christmas stocking',
|
499 |
+
497: 'church, church building',
|
500 |
+
498: 'cinema, movie theater, movie theatre, movie house, picture palace',
|
501 |
+
499: 'cleaver, meat cleaver, chopper',
|
502 |
+
500: 'cliff dwelling',
|
503 |
+
501: 'cloak',
|
504 |
+
502: 'clog, geta, patten, sabot',
|
505 |
+
503: 'cocktail shaker',
|
506 |
+
504: 'coffee mug',
|
507 |
+
505: 'coffeepot',
|
508 |
+
506: 'coil, spiral, volute, whorl, helix',
|
509 |
+
507: 'combination lock',
|
510 |
+
508: 'computer keyboard, keypad',
|
511 |
+
509: 'confectionery, confectionary, candy store',
|
512 |
+
510: 'container ship, containership, container vessel',
|
513 |
+
511: 'convertible',
|
514 |
+
512: 'corkscrew, bottle screw',
|
515 |
+
513: 'cornet, horn, trumpet, trump',
|
516 |
+
514: 'cowboy boot',
|
517 |
+
515: 'cowboy hat, ten-gallon hat',
|
518 |
+
516: 'cradle',
|
519 |
+
517: 'crane',
|
520 |
+
518: 'crash helmet',
|
521 |
+
519: 'crate',
|
522 |
+
520: 'crib, cot',
|
523 |
+
521: 'Crock Pot',
|
524 |
+
522: 'croquet ball',
|
525 |
+
523: 'crutch',
|
526 |
+
524: 'cuirass',
|
527 |
+
525: 'dam, dike, dyke',
|
528 |
+
526: 'desk',
|
529 |
+
527: 'desktop computer',
|
530 |
+
528: 'dial telephone, dial phone',
|
531 |
+
529: 'diaper, nappy, napkin',
|
532 |
+
530: 'digital clock',
|
533 |
+
531: 'digital watch',
|
534 |
+
532: 'dining table, board',
|
535 |
+
533: 'dishrag, dishcloth',
|
536 |
+
534: 'dishwasher, dish washer, dishwashing machine',
|
537 |
+
535: 'disk brake, disc brake',
|
538 |
+
536: 'dock, dockage, docking facility',
|
539 |
+
537: 'dogsled, dog sled, dog sleigh',
|
540 |
+
538: 'dome',
|
541 |
+
539: 'doormat, welcome mat',
|
542 |
+
540: 'drilling platform, offshore rig',
|
543 |
+
541: 'drum, membranophone, tympan',
|
544 |
+
542: 'drumstick',
|
545 |
+
543: 'dumbbell',
|
546 |
+
544: 'Dutch oven',
|
547 |
+
545: 'electric fan, blower',
|
548 |
+
546: 'electric guitar',
|
549 |
+
547: 'electric locomotive',
|
550 |
+
548: 'entertainment center',
|
551 |
+
549: 'envelope',
|
552 |
+
550: 'espresso maker',
|
553 |
+
551: 'face powder',
|
554 |
+
552: 'feather boa, boa',
|
555 |
+
553: 'file, file cabinet, filing cabinet',
|
556 |
+
554: 'fireboat',
|
557 |
+
555: 'fire engine, fire truck',
|
558 |
+
556: 'fire screen, fireguard',
|
559 |
+
557: 'flagpole, flagstaff',
|
560 |
+
558: 'flute, transverse flute',
|
561 |
+
559: 'folding chair',
|
562 |
+
560: 'football helmet',
|
563 |
+
561: 'forklift',
|
564 |
+
562: 'fountain',
|
565 |
+
563: 'fountain pen',
|
566 |
+
564: 'four-poster',
|
567 |
+
565: 'freight car',
|
568 |
+
566: 'French horn, horn',
|
569 |
+
567: 'frying pan, frypan, skillet',
|
570 |
+
568: 'fur coat',
|
571 |
+
569: 'garbage truck, dustcart',
|
572 |
+
570: 'gasmask, respirator, gas helmet',
|
573 |
+
571: 'gas pump, gasoline pump, petrol pump, island dispenser',
|
574 |
+
572: 'goblet',
|
575 |
+
573: 'go-kart',
|
576 |
+
574: 'golf ball',
|
577 |
+
575: 'golfcart, golf cart',
|
578 |
+
576: 'gondola',
|
579 |
+
577: 'gong, tam-tam',
|
580 |
+
578: 'gown',
|
581 |
+
579: 'grand piano, grand',
|
582 |
+
580: 'greenhouse, nursery, glasshouse',
|
583 |
+
581: 'grille, radiator grille',
|
584 |
+
582: 'grocery store, grocery, food market, market',
|
585 |
+
583: 'guillotine',
|
586 |
+
584: 'hair slide',
|
587 |
+
585: 'hair spray',
|
588 |
+
586: 'half track',
|
589 |
+
587: 'hammer',
|
590 |
+
588: 'hamper',
|
591 |
+
589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
|
592 |
+
590: 'hand-held computer, hand-held microcomputer',
|
593 |
+
591: 'handkerchief, hankie, hanky, hankey',
|
594 |
+
592: 'hard disc, hard disk, fixed disk',
|
595 |
+
593: 'harmonica, mouth organ, harp, mouth harp',
|
596 |
+
594: 'harp',
|
597 |
+
595: 'harvester, reaper',
|
598 |
+
596: 'hatchet',
|
599 |
+
597: 'holster',
|
600 |
+
598: 'home theater, home theatre',
|
601 |
+
599: 'honeycomb',
|
602 |
+
600: 'hook, claw',
|
603 |
+
601: 'hoopskirt, crinoline',
|
604 |
+
602: 'horizontal bar, high bar',
|
605 |
+
603: 'horse cart, horse-cart',
|
606 |
+
604: 'hourglass',
|
607 |
+
605: 'iPod',
|
608 |
+
606: 'iron, smoothing iron',
|
609 |
+
607: "jack-o'-lantern",
|
610 |
+
608: 'jean, blue jean, denim',
|
611 |
+
609: 'jeep, landrover',
|
612 |
+
610: 'jersey, T-shirt, tee shirt',
|
613 |
+
611: 'jigsaw puzzle',
|
614 |
+
612: 'jinrikisha, ricksha, rickshaw',
|
615 |
+
613: 'joystick',
|
616 |
+
614: 'kimono',
|
617 |
+
615: 'knee pad',
|
618 |
+
616: 'knot',
|
619 |
+
617: 'lab coat, laboratory coat',
|
620 |
+
618: 'ladle',
|
621 |
+
619: 'lampshade, lamp shade',
|
622 |
+
620: 'laptop, laptop computer',
|
623 |
+
621: 'lawn mower, mower',
|
624 |
+
622: 'lens cap, lens cover',
|
625 |
+
623: 'letter opener, paper knife, paperknife',
|
626 |
+
624: 'library',
|
627 |
+
625: 'lifeboat',
|
628 |
+
626: 'lighter, light, igniter, ignitor',
|
629 |
+
627: 'limousine, limo',
|
630 |
+
628: 'liner, ocean liner',
|
631 |
+
629: 'lipstick, lip rouge',
|
632 |
+
630: 'Loafer',
|
633 |
+
631: 'lotion',
|
634 |
+
632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
|
635 |
+
633: "loupe, jeweler's loupe",
|
636 |
+
634: 'lumbermill, sawmill',
|
637 |
+
635: 'magnetic compass',
|
638 |
+
636: 'mailbag, postbag',
|
639 |
+
637: 'mailbox, letter box',
|
640 |
+
638: 'maillot',
|
641 |
+
639: 'maillot, tank suit',
|
642 |
+
640: 'manhole cover',
|
643 |
+
641: 'maraca',
|
644 |
+
642: 'marimba, xylophone',
|
645 |
+
643: 'mask',
|
646 |
+
644: 'matchstick',
|
647 |
+
645: 'maypole',
|
648 |
+
646: 'maze, labyrinth',
|
649 |
+
647: 'measuring cup',
|
650 |
+
648: 'medicine chest, medicine cabinet',
|
651 |
+
649: 'megalith, megalithic structure',
|
652 |
+
650: 'microphone, mike',
|
653 |
+
651: 'microwave, microwave oven',
|
654 |
+
652: 'military uniform',
|
655 |
+
653: 'milk can',
|
656 |
+
654: 'minibus',
|
657 |
+
655: 'miniskirt, mini',
|
658 |
+
656: 'minivan',
|
659 |
+
657: 'missile',
|
660 |
+
658: 'mitten',
|
661 |
+
659: 'mixing bowl',
|
662 |
+
660: 'mobile home, manufactured home',
|
663 |
+
661: 'Model T',
|
664 |
+
662: 'modem',
|
665 |
+
663: 'monastery',
|
666 |
+
664: 'monitor',
|
667 |
+
665: 'moped',
|
668 |
+
666: 'mortar',
|
669 |
+
667: 'mortarboard',
|
670 |
+
668: 'mosque',
|
671 |
+
669: 'mosquito net',
|
672 |
+
670: 'motor scooter, scooter',
|
673 |
+
671: 'mountain bike, all-terrain bike, off-roader',
|
674 |
+
672: 'mountain tent',
|
675 |
+
673: 'mouse, computer mouse',
|
676 |
+
674: 'mousetrap',
|
677 |
+
675: 'moving van',
|
678 |
+
676: 'muzzle',
|
679 |
+
677: 'nail',
|
680 |
+
678: 'neck brace',
|
681 |
+
679: 'necklace',
|
682 |
+
680: 'nipple',
|
683 |
+
681: 'notebook, notebook computer',
|
684 |
+
682: 'obelisk',
|
685 |
+
683: 'oboe, hautboy, hautbois',
|
686 |
+
684: 'ocarina, sweet potato',
|
687 |
+
685: 'odometer, hodometer, mileometer, milometer',
|
688 |
+
686: 'oil filter',
|
689 |
+
687: 'organ, pipe organ',
|
690 |
+
688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
|
691 |
+
689: 'overskirt',
|
692 |
+
690: 'oxcart',
|
693 |
+
691: 'oxygen mask',
|
694 |
+
692: 'packet',
|
695 |
+
693: 'paddle, boat paddle',
|
696 |
+
694: 'paddlewheel, paddle wheel',
|
697 |
+
695: 'padlock',
|
698 |
+
696: 'paintbrush',
|
699 |
+
697: "pajama, pyjama, pj's, jammies",
|
700 |
+
698: 'palace',
|
701 |
+
699: 'panpipe, pandean pipe, syrinx',
|
702 |
+
700: 'paper towel',
|
703 |
+
701: 'parachute, chute',
|
704 |
+
702: 'parallel bars, bars',
|
705 |
+
703: 'park bench',
|
706 |
+
704: 'parking meter',
|
707 |
+
705: 'passenger car, coach, carriage',
|
708 |
+
706: 'patio, terrace',
|
709 |
+
707: 'pay-phone, pay-station',
|
710 |
+
708: 'pedestal, plinth, footstall',
|
711 |
+
709: 'pencil box, pencil case',
|
712 |
+
710: 'pencil sharpener',
|
713 |
+
711: 'perfume, essence',
|
714 |
+
712: 'Petri dish',
|
715 |
+
713: 'photocopier',
|
716 |
+
714: 'pick, plectrum, plectron',
|
717 |
+
715: 'pickelhaube',
|
718 |
+
716: 'picket fence, paling',
|
719 |
+
717: 'pickup, pickup truck',
|
720 |
+
718: 'pier',
|
721 |
+
719: 'piggy bank, penny bank',
|
722 |
+
720: 'pill bottle',
|
723 |
+
721: 'pillow',
|
724 |
+
722: 'ping-pong ball',
|
725 |
+
723: 'pinwheel',
|
726 |
+
724: 'pirate, pirate ship',
|
727 |
+
725: 'pitcher, ewer',
|
728 |
+
726: "plane, carpenter's plane, woodworking plane",
|
729 |
+
727: 'planetarium',
|
730 |
+
728: 'plastic bag',
|
731 |
+
729: 'plate rack',
|
732 |
+
730: 'plow, plough',
|
733 |
+
731: "plunger, plumber's helper",
|
734 |
+
732: 'Polaroid camera, Polaroid Land camera',
|
735 |
+
733: 'pole',
|
736 |
+
734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
|
737 |
+
735: 'poncho',
|
738 |
+
736: 'pool table, billiard table, snooker table',
|
739 |
+
737: 'pop bottle, soda bottle',
|
740 |
+
738: 'pot, flowerpot',
|
741 |
+
739: "potter's wheel",
|
742 |
+
740: 'power drill',
|
743 |
+
741: 'prayer rug, prayer mat',
|
744 |
+
742: 'printer',
|
745 |
+
743: 'prison, prison house',
|
746 |
+
744: 'projectile, missile',
|
747 |
+
745: 'projector',
|
748 |
+
746: 'puck, hockey puck',
|
749 |
+
747: 'punching bag, punch bag, punching ball, punchball',
|
750 |
+
748: 'purse',
|
751 |
+
749: 'quill, quill pen',
|
752 |
+
750: 'quilt, comforter, comfort, puff',
|
753 |
+
751: 'racer, race car, racing car',
|
754 |
+
752: 'racket, racquet',
|
755 |
+
753: 'radiator',
|
756 |
+
754: 'radio, wireless',
|
757 |
+
755: 'radio telescope, radio reflector',
|
758 |
+
756: 'rain barrel',
|
759 |
+
757: 'recreational vehicle, RV, R.V.',
|
760 |
+
758: 'reel',
|
761 |
+
759: 'reflex camera',
|
762 |
+
760: 'refrigerator, icebox',
|
763 |
+
761: 'remote control, remote',
|
764 |
+
762: 'restaurant, eating house, eating place, eatery',
|
765 |
+
763: 'revolver, six-gun, six-shooter',
|
766 |
+
764: 'rifle',
|
767 |
+
765: 'rocking chair, rocker',
|
768 |
+
766: 'rotisserie',
|
769 |
+
767: 'rubber eraser, rubber, pencil eraser',
|
770 |
+
768: 'rugby ball',
|
771 |
+
769: 'rule, ruler',
|
772 |
+
770: 'running shoe',
|
773 |
+
771: 'safe',
|
774 |
+
772: 'safety pin',
|
775 |
+
773: 'saltshaker, salt shaker',
|
776 |
+
774: 'sandal',
|
777 |
+
775: 'sarong',
|
778 |
+
776: 'sax, saxophone',
|
779 |
+
777: 'scabbard',
|
780 |
+
778: 'scale, weighing machine',
|
781 |
+
779: 'school bus',
|
782 |
+
780: 'schooner',
|
783 |
+
781: 'scoreboard',
|
784 |
+
782: 'screen, CRT screen',
|
785 |
+
783: 'screw',
|
786 |
+
784: 'screwdriver',
|
787 |
+
785: 'seat belt, seatbelt',
|
788 |
+
786: 'sewing machine',
|
789 |
+
787: 'shield, buckler',
|
790 |
+
788: 'shoe shop, shoe-shop, shoe store',
|
791 |
+
789: 'shoji',
|
792 |
+
790: 'shopping basket',
|
793 |
+
791: 'shopping cart',
|
794 |
+
792: 'shovel',
|
795 |
+
793: 'shower cap',
|
796 |
+
794: 'shower curtain',
|
797 |
+
795: 'ski',
|
798 |
+
796: 'ski mask',
|
799 |
+
797: 'sleeping bag',
|
800 |
+
798: 'slide rule, slipstick',
|
801 |
+
799: 'sliding door',
|
802 |
+
800: 'slot, one-armed bandit',
|
803 |
+
801: 'snorkel',
|
804 |
+
802: 'snowmobile',
|
805 |
+
803: 'snowplow, snowplough',
|
806 |
+
804: 'soap dispenser',
|
807 |
+
805: 'soccer ball',
|
808 |
+
806: 'sock',
|
809 |
+
807: 'solar dish, solar collector, solar furnace',
|
810 |
+
808: 'sombrero',
|
811 |
+
809: 'soup bowl',
|
812 |
+
810: 'space bar',
|
813 |
+
811: 'space heater',
|
814 |
+
812: 'space shuttle',
|
815 |
+
813: 'spatula',
|
816 |
+
814: 'speedboat',
|
817 |
+
815: "spider web, spider's web",
|
818 |
+
816: 'spindle',
|
819 |
+
817: 'sports car, sport car',
|
820 |
+
818: 'spotlight, spot',
|
821 |
+
819: 'stage',
|
822 |
+
820: 'steam locomotive',
|
823 |
+
821: 'steel arch bridge',
|
824 |
+
822: 'steel drum',
|
825 |
+
823: 'stethoscope',
|
826 |
+
824: 'stole',
|
827 |
+
825: 'stone wall',
|
828 |
+
826: 'stopwatch, stop watch',
|
829 |
+
827: 'stove',
|
830 |
+
828: 'strainer',
|
831 |
+
829: 'streetcar, tram, tramcar, trolley, trolley car',
|
832 |
+
830: 'stretcher',
|
833 |
+
831: 'studio couch, day bed',
|
834 |
+
832: 'stupa, tope',
|
835 |
+
833: 'submarine, pigboat, sub, U-boat',
|
836 |
+
834: 'suit, suit of clothes',
|
837 |
+
835: 'sundial',
|
838 |
+
836: 'sunglass',
|
839 |
+
837: 'sunglasses, dark glasses, shades',
|
840 |
+
838: 'sunscreen, sunblock, sun blocker',
|
841 |
+
839: 'suspension bridge',
|
842 |
+
840: 'swab, swob, mop',
|
843 |
+
841: 'sweatshirt',
|
844 |
+
842: 'swimming trunks, bathing trunks',
|
845 |
+
843: 'swing',
|
846 |
+
844: 'switch, electric switch, electrical switch',
|
847 |
+
845: 'syringe',
|
848 |
+
846: 'table lamp',
|
849 |
+
847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
|
850 |
+
848: 'tape player',
|
851 |
+
849: 'teapot',
|
852 |
+
850: 'teddy, teddy bear',
|
853 |
+
851: 'television, television system',
|
854 |
+
852: 'tennis ball',
|
855 |
+
853: 'thatch, thatched roof',
|
856 |
+
854: 'theater curtain, theatre curtain',
|
857 |
+
855: 'thimble',
|
858 |
+
856: 'thresher, thrasher, threshing machine',
|
859 |
+
857: 'throne',
|
860 |
+
858: 'tile roof',
|
861 |
+
859: 'toaster',
|
862 |
+
860: 'tobacco shop, tobacconist shop, tobacconist',
|
863 |
+
861: 'toilet seat',
|
864 |
+
862: 'torch',
|
865 |
+
863: 'totem pole',
|
866 |
+
864: 'tow truck, tow car, wrecker',
|
867 |
+
865: 'toyshop',
|
868 |
+
866: 'tractor',
|
869 |
+
867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
|
870 |
+
868: 'tray',
|
871 |
+
869: 'trench coat',
|
872 |
+
870: 'tricycle, trike, velocipede',
|
873 |
+
871: 'trimaran',
|
874 |
+
872: 'tripod',
|
875 |
+
873: 'triumphal arch',
|
876 |
+
874: 'trolleybus, trolley coach, trackless trolley',
|
877 |
+
875: 'trombone',
|
878 |
+
876: 'tub, vat',
|
879 |
+
877: 'turnstile',
|
880 |
+
878: 'typewriter keyboard',
|
881 |
+
879: 'umbrella',
|
882 |
+
880: 'unicycle, monocycle',
|
883 |
+
881: 'upright, upright piano',
|
884 |
+
882: 'vacuum, vacuum cleaner',
|
885 |
+
883: 'vase',
|
886 |
+
884: 'vault',
|
887 |
+
885: 'velvet',
|
888 |
+
886: 'vending machine',
|
889 |
+
887: 'vestment',
|
890 |
+
888: 'viaduct',
|
891 |
+
889: 'violin, fiddle',
|
892 |
+
890: 'volleyball',
|
893 |
+
891: 'waffle iron',
|
894 |
+
892: 'wall clock',
|
895 |
+
893: 'wallet, billfold, notecase, pocketbook',
|
896 |
+
894: 'wardrobe, closet, press',
|
897 |
+
895: 'warplane, military plane',
|
898 |
+
896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
|
899 |
+
897: 'washer, automatic washer, washing machine',
|
900 |
+
898: 'water bottle',
|
901 |
+
899: 'water jug',
|
902 |
+
900: 'water tower',
|
903 |
+
901: 'whiskey jug',
|
904 |
+
902: 'whistle',
|
905 |
+
903: 'wig',
|
906 |
+
904: 'window screen',
|
907 |
+
905: 'window shade',
|
908 |
+
906: 'Windsor tie',
|
909 |
+
907: 'wine bottle',
|
910 |
+
908: 'wing',
|
911 |
+
909: 'wok',
|
912 |
+
910: 'wooden spoon',
|
913 |
+
911: 'wool, woolen, woollen',
|
914 |
+
912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
|
915 |
+
913: 'wreck',
|
916 |
+
914: 'yawl',
|
917 |
+
915: 'yurt',
|
918 |
+
916: 'web site, website, internet site, site',
|
919 |
+
917: 'comic book',
|
920 |
+
918: 'crossword puzzle, crossword',
|
921 |
+
919: 'street sign',
|
922 |
+
920: 'traffic light, traffic signal, stoplight',
|
923 |
+
921: 'book jacket, dust cover, dust jacket, dust wrapper',
|
924 |
+
922: 'menu',
|
925 |
+
923: 'plate',
|
926 |
+
924: 'guacamole',
|
927 |
+
925: 'consomme',
|
928 |
+
926: 'hot pot, hotpot',
|
929 |
+
927: 'trifle',
|
930 |
+
928: 'ice cream, icecream',
|
931 |
+
929: 'ice lolly, lolly, lollipop, popsicle',
|
932 |
+
930: 'French loaf',
|
933 |
+
931: 'bagel, beigel',
|
934 |
+
932: 'pretzel',
|
935 |
+
933: 'cheeseburger',
|
936 |
+
934: 'hotdog, hot dog, red hot',
|
937 |
+
935: 'mashed potato',
|
938 |
+
936: 'head cabbage',
|
939 |
+
937: 'broccoli',
|
940 |
+
938: 'cauliflower',
|
941 |
+
939: 'zucchini, courgette',
|
942 |
+
940: 'spaghetti squash',
|
943 |
+
941: 'acorn squash',
|
944 |
+
942: 'butternut squash',
|
945 |
+
943: 'cucumber, cuke',
|
946 |
+
944: 'artichoke, globe artichoke',
|
947 |
+
945: 'bell pepper',
|
948 |
+
946: 'cardoon',
|
949 |
+
947: 'mushroom',
|
950 |
+
948: 'Granny Smith',
|
951 |
+
949: 'strawberry',
|
952 |
+
950: 'orange',
|
953 |
+
951: 'lemon',
|
954 |
+
952: 'fig',
|
955 |
+
953: 'pineapple, ananas',
|
956 |
+
954: 'banana',
|
957 |
+
955: 'jackfruit, jak, jack',
|
958 |
+
956: 'custard apple',
|
959 |
+
957: 'pomegranate',
|
960 |
+
958: 'hay',
|
961 |
+
959: 'carbonara',
|
962 |
+
960: 'chocolate sauce, chocolate syrup',
|
963 |
+
961: 'dough',
|
964 |
+
962: 'meat loaf, meatloaf',
|
965 |
+
963: 'pizza, pizza pie',
|
966 |
+
964: 'potpie',
|
967 |
+
965: 'burrito',
|
968 |
+
966: 'red wine',
|
969 |
+
967: 'espresso',
|
970 |
+
968: 'cup',
|
971 |
+
969: 'eggnog',
|
972 |
+
970: 'alp',
|
973 |
+
971: 'bubble',
|
974 |
+
972: 'cliff, drop, drop-off',
|
975 |
+
973: 'coral reef',
|
976 |
+
974: 'geyser',
|
977 |
+
975: 'lakeside, lakeshore',
|
978 |
+
976: 'promontory, headland, head, foreland',
|
979 |
+
977: 'sandbar, sand bar',
|
980 |
+
978: 'seashore, coast, seacoast, sea-coast',
|
981 |
+
979: 'valley, vale',
|
982 |
+
980: 'volcano',
|
983 |
+
981: 'ballplayer, baseball player',
|
984 |
+
982: 'groom, bridegroom',
|
985 |
+
983: 'scuba diver',
|
986 |
+
984: 'rapeseed',
|
987 |
+
985: 'daisy',
|
988 |
+
986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
|
989 |
+
987: 'corn',
|
990 |
+
988: 'acorn',
|
991 |
+
989: 'hip, rose hip, rosehip',
|
992 |
+
990: 'buckeye, horse chestnut, conker',
|
993 |
+
991: 'coral fungus',
|
994 |
+
992: 'agaric',
|
995 |
+
993: 'gyromitra',
|
996 |
+
994: 'stinkhorn, carrion fungus',
|
997 |
+
995: 'earthstar',
|
998 |
+
996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
|
999 |
+
997: 'bolete',
|
1000 |
+
998: 'ear, spike, capitulum',
|
1001 |
+
999: 'toilet tissue, toilet paper, bathroom tissue'
|
1002 |
+
}
|
SegmentationTest/data/transforms.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import sys
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
try:
|
7 |
+
import accimage
|
8 |
+
except ImportError:
|
9 |
+
accimage = None
|
10 |
+
import numbers
|
11 |
+
import collections
|
12 |
+
|
13 |
+
from torchvision.transforms import functional as F
|
14 |
+
|
15 |
+
if sys.version_info < (3, 3):
|
16 |
+
Sequence = collections.Sequence
|
17 |
+
Iterable = collections.Iterable
|
18 |
+
else:
|
19 |
+
Sequence = collections.abc.Sequence
|
20 |
+
Iterable = collections.abc.Iterable
|
21 |
+
|
22 |
+
_pil_interpolation_to_str = {
|
23 |
+
Image.NEAREST: 'PIL.Image.NEAREST',
|
24 |
+
Image.BILINEAR: 'PIL.Image.BILINEAR',
|
25 |
+
Image.BICUBIC: 'PIL.Image.BICUBIC',
|
26 |
+
Image.LANCZOS: 'PIL.Image.LANCZOS',
|
27 |
+
Image.HAMMING: 'PIL.Image.HAMMING',
|
28 |
+
Image.BOX: 'PIL.Image.BOX',
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
class Compose(object):
|
33 |
+
"""Composes several transforms together.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
transforms (list of ``Transform`` objects): list of transforms to compose.
|
37 |
+
|
38 |
+
Example:
|
39 |
+
>>> transforms.Compose([
|
40 |
+
>>> transforms.CenterCrop(10),
|
41 |
+
>>> transforms.ToTensor(),
|
42 |
+
>>> ])
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self, transforms):
|
46 |
+
self.transforms = transforms
|
47 |
+
|
48 |
+
def __call__(self, img, tgt):
|
49 |
+
for t in self.transforms:
|
50 |
+
img, tgt = t(img, tgt)
|
51 |
+
return img, tgt
|
52 |
+
|
53 |
+
def __repr__(self):
|
54 |
+
format_string = self.__class__.__name__ + '('
|
55 |
+
for t in self.transforms:
|
56 |
+
format_string += '\n'
|
57 |
+
format_string += ' {0}'.format(t)
|
58 |
+
format_string += '\n)'
|
59 |
+
return format_string
|
60 |
+
|
61 |
+
|
62 |
+
class Resize(object):
|
63 |
+
"""Resize the input PIL Image to the given size.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
size (sequence or int): Desired output size. If size is a sequence like
|
67 |
+
(h, w), output size will be matched to this. If size is an int,
|
68 |
+
smaller edge of the image will be matched to this number.
|
69 |
+
i.e, if height > width, then image will be rescaled to
|
70 |
+
(size * height / width, size)
|
71 |
+
interpolation (int, optional): Desired interpolation. Default is
|
72 |
+
``PIL.Image.BILINEAR``
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, size, interpolation=Image.BILINEAR):
|
76 |
+
assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
|
77 |
+
self.size = size
|
78 |
+
self.interpolation = interpolation
|
79 |
+
|
80 |
+
def __call__(self, img, tgt):
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
img (PIL Image): Image to be scaled.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
PIL Image: Rescaled image.
|
87 |
+
"""
|
88 |
+
return F.resize(img, self.size, self.interpolation), F.resize(tgt, self.size, Image.NEAREST)
|
89 |
+
|
90 |
+
def __repr__(self):
|
91 |
+
interpolate_str = _pil_interpolation_to_str[self.interpolation]
|
92 |
+
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
|
93 |
+
|
94 |
+
|
95 |
+
class CenterCrop(object):
|
96 |
+
"""Crops the given PIL Image at the center.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
100 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
101 |
+
made.
|
102 |
+
"""
|
103 |
+
|
104 |
+
def __init__(self, size):
|
105 |
+
if isinstance(size, numbers.Number):
|
106 |
+
self.size = (int(size), int(size))
|
107 |
+
else:
|
108 |
+
self.size = size
|
109 |
+
|
110 |
+
def __call__(self, img, tgt):
|
111 |
+
"""
|
112 |
+
Args:
|
113 |
+
img (PIL Image): Image to be cropped.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
PIL Image: Cropped image.
|
117 |
+
"""
|
118 |
+
return F.center_crop(img, self.size), F.center_crop(tgt, self.size)
|
119 |
+
|
120 |
+
def __repr__(self):
|
121 |
+
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
122 |
+
|
123 |
+
|
124 |
+
class RandomCrop(object):
|
125 |
+
"""Crop the given PIL Image at a random location.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
129 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
130 |
+
made.
|
131 |
+
padding (int or sequence, optional): Optional padding on each border
|
132 |
+
of the image. Default is None, i.e no padding. If a sequence of length
|
133 |
+
4 is provided, it is used to pad left, top, right, bottom borders
|
134 |
+
respectively. If a sequence of length 2 is provided, it is used to
|
135 |
+
pad left/right, top/bottom borders, respectively.
|
136 |
+
pad_if_needed (boolean): It will pad the image if smaller than the
|
137 |
+
desired size to avoid raising an exception.
|
138 |
+
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
|
139 |
+
length 3, it is used to fill R, G, B channels respectively.
|
140 |
+
This value is only used when the padding_mode is constant
|
141 |
+
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
|
142 |
+
|
143 |
+
- constant: pads with a constant value, this value is specified with fill
|
144 |
+
|
145 |
+
- edge: pads with the last value on the edge of the image
|
146 |
+
|
147 |
+
- reflect: pads with reflection of image (without repeating the last value on the edge)
|
148 |
+
|
149 |
+
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
150 |
+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
151 |
+
|
152 |
+
- symmetric: pads with reflection of image (repeating the last value on the edge)
|
153 |
+
|
154 |
+
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
155 |
+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
156 |
+
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
|
160 |
+
if isinstance(size, numbers.Number):
|
161 |
+
self.size = (int(size), int(size))
|
162 |
+
else:
|
163 |
+
self.size = size
|
164 |
+
self.padding = padding
|
165 |
+
self.pad_if_needed = pad_if_needed
|
166 |
+
self.fill = fill
|
167 |
+
self.padding_mode = padding_mode
|
168 |
+
|
169 |
+
@staticmethod
|
170 |
+
def get_params(img, output_size):
|
171 |
+
"""Get parameters for ``crop`` for a random crop.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
img (PIL Image): Image to be cropped.
|
175 |
+
output_size (tuple): Expected output size of the crop.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
|
179 |
+
"""
|
180 |
+
w, h = img.size
|
181 |
+
th, tw = output_size
|
182 |
+
if w == tw and h == th:
|
183 |
+
return 0, 0, h, w
|
184 |
+
|
185 |
+
i = random.randint(0, h - th)
|
186 |
+
j = random.randint(0, w - tw)
|
187 |
+
return i, j, th, tw
|
188 |
+
|
189 |
+
def __call__(self, img, tgt):
|
190 |
+
"""
|
191 |
+
Args:
|
192 |
+
img (PIL Image): Image to be cropped.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
PIL Image: Cropped image.
|
196 |
+
"""
|
197 |
+
if self.padding is not None:
|
198 |
+
img = F.pad(img, self.padding, self.fill, self.padding_mode)
|
199 |
+
tgt = F.pad(tgt, self.padding, self.fill, self.padding_mode)
|
200 |
+
|
201 |
+
# pad the width if needed
|
202 |
+
if self.pad_if_needed and img.size[0] < self.size[1]:
|
203 |
+
img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
|
204 |
+
tgt = F.pad(tgt, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
|
205 |
+
# pad the height if needed
|
206 |
+
if self.pad_if_needed and img.size[1] < self.size[0]:
|
207 |
+
img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
|
208 |
+
tgt = F.pad(tgt, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
|
209 |
+
|
210 |
+
i, j, h, w = self.get_params(img, self.size)
|
211 |
+
|
212 |
+
return F.crop(img, i, j, h, w), F.crop(tgt, i, j, h, w)
|
213 |
+
|
214 |
+
def __repr__(self):
|
215 |
+
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
|
216 |
+
|
217 |
+
|
218 |
+
class RandomHorizontalFlip(object):
|
219 |
+
"""Horizontally flip the given PIL Image randomly with a given probability.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
p (float): probability of the image being flipped. Default value is 0.5
|
223 |
+
"""
|
224 |
+
|
225 |
+
def __init__(self, p=0.5):
|
226 |
+
self.p = p
|
227 |
+
|
228 |
+
def __call__(self, img, tgt):
|
229 |
+
"""
|
230 |
+
Args:
|
231 |
+
img (PIL Image): Image to be flipped.
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
PIL Image: Randomly flipped image.
|
235 |
+
"""
|
236 |
+
if random.random() < self.p:
|
237 |
+
return F.hflip(img), F.hflip(tgt)
|
238 |
+
|
239 |
+
return img, tgt
|
240 |
+
|
241 |
+
def __repr__(self):
|
242 |
+
return self.__class__.__name__ + '(p={})'.format(self.p)
|
243 |
+
|
244 |
+
|
245 |
+
class RandomVerticalFlip(object):
|
246 |
+
"""Vertically flip the given PIL Image randomly with a given probability.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
p (float): probability of the image being flipped. Default value is 0.5
|
250 |
+
"""
|
251 |
+
|
252 |
+
def __init__(self, p=0.5):
|
253 |
+
self.p = p
|
254 |
+
|
255 |
+
def __call__(self, img, tgt):
|
256 |
+
"""
|
257 |
+
Args:
|
258 |
+
img (PIL Image): Image to be flipped.
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
PIL Image: Randomly flipped image.
|
262 |
+
"""
|
263 |
+
if random.random() < self.p:
|
264 |
+
return F.vflip(img), F.vflip(tgt)
|
265 |
+
return img, tgt
|
266 |
+
|
267 |
+
def __repr__(self):
|
268 |
+
return self.__class__.__name__ + '(p={})'.format(self.p)
|
269 |
+
|
270 |
+
|
271 |
+
class Lambda(object):
|
272 |
+
"""Apply a user-defined lambda as a transform.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
lambd (function): Lambda/function to be used for transform.
|
276 |
+
"""
|
277 |
+
|
278 |
+
def __init__(self, lambd):
|
279 |
+
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
|
280 |
+
self.lambd = lambd
|
281 |
+
|
282 |
+
def __call__(self, img, tgt):
|
283 |
+
return self.lambd(img, tgt)
|
284 |
+
|
285 |
+
def __repr__(self):
|
286 |
+
return self.__class__.__name__ + '()'
|
287 |
+
|
288 |
+
|
289 |
+
class ColorJitter(object):
|
290 |
+
"""Randomly change the brightness, contrast and saturation of an image.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
brightness (float or tuple of float (min, max)): How much to jitter brightness.
|
294 |
+
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
|
295 |
+
or the given [min, max]. Should be non negative numbers.
|
296 |
+
contrast (float or tuple of float (min, max)): How much to jitter contrast.
|
297 |
+
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
|
298 |
+
or the given [min, max]. Should be non negative numbers.
|
299 |
+
saturation (float or tuple of float (min, max)): How much to jitter saturation.
|
300 |
+
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
|
301 |
+
or the given [min, max]. Should be non negative numbers.
|
302 |
+
hue (float or tuple of float (min, max)): How much to jitter hue.
|
303 |
+
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
|
304 |
+
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
|
305 |
+
"""
|
306 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
307 |
+
self.brightness = self._check_input(brightness, 'brightness')
|
308 |
+
self.contrast = self._check_input(contrast, 'contrast')
|
309 |
+
self.saturation = self._check_input(saturation, 'saturation')
|
310 |
+
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
|
311 |
+
clip_first_on_zero=False)
|
312 |
+
|
313 |
+
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
|
314 |
+
if isinstance(value, numbers.Number):
|
315 |
+
if value < 0:
|
316 |
+
raise ValueError("If {} is a single number, it must be non negative.".format(name))
|
317 |
+
value = [center - value, center + value]
|
318 |
+
if clip_first_on_zero:
|
319 |
+
value[0] = max(value[0], 0)
|
320 |
+
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
321 |
+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
322 |
+
raise ValueError("{} values should be between {}".format(name, bound))
|
323 |
+
else:
|
324 |
+
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
|
325 |
+
|
326 |
+
# if value is 0 or (1., 1.) for brightness/contrast/saturation
|
327 |
+
# or (0., 0.) for hue, do nothing
|
328 |
+
if value[0] == value[1] == center:
|
329 |
+
value = None
|
330 |
+
return value
|
331 |
+
|
332 |
+
@staticmethod
|
333 |
+
def get_params(brightness, contrast, saturation, hue):
|
334 |
+
"""Get a randomized transform to be applied on image.
|
335 |
+
|
336 |
+
Arguments are same as that of __init__.
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
Transform which randomly adjusts brightness, contrast and
|
340 |
+
saturation in a random order.
|
341 |
+
"""
|
342 |
+
transforms = []
|
343 |
+
|
344 |
+
if brightness is not None:
|
345 |
+
brightness_factor = random.uniform(brightness[0], brightness[1])
|
346 |
+
transforms.append(Lambda(lambda img, tgt: (F.adjust_brightness(img, brightness_factor), tgt)))
|
347 |
+
|
348 |
+
if contrast is not None:
|
349 |
+
contrast_factor = random.uniform(contrast[0], contrast[1])
|
350 |
+
transforms.append(Lambda(lambda img, tgt: (F.adjust_contrast(img, contrast_factor), tgt)))
|
351 |
+
|
352 |
+
if saturation is not None:
|
353 |
+
saturation_factor = random.uniform(saturation[0], saturation[1])
|
354 |
+
transforms.append(Lambda(lambda img, tgt: (F.adjust_saturation(img, saturation_factor), tgt)))
|
355 |
+
|
356 |
+
if hue is not None:
|
357 |
+
hue_factor = random.uniform(hue[0], hue[1])
|
358 |
+
transforms.append(Lambda(lambda img, tgt: (F.adjust_hue(img, hue_factor), tgt)))
|
359 |
+
|
360 |
+
random.shuffle(transforms)
|
361 |
+
transform = Compose(transforms)
|
362 |
+
|
363 |
+
return transform
|
364 |
+
|
365 |
+
def __call__(self, img, tgt):
|
366 |
+
"""
|
367 |
+
Args:
|
368 |
+
img (PIL Image): Input image.
|
369 |
+
|
370 |
+
Returns:
|
371 |
+
PIL Image: Color jittered image.
|
372 |
+
"""
|
373 |
+
transform = self.get_params(self.brightness, self.contrast,
|
374 |
+
self.saturation, self.hue)
|
375 |
+
return transform(img, tgt)
|
376 |
+
|
377 |
+
def __repr__(self):
|
378 |
+
format_string = self.__class__.__name__ + '('
|
379 |
+
format_string += 'brightness={0}'.format(self.brightness)
|
380 |
+
format_string += ', contrast={0}'.format(self.contrast)
|
381 |
+
format_string += ', saturation={0}'.format(self.saturation)
|
382 |
+
format_string += ', hue={0})'.format(self.hue)
|
383 |
+
return format_string
|
384 |
+
|
385 |
+
|
386 |
+
class Normalize(object):
|
387 |
+
"""Normalize a tensor image with mean and standard deviation.
|
388 |
+
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
|
389 |
+
will normalize each channel of the input ``torch.*Tensor`` i.e.
|
390 |
+
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
|
391 |
+
|
392 |
+
.. note::
|
393 |
+
This transform acts out of place, i.e., it does not mutates the input tensor.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
mean (sequence): Sequence of means for each channel.
|
397 |
+
std (sequence): Sequence of standard deviations for each channel.
|
398 |
+
"""
|
399 |
+
|
400 |
+
def __init__(self, mean, std, inplace=False):
|
401 |
+
self.mean = mean
|
402 |
+
self.std = std
|
403 |
+
self.inplace = inplace
|
404 |
+
|
405 |
+
def __call__(self, img, tgt):
|
406 |
+
"""
|
407 |
+
Args:
|
408 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
409 |
+
|
410 |
+
Returns:
|
411 |
+
Tensor: Normalized Tensor image.
|
412 |
+
"""
|
413 |
+
# return F.normalize(img, self.mean, self.std, self.inplace), tgt
|
414 |
+
return F.normalize(img, self.mean, self.std), tgt
|
415 |
+
|
416 |
+
def __repr__(self):
|
417 |
+
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
418 |
+
|
419 |
+
|
420 |
+
class ToTensor(object):
|
421 |
+
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
|
422 |
+
|
423 |
+
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
|
424 |
+
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
|
425 |
+
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
|
426 |
+
or if the numpy.ndarray has dtype = np.uint8
|
427 |
+
|
428 |
+
In the other cases, tensors are returned without scaling.
|
429 |
+
"""
|
430 |
+
|
431 |
+
def __call__(self, img, tgt):
|
432 |
+
"""
|
433 |
+
Args:
|
434 |
+
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
435 |
+
|
436 |
+
Returns:
|
437 |
+
Tensor: Converted image.
|
438 |
+
"""
|
439 |
+
return F.to_tensor(img), tgt
|
440 |
+
|
441 |
+
def __repr__(self):
|
442 |
+
return self.__class__.__name__ + '()'
|
SegmentationTest/imagenet_seg_eval.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from numpy import *
|
6 |
+
import argparse
|
7 |
+
from PIL import Image
|
8 |
+
import imageio
|
9 |
+
import os
|
10 |
+
from tqdm import tqdm
|
11 |
+
from SegmentationTest.utils.metrices import *
|
12 |
+
|
13 |
+
from SegmentationTest.utils import render
|
14 |
+
from SegmentationTest.utils.saver import Saver
|
15 |
+
from SegmentationTest.utils.iou import IoU
|
16 |
+
|
17 |
+
from SegmentationTest.data.Imagenet import Imagenet_Segmentation
|
18 |
+
|
19 |
+
# Uncomment the expected model below
|
20 |
+
|
21 |
+
# ViT
|
22 |
+
from ViT.ViT import vit_base_patch16_224 as vit
|
23 |
+
# from ViT.ViT import vit_large_patch16_224 as vit
|
24 |
+
|
25 |
+
# ViT-AugReg
|
26 |
+
# from ViT.ViT_new import vit_small_patch16_224 as vit
|
27 |
+
# from ViT.ViT_new import vit_base_patch16_224 as vit
|
28 |
+
# from ViT.ViT_new import vit_large_patch16_224 as vit
|
29 |
+
|
30 |
+
# DeiT
|
31 |
+
# from ViT.ViT import deit_base_patch16_224 as vit
|
32 |
+
# from ViT.ViT import deit_small_patch16_224 as vit
|
33 |
+
|
34 |
+
|
35 |
+
from ViT.explainer import generate_relevance, get_image_with_relevance
|
36 |
+
|
37 |
+
from sklearn.metrics import precision_recall_curve
|
38 |
+
import matplotlib.pyplot as plt
|
39 |
+
|
40 |
+
import torch.nn.functional as F
|
41 |
+
|
42 |
+
import warnings
|
43 |
+
warnings.filterwarnings("ignore")
|
44 |
+
|
45 |
+
plt.switch_backend('agg')
|
46 |
+
|
47 |
+
# hyperparameters
|
48 |
+
num_workers = 0
|
49 |
+
batch_size = 1
|
50 |
+
|
51 |
+
cls = ['airplane',
|
52 |
+
'bicycle',
|
53 |
+
'bird',
|
54 |
+
'boat',
|
55 |
+
'bottle',
|
56 |
+
'bus',
|
57 |
+
'car',
|
58 |
+
'cat',
|
59 |
+
'chair',
|
60 |
+
'cow',
|
61 |
+
'dining table',
|
62 |
+
'dog',
|
63 |
+
'horse',
|
64 |
+
'motobike',
|
65 |
+
'person',
|
66 |
+
'potted plant',
|
67 |
+
'sheep',
|
68 |
+
'sofa',
|
69 |
+
'train',
|
70 |
+
'tv'
|
71 |
+
]
|
72 |
+
|
73 |
+
# Args
|
74 |
+
parser = argparse.ArgumentParser(description='Training multi-class classifier')
|
75 |
+
parser.add_argument('--arc', type=str, default='vgg', metavar='N',
|
76 |
+
help='Model architecture')
|
77 |
+
parser.add_argument('--train_dataset', type=str, default='imagenet', metavar='N',
|
78 |
+
help='Testing Dataset')
|
79 |
+
parser.add_argument('--method', type=str,
|
80 |
+
default='grad_rollout',
|
81 |
+
choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'lrp_last_layer',
|
82 |
+
'attn_last_layer', 'attn_gradcam'],
|
83 |
+
help='')
|
84 |
+
parser.add_argument('--thr', type=float, default=0.,
|
85 |
+
help='threshold')
|
86 |
+
parser.add_argument('--K', type=int, default=1,
|
87 |
+
help='new - top K results')
|
88 |
+
parser.add_argument('--save-img', action='store_true',
|
89 |
+
default=False,
|
90 |
+
help='')
|
91 |
+
parser.add_argument('--no-ia', action='store_true',
|
92 |
+
default=False,
|
93 |
+
help='')
|
94 |
+
parser.add_argument('--no-fx', action='store_true',
|
95 |
+
default=False,
|
96 |
+
help='')
|
97 |
+
parser.add_argument('--no-fgx', action='store_true',
|
98 |
+
default=False,
|
99 |
+
help='')
|
100 |
+
parser.add_argument('--no-m', action='store_true',
|
101 |
+
default=False,
|
102 |
+
help='')
|
103 |
+
parser.add_argument('--no-reg', action='store_true',
|
104 |
+
default=False,
|
105 |
+
help='')
|
106 |
+
parser.add_argument('--is-ablation', type=bool,
|
107 |
+
default=False,
|
108 |
+
help='')
|
109 |
+
parser.add_argument('--imagenet-seg-path', type=str, required=True)
|
110 |
+
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
111 |
+
help='path to latest checkpoint (default: none)')
|
112 |
+
args = parser.parse_args()
|
113 |
+
|
114 |
+
args.checkname = args.method + '_' + args.arc
|
115 |
+
|
116 |
+
alpha = 2
|
117 |
+
|
118 |
+
cuda = torch.cuda.is_available()
|
119 |
+
device = torch.device("cuda" if cuda else "cpu")
|
120 |
+
|
121 |
+
# Define Saver
|
122 |
+
saver = Saver(args)
|
123 |
+
saver.results_dir = os.path.join(saver.experiment_dir, 'results')
|
124 |
+
if not os.path.exists(saver.results_dir):
|
125 |
+
os.makedirs(saver.results_dir)
|
126 |
+
if not os.path.exists(os.path.join(saver.results_dir, 'input')):
|
127 |
+
os.makedirs(os.path.join(saver.results_dir, 'input'))
|
128 |
+
if not os.path.exists(os.path.join(saver.results_dir, 'explain')):
|
129 |
+
os.makedirs(os.path.join(saver.results_dir, 'explain'))
|
130 |
+
|
131 |
+
args.exp_img_path = os.path.join(saver.results_dir, 'explain/img')
|
132 |
+
if not os.path.exists(args.exp_img_path):
|
133 |
+
os.makedirs(args.exp_img_path)
|
134 |
+
args.exp_np_path = os.path.join(saver.results_dir, 'explain/np')
|
135 |
+
if not os.path.exists(args.exp_np_path):
|
136 |
+
os.makedirs(args.exp_np_path)
|
137 |
+
|
138 |
+
# Data
|
139 |
+
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
140 |
+
test_img_trans = transforms.Compose([
|
141 |
+
transforms.Resize((224, 224)),
|
142 |
+
transforms.ToTensor(),
|
143 |
+
normalize,
|
144 |
+
])
|
145 |
+
test_lbl_trans = transforms.Compose([
|
146 |
+
transforms.Resize((224, 224), Image.NEAREST),
|
147 |
+
])
|
148 |
+
|
149 |
+
ds = Imagenet_Segmentation(args.imagenet_seg_path,
|
150 |
+
transform=test_img_trans, target_transform=test_lbl_trans)
|
151 |
+
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False)
|
152 |
+
|
153 |
+
# Model
|
154 |
+
if args.checkpoint:
|
155 |
+
print(f"loading model from checkpoint {args.checkpoint}")
|
156 |
+
model = vit().cuda()
|
157 |
+
checkpoint = torch.load(args.checkpoint)
|
158 |
+
model.load_state_dict(checkpoint['state_dict'])
|
159 |
+
else:
|
160 |
+
model = vit(pretrained=True).cuda()
|
161 |
+
|
162 |
+
metric = IoU(2, ignore_index=-1)
|
163 |
+
|
164 |
+
iterator = tqdm(dl)
|
165 |
+
|
166 |
+
model.eval()
|
167 |
+
|
168 |
+
|
169 |
+
def compute_pred(output):
|
170 |
+
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
|
171 |
+
# pred[0, 0] = 282
|
172 |
+
# print('Pred cls : ' + str(pred))
|
173 |
+
T = pred.squeeze().cpu().numpy()
|
174 |
+
T = np.expand_dims(T, 0)
|
175 |
+
T = (T[:, np.newaxis] == np.arange(1000)) * 1.0
|
176 |
+
T = torch.from_numpy(T).type(torch.FloatTensor)
|
177 |
+
Tt = T.cuda()
|
178 |
+
|
179 |
+
return Tt
|
180 |
+
|
181 |
+
|
182 |
+
def eval_batch(image, labels, evaluator, index):
|
183 |
+
evaluator.zero_grad()
|
184 |
+
# Save input image
|
185 |
+
if args.save_img:
|
186 |
+
img = image[0].permute(1, 2, 0).data.cpu().numpy()
|
187 |
+
img = 255 * (img - img.min()) / (img.max() - img.min())
|
188 |
+
img = img.astype('uint8')
|
189 |
+
Image.fromarray(img, 'RGB').save(os.path.join(saver.results_dir, 'input/{}_input.png'.format(index)))
|
190 |
+
Image.fromarray((labels.repeat(3, 1, 1).permute(1, 2, 0).data.cpu().numpy() * 255).astype('uint8'), 'RGB').save(
|
191 |
+
os.path.join(saver.results_dir, 'input/{}_mask.png'.format(index)))
|
192 |
+
|
193 |
+
image.requires_grad = True
|
194 |
+
|
195 |
+
image = image.requires_grad_()
|
196 |
+
predictions = evaluator(image)
|
197 |
+
Res = generate_relevance(model, image.cuda())
|
198 |
+
|
199 |
+
# threshold between FG and BG is the mean
|
200 |
+
Res = (Res - Res.min()) / (Res.max() - Res.min())
|
201 |
+
|
202 |
+
ret = Res.mean()
|
203 |
+
|
204 |
+
Res_1 = Res.gt(ret).type(Res.type())
|
205 |
+
Res_0 = Res.le(ret).type(Res.type())
|
206 |
+
|
207 |
+
Res_1_AP = Res
|
208 |
+
Res_0_AP = 1 - Res
|
209 |
+
|
210 |
+
Res_1[Res_1 != Res_1] = 0
|
211 |
+
Res_0[Res_0 != Res_0] = 0
|
212 |
+
Res_1_AP[Res_1_AP != Res_1_AP] = 0
|
213 |
+
Res_0_AP[Res_0_AP != Res_0_AP] = 0
|
214 |
+
|
215 |
+
# TEST
|
216 |
+
pred = Res.clamp(min=args.thr) / Res.max()
|
217 |
+
pred = pred.view(-1).data.cpu().numpy()
|
218 |
+
target = labels.view(-1).data.cpu().numpy()
|
219 |
+
# print("target", target.shape)
|
220 |
+
|
221 |
+
output = torch.cat((Res_0, Res_1), 1)
|
222 |
+
output_AP = torch.cat((Res_0_AP, Res_1_AP), 1)
|
223 |
+
|
224 |
+
if args.save_img:
|
225 |
+
# Save predicted mask
|
226 |
+
mask = F.interpolate(Res_1, [64, 64], mode='bilinear')
|
227 |
+
mask = mask[0].squeeze().data.cpu().numpy()
|
228 |
+
# mask = Res_1[0].squeeze().data.cpu().numpy()
|
229 |
+
mask = 255 * mask
|
230 |
+
mask = mask.astype('uint8')
|
231 |
+
imageio.imsave(os.path.join(args.exp_img_path, 'mask_' + str(index) + '.jpg'), mask)
|
232 |
+
|
233 |
+
relevance = F.interpolate(Res, [64, 64], mode='bilinear')
|
234 |
+
relevance = relevance[0].permute(1, 2, 0).data.cpu().numpy()
|
235 |
+
# relevance = Res[0].permute(1, 2, 0).data.cpu().numpy()
|
236 |
+
hm = np.sum(relevance, axis=-1)
|
237 |
+
maps = (render.hm_to_rgb(hm, scaling=3, sigma=1, cmap='seismic') * 255).astype(np.uint8)
|
238 |
+
imageio.imsave(os.path.join(args.exp_img_path, 'heatmap_' + str(index) + '.jpg'), maps)
|
239 |
+
|
240 |
+
# Evaluate Segmentation
|
241 |
+
batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0
|
242 |
+
batch_ap, batch_f1 = 0, 0
|
243 |
+
|
244 |
+
# Segmentation resutls
|
245 |
+
correct, labeled = batch_pix_accuracy(output[0].data.cpu(), labels[0])
|
246 |
+
inter, union = batch_intersection_union(output[0].data.cpu(), labels[0], 2)
|
247 |
+
batch_correct += correct
|
248 |
+
batch_label += labeled
|
249 |
+
batch_inter += inter
|
250 |
+
batch_union += union
|
251 |
+
# print("output", output.shape)
|
252 |
+
# print("ap labels", labels.shape)
|
253 |
+
# ap = np.nan_to_num(get_ap_scores(output, labels))
|
254 |
+
ap = np.nan_to_num(get_ap_scores(output_AP, labels))
|
255 |
+
# f1 = np.nan_to_num(get_f1_scores(output[0, 1].data.cpu(), labels[0]))
|
256 |
+
batch_ap += ap
|
257 |
+
# batch_f1 += f1
|
258 |
+
|
259 |
+
# return batch_correct, batch_label, batch_inter, batch_union, batch_ap, batch_f1, pred, target
|
260 |
+
return batch_correct, batch_label, batch_inter, batch_union, batch_ap, pred, target
|
261 |
+
|
262 |
+
|
263 |
+
total_inter, total_union, total_correct, total_label = np.int64(0), np.int64(0), np.int64(0), np.int64(0)
|
264 |
+
total_ap, total_f1 = [], []
|
265 |
+
|
266 |
+
predictions, targets = [], []
|
267 |
+
for batch_idx, (image, labels) in enumerate(iterator):
|
268 |
+
|
269 |
+
if args.method == "blur":
|
270 |
+
images = (image[0].cuda(), image[1].cuda())
|
271 |
+
else:
|
272 |
+
images = image.cuda()
|
273 |
+
labels = labels.cuda()
|
274 |
+
# print("image", image.shape)
|
275 |
+
# print("lables", labels.shape)
|
276 |
+
|
277 |
+
# correct, labeled, inter, union, ap, f1, pred, target = eval_batch(images, labels, model, batch_idx)
|
278 |
+
correct, labeled, inter, union, ap, pred, target = eval_batch(images, labels, model, batch_idx)
|
279 |
+
|
280 |
+
predictions.append(pred)
|
281 |
+
targets.append(target)
|
282 |
+
|
283 |
+
total_correct += correct.astype('int64')
|
284 |
+
total_label += labeled.astype('int64')
|
285 |
+
total_inter += inter.astype('int64')
|
286 |
+
total_union += union.astype('int64')
|
287 |
+
total_ap += [ap]
|
288 |
+
# total_f1 += [f1]
|
289 |
+
pixAcc = np.float64(1.0) * total_correct / (np.spacing(1, dtype=np.float64) + total_label)
|
290 |
+
IoU = np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union)
|
291 |
+
mIoU = IoU.mean()
|
292 |
+
mAp = np.mean(total_ap)
|
293 |
+
# mF1 = np.mean(total_f1)
|
294 |
+
# iterator.set_description('pixAcc: %.4f, mIoU: %.4f, mAP: %.4f, mF1: %.4f' % (pixAcc, mIoU, mAp, mF1))
|
295 |
+
iterator.set_description('pixAcc: %.4f, mIoU: %.4f, mAP: %.4f' % (pixAcc, mIoU, mAp))
|
296 |
+
|
297 |
+
predictions = np.concatenate(predictions)
|
298 |
+
targets = np.concatenate(targets)
|
299 |
+
pr, rc, thr = precision_recall_curve(targets, predictions)
|
300 |
+
np.save(os.path.join(saver.experiment_dir, 'precision.npy'), pr)
|
301 |
+
np.save(os.path.join(saver.experiment_dir, 'recall.npy'), rc)
|
302 |
+
|
303 |
+
plt.figure()
|
304 |
+
plt.plot(rc, pr)
|
305 |
+
plt.savefig(os.path.join(saver.experiment_dir, 'PR_curve_{}.png'.format(args.method)))
|
306 |
+
|
307 |
+
txtfile = os.path.join(saver.experiment_dir, 'result_mIoU_%.4f.txt' % mIoU)
|
308 |
+
# txtfile = 'result_mIoU_%.4f.txt' % mIoU
|
309 |
+
fh = open(txtfile, 'w')
|
310 |
+
print("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
|
311 |
+
print("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
|
312 |
+
print("Mean AP over %d classes: %.4f\n" % (2, mAp))
|
313 |
+
# print("Mean F1 over %d classes: %.4f\n" % (2, mF1))
|
314 |
+
|
315 |
+
fh.write("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
|
316 |
+
fh.write("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
|
317 |
+
fh.write("Mean AP over %d classes: %.4f\n" % (2, mAp))
|
318 |
+
# fh.write("Mean F1 over %d classes: %.4f\n" % (2, mF1))
|
319 |
+
fh.close()
|
SegmentationTest/utils/__init__.py
ADDED
File without changes
|
SegmentationTest/utils/confusionmatrix.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from . import metric
|
4 |
+
|
5 |
+
|
6 |
+
class ConfusionMatrix(metric.Metric):
|
7 |
+
"""Constructs a confusion matrix for a multi-class classification problems.
|
8 |
+
Does not support multi-label, multi-class problems.
|
9 |
+
Keyword arguments:
|
10 |
+
- num_classes (int): number of classes in the classification problem.
|
11 |
+
- normalized (boolean, optional): Determines whether or not the confusion
|
12 |
+
matrix is normalized or not. Default: False.
|
13 |
+
Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, num_classes, normalized=False):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32)
|
20 |
+
self.normalized = normalized
|
21 |
+
self.num_classes = num_classes
|
22 |
+
self.reset()
|
23 |
+
|
24 |
+
def reset(self):
|
25 |
+
self.conf.fill(0)
|
26 |
+
|
27 |
+
def add(self, predicted, target):
|
28 |
+
"""Computes the confusion matrix
|
29 |
+
The shape of the confusion matrix is K x K, where K is the number
|
30 |
+
of classes.
|
31 |
+
Keyword arguments:
|
32 |
+
- predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of
|
33 |
+
predicted scores obtained from the model for N examples and K classes,
|
34 |
+
or an N-tensor/array of integer values between 0 and K-1.
|
35 |
+
- target (Tensor or numpy.ndarray): Can be an N x K tensor/array of
|
36 |
+
ground-truth classes for N examples and K classes, or an N-tensor/array
|
37 |
+
of integer values between 0 and K-1.
|
38 |
+
"""
|
39 |
+
# If target and/or predicted are tensors, convert them to numpy arrays
|
40 |
+
if torch.is_tensor(predicted):
|
41 |
+
predicted = predicted.cpu().numpy()
|
42 |
+
if torch.is_tensor(target):
|
43 |
+
target = target.cpu().numpy()
|
44 |
+
|
45 |
+
assert predicted.shape[0] == target.shape[0], \
|
46 |
+
'number of targets and predicted outputs do not match'
|
47 |
+
|
48 |
+
if np.ndim(predicted) != 1:
|
49 |
+
assert predicted.shape[1] == self.num_classes, \
|
50 |
+
'number of predictions does not match size of confusion matrix'
|
51 |
+
predicted = np.argmax(predicted, 1)
|
52 |
+
else:
|
53 |
+
assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \
|
54 |
+
'predicted values are not between 0 and k-1'
|
55 |
+
|
56 |
+
if np.ndim(target) != 1:
|
57 |
+
assert target.shape[1] == self.num_classes, \
|
58 |
+
'Onehot target does not match size of confusion matrix'
|
59 |
+
assert (target >= 0).all() and (target <= 1).all(), \
|
60 |
+
'in one-hot encoding, target values should be 0 or 1'
|
61 |
+
assert (target.sum(1) == 1).all(), \
|
62 |
+
'multi-label setting is not supported'
|
63 |
+
target = np.argmax(target, 1)
|
64 |
+
else:
|
65 |
+
assert (target.max() < self.num_classes) and (target.min() >= 0), \
|
66 |
+
'target values are not between 0 and k-1'
|
67 |
+
|
68 |
+
# hack for bincounting 2 arrays together
|
69 |
+
x = predicted + self.num_classes * target
|
70 |
+
bincount_2d = np.bincount(
|
71 |
+
x.astype(np.int32), minlength=self.num_classes**2)
|
72 |
+
assert bincount_2d.size == self.num_classes**2
|
73 |
+
conf = bincount_2d.reshape((self.num_classes, self.num_classes))
|
74 |
+
|
75 |
+
self.conf += conf
|
76 |
+
|
77 |
+
def value(self):
|
78 |
+
"""
|
79 |
+
Returns:
|
80 |
+
Confustion matrix of K rows and K columns, where rows corresponds
|
81 |
+
to ground-truth targets and columns corresponds to predicted
|
82 |
+
targets.
|
83 |
+
"""
|
84 |
+
if self.normalized:
|
85 |
+
conf = self.conf.astype(np.float32)
|
86 |
+
return conf / conf.sum(1).clip(min=1e-12)[:, None]
|
87 |
+
else:
|
88 |
+
return self.conf
|
SegmentationTest/utils/iou.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from . import metric
|
4 |
+
from .confusionmatrix import ConfusionMatrix
|
5 |
+
|
6 |
+
|
7 |
+
class IoU(metric.Metric):
|
8 |
+
"""Computes the intersection over union (IoU) per class and corresponding
|
9 |
+
mean (mIoU).
|
10 |
+
|
11 |
+
Intersection over union (IoU) is a common evaluation metric for semantic
|
12 |
+
segmentation. The predictions are first accumulated in a confusion matrix
|
13 |
+
and the IoU is computed from it as follows:
|
14 |
+
|
15 |
+
IoU = true_positive / (true_positive + false_positive + false_negative).
|
16 |
+
|
17 |
+
Keyword arguments:
|
18 |
+
- num_classes (int): number of classes in the classification problem
|
19 |
+
- normalized (boolean, optional): Determines whether or not the confusion
|
20 |
+
matrix is normalized or not. Default: False.
|
21 |
+
- ignore_index (int or iterable, optional): Index of the classes to ignore
|
22 |
+
when computing the IoU. Can be an int, or any iterable of ints.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, num_classes, normalized=False, ignore_index=None):
|
26 |
+
super().__init__()
|
27 |
+
self.conf_metric = ConfusionMatrix(num_classes, normalized)
|
28 |
+
|
29 |
+
if ignore_index is None:
|
30 |
+
self.ignore_index = None
|
31 |
+
elif isinstance(ignore_index, int):
|
32 |
+
self.ignore_index = (ignore_index,)
|
33 |
+
else:
|
34 |
+
try:
|
35 |
+
self.ignore_index = tuple(ignore_index)
|
36 |
+
except TypeError:
|
37 |
+
raise ValueError("'ignore_index' must be an int or iterable")
|
38 |
+
|
39 |
+
def reset(self):
|
40 |
+
self.conf_metric.reset()
|
41 |
+
|
42 |
+
def add(self, predicted, target):
|
43 |
+
"""Adds the predicted and target pair to the IoU metric.
|
44 |
+
|
45 |
+
Keyword arguments:
|
46 |
+
- predicted (Tensor): Can be a (N, K, H, W) tensor of
|
47 |
+
predicted scores obtained from the model for N examples and K classes,
|
48 |
+
or (N, H, W) tensor of integer values between 0 and K-1.
|
49 |
+
- target (Tensor): Can be a (N, K, H, W) tensor of
|
50 |
+
target scores for N examples and K classes, or (N, H, W) tensor of
|
51 |
+
integer values between 0 and K-1.
|
52 |
+
|
53 |
+
"""
|
54 |
+
# Dimensions check
|
55 |
+
assert predicted.size(0) == target.size(0), \
|
56 |
+
'number of targets and predicted outputs do not match'
|
57 |
+
assert predicted.dim() == 3 or predicted.dim() == 4, \
|
58 |
+
"predictions must be of dimension (N, H, W) or (N, K, H, W)"
|
59 |
+
assert target.dim() == 3 or target.dim() == 4, \
|
60 |
+
"targets must be of dimension (N, H, W) or (N, K, H, W)"
|
61 |
+
|
62 |
+
# If the tensor is in categorical format convert it to integer format
|
63 |
+
if predicted.dim() == 4:
|
64 |
+
_, predicted = predicted.max(1)
|
65 |
+
if target.dim() == 4:
|
66 |
+
_, target = target.max(1)
|
67 |
+
|
68 |
+
self.conf_metric.add(predicted.view(-1), target.view(-1))
|
69 |
+
|
70 |
+
def value(self):
|
71 |
+
"""Computes the IoU and mean IoU.
|
72 |
+
|
73 |
+
The mean computation ignores NaN elements of the IoU array.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
Tuple: (IoU, mIoU). The first output is the per class IoU,
|
77 |
+
for K classes it's numpy.ndarray with K elements. The second output,
|
78 |
+
is the mean IoU.
|
79 |
+
"""
|
80 |
+
conf_matrix = self.conf_metric.value()
|
81 |
+
if self.ignore_index is not None:
|
82 |
+
for index in self.ignore_index:
|
83 |
+
conf_matrix[:, self.ignore_index] = 0
|
84 |
+
conf_matrix[self.ignore_index, :] = 0
|
85 |
+
true_positive = np.diag(conf_matrix)
|
86 |
+
false_positive = np.sum(conf_matrix, 0) - true_positive
|
87 |
+
false_negative = np.sum(conf_matrix, 1) - true_positive
|
88 |
+
|
89 |
+
# Just in case we get a division by 0, ignore/hide the error
|
90 |
+
with np.errstate(divide='ignore', invalid='ignore'):
|
91 |
+
iou = true_positive / (true_positive + false_positive + false_negative)
|
92 |
+
|
93 |
+
return iou, np.nanmean(iou)
|
SegmentationTest/utils/metric.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Metric(object):
|
2 |
+
"""Base class for all metrics.
|
3 |
+
From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py
|
4 |
+
"""
|
5 |
+
def reset(self):
|
6 |
+
pass
|
7 |
+
|
8 |
+
def add(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def value(self):
|
12 |
+
pass
|
SegmentationTest/utils/metrices.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from sklearn.metrics import f1_score, average_precision_score
|
4 |
+
from sklearn.metrics import precision_recall_curve, roc_curve
|
5 |
+
|
6 |
+
SMOOTH = 1e-6
|
7 |
+
__all__ = ['get_f1_scores', 'get_ap_scores', 'batch_pix_accuracy', 'batch_intersection_union', 'get_iou', 'get_pr',
|
8 |
+
'get_roc', 'get_ap_multiclass']
|
9 |
+
|
10 |
+
|
11 |
+
def get_iou(outputs: torch.Tensor, labels: torch.Tensor):
|
12 |
+
# You can comment out this line if you are passing tensors of equal shape
|
13 |
+
# But if you are passing output from UNet or something it will most probably
|
14 |
+
# be with the BATCH x 1 x H x W shape
|
15 |
+
outputs = outputs.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
|
16 |
+
labels = labels.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
|
17 |
+
|
18 |
+
intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0
|
19 |
+
union = (outputs | labels).float().sum((1, 2)) # Will be zzero if both are 0
|
20 |
+
|
21 |
+
iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0
|
22 |
+
|
23 |
+
return iou.cpu().numpy()
|
24 |
+
|
25 |
+
|
26 |
+
def get_f1_scores(predict, target, ignore_index=-1):
|
27 |
+
# Tensor process
|
28 |
+
batch_size = predict.shape[0]
|
29 |
+
predict = predict.data.cpu().numpy().reshape(-1)
|
30 |
+
target = target.data.cpu().numpy().reshape(-1)
|
31 |
+
pb = predict[target != ignore_index].reshape(batch_size, -1)
|
32 |
+
tb = target[target != ignore_index].reshape(batch_size, -1)
|
33 |
+
|
34 |
+
total = []
|
35 |
+
for p, t in zip(pb, tb):
|
36 |
+
total.append(np.nan_to_num(f1_score(t, p)))
|
37 |
+
|
38 |
+
return total
|
39 |
+
|
40 |
+
|
41 |
+
def get_roc(predict, target, ignore_index=-1):
|
42 |
+
target_expand = target.unsqueeze(1).expand_as(predict)
|
43 |
+
target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
|
44 |
+
# Tensor process
|
45 |
+
x = torch.zeros_like(target_expand)
|
46 |
+
t = target.unsqueeze(1).clamp(min=0)
|
47 |
+
target_1hot = x.scatter_(1, t, 1)
|
48 |
+
batch_size = predict.shape[0]
|
49 |
+
predict = predict.data.cpu().numpy().reshape(-1)
|
50 |
+
target = target_1hot.data.cpu().numpy().reshape(-1)
|
51 |
+
pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1)
|
52 |
+
tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1)
|
53 |
+
|
54 |
+
total = []
|
55 |
+
for p, t in zip(pb, tb):
|
56 |
+
total.append(roc_curve(t, p))
|
57 |
+
|
58 |
+
return total
|
59 |
+
|
60 |
+
|
61 |
+
def get_pr(predict, target, ignore_index=-1):
|
62 |
+
target_expand = target.unsqueeze(1).expand_as(predict)
|
63 |
+
target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
|
64 |
+
# Tensor process
|
65 |
+
x = torch.zeros_like(target_expand)
|
66 |
+
t = target.unsqueeze(1).clamp(min=0)
|
67 |
+
target_1hot = x.scatter_(1, t, 1)
|
68 |
+
batch_size = predict.shape[0]
|
69 |
+
predict = predict.data.cpu().numpy().reshape(-1)
|
70 |
+
target = target_1hot.data.cpu().numpy().reshape(-1)
|
71 |
+
pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1)
|
72 |
+
tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1)
|
73 |
+
|
74 |
+
total = []
|
75 |
+
for p, t in zip(pb, tb):
|
76 |
+
total.append(precision_recall_curve(t, p))
|
77 |
+
|
78 |
+
return total
|
79 |
+
|
80 |
+
|
81 |
+
def get_ap_scores(predict, target, ignore_index=-1):
|
82 |
+
total = []
|
83 |
+
for pred, tgt in zip(predict, target):
|
84 |
+
target_expand = tgt.unsqueeze(0).expand_as(pred)
|
85 |
+
target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
|
86 |
+
|
87 |
+
# Tensor process
|
88 |
+
x = torch.zeros_like(target_expand)
|
89 |
+
t = tgt.unsqueeze(0).clamp(min=0).long()
|
90 |
+
target_1hot = x.scatter_(0, t, 1)
|
91 |
+
predict_flat = pred.data.cpu().numpy().reshape(-1)
|
92 |
+
target_flat = target_1hot.data.cpu().numpy().reshape(-1)
|
93 |
+
|
94 |
+
p = predict_flat[target_expand_numpy != ignore_index]
|
95 |
+
t = target_flat[target_expand_numpy != ignore_index]
|
96 |
+
|
97 |
+
total.append(np.nan_to_num(average_precision_score(t, p)))
|
98 |
+
|
99 |
+
return total
|
100 |
+
|
101 |
+
|
102 |
+
def get_ap_multiclass(predict, target):
|
103 |
+
total = []
|
104 |
+
for pred, tgt in zip(predict, target):
|
105 |
+
predict_flat = pred.data.cpu().numpy().reshape(-1)
|
106 |
+
target_flat = tgt.data.cpu().numpy().reshape(-1)
|
107 |
+
|
108 |
+
total.append(np.nan_to_num(average_precision_score(target_flat, predict_flat)))
|
109 |
+
|
110 |
+
return total
|
111 |
+
|
112 |
+
|
113 |
+
def batch_precision_recall(predict, target, thr=0.5):
|
114 |
+
"""Batch Precision Recall
|
115 |
+
Args:
|
116 |
+
predict: input 4D tensor
|
117 |
+
target: label 4D tensor
|
118 |
+
"""
|
119 |
+
# _, predict = torch.max(predict, 1)
|
120 |
+
|
121 |
+
predict = predict > thr
|
122 |
+
predict = predict.data.cpu().numpy() + 1
|
123 |
+
target = target.data.cpu().numpy() + 1
|
124 |
+
|
125 |
+
tp = np.sum(((predict == 2) * (target == 2)) * (target > 0))
|
126 |
+
fp = np.sum(((predict == 2) * (target == 1)) * (target > 0))
|
127 |
+
fn = np.sum(((predict == 1) * (target == 2)) * (target > 0))
|
128 |
+
|
129 |
+
precision = float(np.nan_to_num(tp / (tp + fp)))
|
130 |
+
recall = float(np.nan_to_num(tp / (tp + fn)))
|
131 |
+
|
132 |
+
return precision, recall
|
133 |
+
|
134 |
+
|
135 |
+
def batch_pix_accuracy(predict, target):
|
136 |
+
"""Batch Pixel Accuracy
|
137 |
+
Args:
|
138 |
+
predict: input 3D tensor
|
139 |
+
target: label 3D tensor
|
140 |
+
"""
|
141 |
+
|
142 |
+
# for thr in np.linspace(0, 1, slices):
|
143 |
+
|
144 |
+
_, predict = torch.max(predict, 0)
|
145 |
+
predict = predict.cpu().numpy() + 1
|
146 |
+
target = target.cpu().numpy() + 1
|
147 |
+
pixel_labeled = np.sum(target > 0)
|
148 |
+
pixel_correct = np.sum((predict == target) * (target > 0))
|
149 |
+
assert pixel_correct <= pixel_labeled, \
|
150 |
+
"Correct area should be smaller than Labeled"
|
151 |
+
return pixel_correct, pixel_labeled
|
152 |
+
|
153 |
+
|
154 |
+
def batch_intersection_union(predict, target, nclass):
|
155 |
+
"""Batch Intersection of Union
|
156 |
+
Args:
|
157 |
+
predict: input 3D tensor
|
158 |
+
target: label 3D tensor
|
159 |
+
nclass: number of categories (int)
|
160 |
+
"""
|
161 |
+
_, predict = torch.max(predict, 0)
|
162 |
+
mini = 1
|
163 |
+
maxi = nclass
|
164 |
+
nbins = nclass
|
165 |
+
predict = predict.cpu().numpy() + 1
|
166 |
+
target = target.cpu().numpy() + 1
|
167 |
+
|
168 |
+
predict = predict * (target > 0).astype(predict.dtype)
|
169 |
+
intersection = predict * (predict == target)
|
170 |
+
# areas of intersection and union
|
171 |
+
area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
|
172 |
+
area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
|
173 |
+
area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
|
174 |
+
area_union = area_pred + area_lab - area_inter
|
175 |
+
assert (area_inter <= area_union).all(), \
|
176 |
+
"Intersection area should be smaller than Union area"
|
177 |
+
return area_inter, area_union
|
178 |
+
|
179 |
+
|
180 |
+
# ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py
|
181 |
+
def pixel_accuracy(im_pred, im_lab):
|
182 |
+
im_pred = np.asarray(im_pred)
|
183 |
+
im_lab = np.asarray(im_lab)
|
184 |
+
|
185 |
+
# Remove classes from unlabeled pixels in gt image.
|
186 |
+
# We should not penalize detections in unlabeled portions of the image.
|
187 |
+
pixel_labeled = np.sum(im_lab > 0)
|
188 |
+
pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0))
|
189 |
+
# pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
|
190 |
+
return pixel_correct, pixel_labeled
|
191 |
+
|
192 |
+
|
193 |
+
def intersection_and_union(im_pred, im_lab, num_class):
|
194 |
+
im_pred = np.asarray(im_pred)
|
195 |
+
im_lab = np.asarray(im_lab)
|
196 |
+
# Remove classes from unlabeled pixels in gt image.
|
197 |
+
im_pred = im_pred * (im_lab > 0)
|
198 |
+
# Compute area intersection:
|
199 |
+
intersection = im_pred * (im_pred == im_lab)
|
200 |
+
area_inter, _ = np.histogram(intersection, bins=num_class - 1,
|
201 |
+
range=(1, num_class - 1))
|
202 |
+
# Compute area union:
|
203 |
+
area_pred, _ = np.histogram(im_pred, bins=num_class - 1,
|
204 |
+
range=(1, num_class - 1))
|
205 |
+
area_lab, _ = np.histogram(im_lab, bins=num_class - 1,
|
206 |
+
range=(1, num_class - 1))
|
207 |
+
area_union = area_pred + area_lab - area_inter
|
208 |
+
return area_inter, area_union
|
SegmentationTest/utils/parallel.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
2 |
+
## Created by: Hang Zhang
|
3 |
+
## ECE Department, Rutgers University
|
4 |
+
## Email: zhang.hang@rutgers.edu
|
5 |
+
## Copyright (c) 2017
|
6 |
+
##
|
7 |
+
## This source code is licensed under the MIT-style license found in the
|
8 |
+
## LICENSE file in the root directory of this source tree
|
9 |
+
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
10 |
+
|
11 |
+
"""Encoding Data Parallel"""
|
12 |
+
import threading
|
13 |
+
import functools
|
14 |
+
import torch
|
15 |
+
from torch.autograd import Variable, Function
|
16 |
+
import torch.cuda.comm as comm
|
17 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
18 |
+
from torch.nn.parallel.parallel_apply import get_a_var
|
19 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
20 |
+
|
21 |
+
torch_ver = torch.__version__[:3]
|
22 |
+
|
23 |
+
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
|
24 |
+
'patch_replication_callback']
|
25 |
+
|
26 |
+
def allreduce(*inputs):
|
27 |
+
"""Cross GPU all reduce autograd operation for calculate mean and
|
28 |
+
variance in SyncBN.
|
29 |
+
"""
|
30 |
+
return AllReduce.apply(*inputs)
|
31 |
+
|
32 |
+
class AllReduce(Function):
|
33 |
+
@staticmethod
|
34 |
+
def forward(ctx, num_inputs, *inputs):
|
35 |
+
ctx.num_inputs = num_inputs
|
36 |
+
ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
|
37 |
+
inputs = [inputs[i:i + num_inputs]
|
38 |
+
for i in range(0, len(inputs), num_inputs)]
|
39 |
+
# sort before reduce sum
|
40 |
+
inputs = sorted(inputs, key=lambda i: i[0].get_device())
|
41 |
+
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
|
42 |
+
outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
|
43 |
+
return tuple([t for tensors in outputs for t in tensors])
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
def backward(ctx, *inputs):
|
47 |
+
inputs = [i.data for i in inputs]
|
48 |
+
inputs = [inputs[i:i + ctx.num_inputs]
|
49 |
+
for i in range(0, len(inputs), ctx.num_inputs)]
|
50 |
+
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
|
51 |
+
outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
|
52 |
+
return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
|
53 |
+
|
54 |
+
|
55 |
+
class Reduce(Function):
|
56 |
+
@staticmethod
|
57 |
+
def forward(ctx, *inputs):
|
58 |
+
ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
|
59 |
+
inputs = sorted(inputs, key=lambda i: i.get_device())
|
60 |
+
return comm.reduce_add(inputs)
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
def backward(ctx, gradOutput):
|
64 |
+
return Broadcast.apply(ctx.target_gpus, gradOutput)
|
65 |
+
|
66 |
+
|
67 |
+
class DataParallelModel(DataParallel):
|
68 |
+
"""Implements data parallelism at the module level.
|
69 |
+
|
70 |
+
This container parallelizes the application of the given module by
|
71 |
+
splitting the input across the specified devices by chunking in the
|
72 |
+
batch dimension.
|
73 |
+
In the forward pass, the module is replicated on each device,
|
74 |
+
and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
|
75 |
+
Note that the outputs are not gathered, please use compatible
|
76 |
+
:class:`encoding.parallel.DataParallelCriterion`.
|
77 |
+
|
78 |
+
The batch size should be larger than the number of GPUs used. It should
|
79 |
+
also be an integer multiple of the number of GPUs so that each chunk is
|
80 |
+
the same size (so that each GPU processes the same number of samples).
|
81 |
+
|
82 |
+
Args:
|
83 |
+
module: module to be parallelized
|
84 |
+
device_ids: CUDA devices (default: all devices)
|
85 |
+
|
86 |
+
Reference:
|
87 |
+
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
|
88 |
+
Amit Agrawal. “Context Encoding for Semantic Segmentation.
|
89 |
+
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
|
90 |
+
|
91 |
+
Example::
|
92 |
+
|
93 |
+
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
|
94 |
+
>>> y = net(x)
|
95 |
+
"""
|
96 |
+
def gather(self, outputs, output_device):
|
97 |
+
return outputs
|
98 |
+
|
99 |
+
def replicate(self, module, device_ids):
|
100 |
+
modules = super(DataParallelModel, self).replicate(module, device_ids)
|
101 |
+
execute_replication_callbacks(modules)
|
102 |
+
return modules
|
103 |
+
|
104 |
+
|
105 |
+
class DataParallelCriterion(DataParallel):
|
106 |
+
"""
|
107 |
+
Calculate loss in multiple-GPUs, which balance the memory usage for
|
108 |
+
Semantic Segmentation.
|
109 |
+
|
110 |
+
The targets are splitted across the specified devices by chunking in
|
111 |
+
the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
|
112 |
+
|
113 |
+
Reference:
|
114 |
+
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
|
115 |
+
Amit Agrawal. “Context Encoding for Semantic Segmentation.
|
116 |
+
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
|
117 |
+
|
118 |
+
Example::
|
119 |
+
|
120 |
+
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
|
121 |
+
>>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
|
122 |
+
>>> y = net(x)
|
123 |
+
>>> loss = criterion(y, target)
|
124 |
+
"""
|
125 |
+
def forward(self, inputs, *targets, **kwargs):
|
126 |
+
# input should be already scatterd
|
127 |
+
# scattering the targets instead
|
128 |
+
if not self.device_ids:
|
129 |
+
return self.module(inputs, *targets, **kwargs)
|
130 |
+
targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
|
131 |
+
if len(self.device_ids) == 1:
|
132 |
+
return self.module(inputs, *targets[0], **kwargs[0])
|
133 |
+
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
134 |
+
outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
|
135 |
+
return Reduce.apply(*outputs) / len(outputs)
|
136 |
+
#return self.gather(outputs, self.output_device).mean()
|
137 |
+
|
138 |
+
|
139 |
+
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
|
140 |
+
assert len(modules) == len(inputs)
|
141 |
+
assert len(targets) == len(inputs)
|
142 |
+
if kwargs_tup:
|
143 |
+
assert len(modules) == len(kwargs_tup)
|
144 |
+
else:
|
145 |
+
kwargs_tup = ({},) * len(modules)
|
146 |
+
if devices is not None:
|
147 |
+
assert len(modules) == len(devices)
|
148 |
+
else:
|
149 |
+
devices = [None] * len(modules)
|
150 |
+
|
151 |
+
lock = threading.Lock()
|
152 |
+
results = {}
|
153 |
+
if torch_ver != "0.3":
|
154 |
+
grad_enabled = torch.is_grad_enabled()
|
155 |
+
|
156 |
+
def _worker(i, module, input, target, kwargs, device=None):
|
157 |
+
if torch_ver != "0.3":
|
158 |
+
torch.set_grad_enabled(grad_enabled)
|
159 |
+
if device is None:
|
160 |
+
device = get_a_var(input).get_device()
|
161 |
+
try:
|
162 |
+
with torch.cuda.device(device):
|
163 |
+
# this also avoids accidental slicing of `input` if it is a Tensor
|
164 |
+
if not isinstance(input, (list, tuple)):
|
165 |
+
input = (input,)
|
166 |
+
if type(input) != type(target):
|
167 |
+
if isinstance(target, tuple):
|
168 |
+
input = tuple(input)
|
169 |
+
elif isinstance(target, list):
|
170 |
+
input = list(input)
|
171 |
+
else:
|
172 |
+
raise Exception("Types problem")
|
173 |
+
|
174 |
+
output = module(*(input + target), **kwargs)
|
175 |
+
with lock:
|
176 |
+
results[i] = output
|
177 |
+
except Exception as e:
|
178 |
+
with lock:
|
179 |
+
results[i] = e
|
180 |
+
|
181 |
+
if len(modules) > 1:
|
182 |
+
threads = [threading.Thread(target=_worker,
|
183 |
+
args=(i, module, input, target,
|
184 |
+
kwargs, device),)
|
185 |
+
for i, (module, input, target, kwargs, device) in
|
186 |
+
enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
|
187 |
+
|
188 |
+
for thread in threads:
|
189 |
+
thread.start()
|
190 |
+
for thread in threads:
|
191 |
+
thread.join()
|
192 |
+
else:
|
193 |
+
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
|
194 |
+
|
195 |
+
outputs = []
|
196 |
+
for i in range(len(inputs)):
|
197 |
+
output = results[i]
|
198 |
+
if isinstance(output, Exception):
|
199 |
+
raise output
|
200 |
+
outputs.append(output)
|
201 |
+
return outputs
|
202 |
+
|
203 |
+
|
204 |
+
###########################################################################
|
205 |
+
# Adapted from Synchronized-BatchNorm-PyTorch.
|
206 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
207 |
+
#
|
208 |
+
class CallbackContext(object):
|
209 |
+
pass
|
210 |
+
|
211 |
+
|
212 |
+
def execute_replication_callbacks(modules):
|
213 |
+
"""
|
214 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created
|
215 |
+
by original replication.
|
216 |
+
|
217 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
218 |
+
|
219 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
220 |
+
(shared among multiple copies of this module on different devices).
|
221 |
+
Through this context, different copies can share some information.
|
222 |
+
|
223 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead
|
224 |
+
of calling the callback of any slave copies.
|
225 |
+
"""
|
226 |
+
master_copy = modules[0]
|
227 |
+
nr_modules = len(list(master_copy.modules()))
|
228 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
229 |
+
|
230 |
+
for i, module in enumerate(modules):
|
231 |
+
for j, m in enumerate(module.modules()):
|
232 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
233 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
234 |
+
|
235 |
+
|
236 |
+
def patch_replication_callback(data_parallel):
|
237 |
+
"""
|
238 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
239 |
+
Useful when you have customized `DataParallel` implementation.
|
240 |
+
|
241 |
+
Examples:
|
242 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
243 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
244 |
+
> patch_replication_callback(sync_bn)
|
245 |
+
# this is equivalent to
|
246 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
247 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
248 |
+
"""
|
249 |
+
|
250 |
+
assert isinstance(data_parallel, DataParallel)
|
251 |
+
|
252 |
+
old_replicate = data_parallel.replicate
|
253 |
+
|
254 |
+
@functools.wraps(old_replicate)
|
255 |
+
def new_replicate(module, device_ids):
|
256 |
+
modules = old_replicate(module, device_ids)
|
257 |
+
execute_replication_callbacks(modules)
|
258 |
+
return modules
|
259 |
+
|
260 |
+
data_parallel.replicate = new_replicate
|
SegmentationTest/utils/render.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.cm
|
3 |
+
import skimage.io
|
4 |
+
import skimage.feature
|
5 |
+
import skimage.filters
|
6 |
+
|
7 |
+
|
8 |
+
def vec2im(V, shape=()):
|
9 |
+
'''
|
10 |
+
Transform an array V into a specified shape - or if no shape is given assume a square output format.
|
11 |
+
|
12 |
+
Parameters
|
13 |
+
----------
|
14 |
+
|
15 |
+
V : numpy.ndarray
|
16 |
+
an array either representing a matrix or vector to be reshaped into an two-dimensional image
|
17 |
+
|
18 |
+
shape : tuple or list
|
19 |
+
optional. containing the shape information for the output array if not given, the output is assumed to be square
|
20 |
+
|
21 |
+
Returns
|
22 |
+
-------
|
23 |
+
|
24 |
+
W : numpy.ndarray
|
25 |
+
with W.shape = shape or W.shape = [np.sqrt(V.size)]*2
|
26 |
+
|
27 |
+
'''
|
28 |
+
|
29 |
+
if len(shape) < 2:
|
30 |
+
shape = [np.sqrt(V.size)] * 2
|
31 |
+
shape = map(int, shape)
|
32 |
+
return np.reshape(V, shape)
|
33 |
+
|
34 |
+
|
35 |
+
def enlarge_image(img, scaling=3):
|
36 |
+
'''
|
37 |
+
Enlarges a given input matrix by replicating each pixel value scaling times in horizontal and vertical direction.
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
|
42 |
+
img : numpy.ndarray
|
43 |
+
array of shape [H x W] OR [H x W x D]
|
44 |
+
|
45 |
+
scaling : int
|
46 |
+
positive integer value > 0
|
47 |
+
|
48 |
+
Returns
|
49 |
+
-------
|
50 |
+
|
51 |
+
out : numpy.ndarray
|
52 |
+
two-dimensional array of shape [scaling*H x scaling*W]
|
53 |
+
OR
|
54 |
+
three-dimensional array of shape [scaling*H x scaling*W x D]
|
55 |
+
depending on the dimensionality of the input
|
56 |
+
'''
|
57 |
+
|
58 |
+
if scaling < 1 or not isinstance(scaling, int):
|
59 |
+
print('scaling factor needs to be an int >= 1')
|
60 |
+
|
61 |
+
if len(img.shape) == 2:
|
62 |
+
H, W = img.shape
|
63 |
+
|
64 |
+
out = np.zeros((scaling * H, scaling * W))
|
65 |
+
for h in range(H):
|
66 |
+
fh = scaling * h
|
67 |
+
for w in range(W):
|
68 |
+
fw = scaling * w
|
69 |
+
out[fh:fh + scaling, fw:fw + scaling] = img[h, w]
|
70 |
+
|
71 |
+
elif len(img.shape) == 3:
|
72 |
+
H, W, D = img.shape
|
73 |
+
|
74 |
+
out = np.zeros((scaling * H, scaling * W, D))
|
75 |
+
for h in range(H):
|
76 |
+
fh = scaling * h
|
77 |
+
for w in range(W):
|
78 |
+
fw = scaling * w
|
79 |
+
out[fh:fh + scaling, fw:fw + scaling, :] = img[h, w, :]
|
80 |
+
|
81 |
+
return out
|
82 |
+
|
83 |
+
|
84 |
+
def repaint_corner_pixels(rgbimg, scaling=3):
|
85 |
+
'''
|
86 |
+
DEPRECATED/OBSOLETE.
|
87 |
+
|
88 |
+
Recolors the top left and bottom right pixel (groups) with the average rgb value of its three neighboring pixel (groups).
|
89 |
+
The recoloring visually masks the opposing pixel values which are a product of stabilizing the scaling.
|
90 |
+
Assumes those image ares will pretty much never show evidence.
|
91 |
+
|
92 |
+
Parameters
|
93 |
+
----------
|
94 |
+
|
95 |
+
rgbimg : numpy.ndarray
|
96 |
+
array of shape [H x W x 3]
|
97 |
+
|
98 |
+
scaling : int
|
99 |
+
positive integer value > 0
|
100 |
+
|
101 |
+
Returns
|
102 |
+
-------
|
103 |
+
|
104 |
+
rgbimg : numpy.ndarray
|
105 |
+
three-dimensional array of shape [scaling*H x scaling*W x 3]
|
106 |
+
'''
|
107 |
+
|
108 |
+
# top left corner.
|
109 |
+
rgbimg[0:scaling, 0:scaling, :] = (rgbimg[0, scaling, :] + rgbimg[scaling, 0, :] + rgbimg[scaling, scaling,
|
110 |
+
:]) / 3.0
|
111 |
+
# bottom right corner
|
112 |
+
rgbimg[-scaling:, -scaling:, :] = (rgbimg[-1, -1 - scaling, :] + rgbimg[-1 - scaling, -1, :] + rgbimg[-1 - scaling,
|
113 |
+
-1 - scaling,
|
114 |
+
:]) / 3.0
|
115 |
+
return rgbimg
|
116 |
+
|
117 |
+
|
118 |
+
def digit_to_rgb(X, scaling=3, shape=(), cmap='binary'):
|
119 |
+
'''
|
120 |
+
Takes as input an intensity array and produces a rgb image due to some color map
|
121 |
+
|
122 |
+
Parameters
|
123 |
+
----------
|
124 |
+
|
125 |
+
X : numpy.ndarray
|
126 |
+
intensity matrix as array of shape [M x N]
|
127 |
+
|
128 |
+
scaling : int
|
129 |
+
optional. positive integer value > 0
|
130 |
+
|
131 |
+
shape: tuple or list of its , length = 2
|
132 |
+
optional. if not given, X is reshaped to be square.
|
133 |
+
|
134 |
+
cmap : str
|
135 |
+
name of color map of choice. default is 'binary'
|
136 |
+
|
137 |
+
Returns
|
138 |
+
-------
|
139 |
+
|
140 |
+
image : numpy.ndarray
|
141 |
+
three-dimensional array of shape [scaling*H x scaling*W x 3] , where H*W == M*N
|
142 |
+
'''
|
143 |
+
|
144 |
+
# create color map object from name string
|
145 |
+
cmap = eval('matplotlib.cm.{}'.format(cmap))
|
146 |
+
|
147 |
+
image = enlarge_image(vec2im(X, shape), scaling) # enlarge
|
148 |
+
image = cmap(image.flatten())[..., 0:3].reshape([image.shape[0], image.shape[1], 3]) # colorize, reshape
|
149 |
+
|
150 |
+
return image
|
151 |
+
|
152 |
+
|
153 |
+
def hm_to_rgb(R, X=None, scaling=3, shape=(), sigma=2, cmap='bwr', normalize=True):
|
154 |
+
'''
|
155 |
+
Takes as input an intensity array and produces a rgb image for the represented heatmap.
|
156 |
+
optionally draws the outline of another input on top of it.
|
157 |
+
|
158 |
+
Parameters
|
159 |
+
----------
|
160 |
+
|
161 |
+
R : numpy.ndarray
|
162 |
+
the heatmap to be visualized, shaped [M x N]
|
163 |
+
|
164 |
+
X : numpy.ndarray
|
165 |
+
optional. some input, usually the data point for which the heatmap R is for, which shall serve
|
166 |
+
as a template for a black outline to be drawn on top of the image
|
167 |
+
shaped [M x N]
|
168 |
+
|
169 |
+
scaling: int
|
170 |
+
factor, on how to enlarge the heatmap (to control resolution and as a inverse way to control outline thickness)
|
171 |
+
after reshaping it using shape.
|
172 |
+
|
173 |
+
shape: tuple or list, length = 2
|
174 |
+
optional. if not given, X is reshaped to be square.
|
175 |
+
|
176 |
+
sigma : double
|
177 |
+
optional. sigma-parameter for the canny algorithm used for edge detection. the found edges are drawn as outlines.
|
178 |
+
|
179 |
+
cmap : str
|
180 |
+
optional. color map of choice
|
181 |
+
|
182 |
+
normalize : bool
|
183 |
+
optional. whether to normalize the heatmap to [-1 1] prior to colorization or not.
|
184 |
+
|
185 |
+
Returns
|
186 |
+
-------
|
187 |
+
|
188 |
+
rgbimg : numpy.ndarray
|
189 |
+
three-dimensional array of shape [scaling*H x scaling*W x 3] , where H*W == M*N
|
190 |
+
'''
|
191 |
+
|
192 |
+
# create color map object from name string
|
193 |
+
cmap = eval('matplotlib.cm.{}'.format(cmap))
|
194 |
+
|
195 |
+
if normalize:
|
196 |
+
R = R / np.max(np.abs(R)) # normalize to [-1,1] wrt to max relevance magnitude
|
197 |
+
R = (R + 1.) / 2. # shift/normalize to [0,1] for color mapping
|
198 |
+
|
199 |
+
R = enlarge_image(R, scaling)
|
200 |
+
rgb = cmap(R.flatten())[..., 0:3].reshape([R.shape[0], R.shape[1], 3])
|
201 |
+
# rgb = repaint_corner_pixels(rgb, scaling) #obsolete due to directly calling the color map with [0,1]-normalized inputs
|
202 |
+
|
203 |
+
if not X is None: # compute the outline of the input
|
204 |
+
# X = enlarge_image(vec2im(X,shape), scaling)
|
205 |
+
xdims = X.shape
|
206 |
+
Rdims = R.shape
|
207 |
+
|
208 |
+
# if not np.all(xdims == Rdims):
|
209 |
+
# print 'transformed heatmap and data dimension mismatch. data dimensions differ?'
|
210 |
+
# print 'R.shape = ',Rdims, 'X.shape = ', xdims
|
211 |
+
# print 'skipping drawing of outline\n'
|
212 |
+
# else:
|
213 |
+
# #edges = skimage.filters.canny(X, sigma=sigma)
|
214 |
+
# edges = skimage.feature.canny(X, sigma=sigma)
|
215 |
+
# edges = np.invert(np.dstack([edges]*3))*1.0
|
216 |
+
# rgb *= edges # set outline pixels to black color
|
217 |
+
|
218 |
+
return rgb
|
219 |
+
|
220 |
+
|
221 |
+
def save_image(rgb_images, path, gap=2):
|
222 |
+
'''
|
223 |
+
Takes as input a list of rgb images, places them next to each other with a gap and writes out the result.
|
224 |
+
|
225 |
+
Parameters
|
226 |
+
----------
|
227 |
+
|
228 |
+
rgb_images : list , tuple, collection. such stuff
|
229 |
+
each item in the collection is expected to be an rgb image of dimensions [H x _ x 3]
|
230 |
+
where the width is variable
|
231 |
+
|
232 |
+
path : str
|
233 |
+
the output path of the assembled image
|
234 |
+
|
235 |
+
gap : int
|
236 |
+
optional. sets the width of a black area of pixels realized as an image shaped [H x gap x 3] in between the input images
|
237 |
+
|
238 |
+
Returns
|
239 |
+
-------
|
240 |
+
|
241 |
+
image : numpy.ndarray
|
242 |
+
the assembled image as written out to path
|
243 |
+
'''
|
244 |
+
|
245 |
+
sz = []
|
246 |
+
image = []
|
247 |
+
for i in range(len(rgb_images)):
|
248 |
+
if not sz:
|
249 |
+
sz = rgb_images[i].shape
|
250 |
+
image = rgb_images[i]
|
251 |
+
gap = np.zeros((sz[0], gap, sz[2]))
|
252 |
+
continue
|
253 |
+
if not sz[0] == rgb_images[i].shape[0] and sz[1] == rgb_images[i].shape[2]:
|
254 |
+
print('image', i, 'differs in size. unable to perform horizontal alignment')
|
255 |
+
print('expected: Hx_xD = {0}x_x{1}'.format(sz[0], sz[1]))
|
256 |
+
print('got : Hx_xD = {0}x_x{1}'.format(rgb_images[i].shape[0], rgb_images[i].shape[1]))
|
257 |
+
print('skipping image\n')
|
258 |
+
else:
|
259 |
+
image = np.hstack((image, gap, rgb_images[i]))
|
260 |
+
|
261 |
+
image *= 255
|
262 |
+
image = image.astype(np.uint8)
|
263 |
+
|
264 |
+
print('saving image to ', path)
|
265 |
+
skimage.io.imsave(path, image)
|
266 |
+
return image
|
SegmentationTest/utils/saver.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
import glob
|
5 |
+
|
6 |
+
|
7 |
+
class Saver(object):
|
8 |
+
|
9 |
+
def __init__(self, args):
|
10 |
+
self.args = args
|
11 |
+
self.directory = os.path.join('run', args.train_dataset, args.checkname)
|
12 |
+
self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*')))
|
13 |
+
run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0
|
14 |
+
|
15 |
+
self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)))
|
16 |
+
if not os.path.exists(self.experiment_dir):
|
17 |
+
os.makedirs(self.experiment_dir)
|
18 |
+
|
19 |
+
def save_checkpoint(self, state, filename='checkpoint.pth.tar'):
|
20 |
+
"""Saves checkpoint to disk"""
|
21 |
+
filename = os.path.join(self.experiment_dir, filename)
|
22 |
+
torch.save(state, filename)
|
23 |
+
|
24 |
+
def save_experiment_config(self):
|
25 |
+
logfile = os.path.join(self.experiment_dir, 'parameters.txt')
|
26 |
+
log_file = open(logfile, 'w')
|
27 |
+
p = OrderedDict()
|
28 |
+
p['train_dataset'] = self.args.train_dataset
|
29 |
+
p['lr'] = self.args.lr
|
30 |
+
p['epoch'] = self.args.epochs
|
31 |
+
|
32 |
+
for key, val in p.items():
|
33 |
+
log_file.write(key + ':' + str(val) + '\n')
|
34 |
+
log_file.close()
|
SegmentationTest/utils/summaries.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.tensorboard import SummaryWriter
|
3 |
+
|
4 |
+
|
5 |
+
class TensorboardSummary(object):
|
6 |
+
def __init__(self, directory):
|
7 |
+
self.directory = directory
|
8 |
+
self.writer = SummaryWriter(log_dir=os.path.join(self.directory))
|
9 |
+
|
10 |
+
def add_scalar(self, *args):
|
11 |
+
self.writer.add_scalar(*args)
|
ViT/ViT.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Vision Transformer (ViT) in PyTorch
|
2 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from functools import partial
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
from ViT.helpers import load_pretrained
|
10 |
+
from ViT.weight_init import trunc_normal_
|
11 |
+
from ViT.layer_helpers import to_2tuple
|
12 |
+
|
13 |
+
|
14 |
+
def _cfg(url='', **kwargs):
|
15 |
+
return {
|
16 |
+
'url': url,
|
17 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
18 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
19 |
+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
20 |
+
**kwargs
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
default_cfgs = {
|
25 |
+
# patch models
|
26 |
+
'vit_small_patch16_224': _cfg(
|
27 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
28 |
+
),
|
29 |
+
'vit_base_patch16_224': _cfg(
|
30 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
31 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
32 |
+
),
|
33 |
+
'vit_large_patch16_224': _cfg(
|
34 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
35 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
36 |
+
|
37 |
+
# deit models (FB weights)
|
38 |
+
'deit_tiny_patch16_224': _cfg(
|
39 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
40 |
+
'deit_small_patch16_224': _cfg(
|
41 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
42 |
+
'deit_base_patch16_224': _cfg(
|
43 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', ),
|
44 |
+
'deit_base_patch16_384': _cfg(
|
45 |
+
url='', # no weights yet
|
46 |
+
input_size=(3, 384, 384)),
|
47 |
+
}
|
48 |
+
|
49 |
+
class Mlp(nn.Module):
|
50 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
51 |
+
super().__init__()
|
52 |
+
out_features = out_features or in_features
|
53 |
+
hidden_features = hidden_features or in_features
|
54 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
55 |
+
self.act = act_layer()
|
56 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
57 |
+
self.drop = nn.Dropout(drop)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = self.fc1(x)
|
61 |
+
x = self.act(x)
|
62 |
+
x = self.drop(x)
|
63 |
+
x = self.fc2(x)
|
64 |
+
x = self.drop(x)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
class Attention(nn.Module):
|
69 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
|
70 |
+
super().__init__()
|
71 |
+
self.num_heads = num_heads
|
72 |
+
head_dim = dim // num_heads
|
73 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
74 |
+
self.scale = head_dim ** -0.5
|
75 |
+
|
76 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
77 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
78 |
+
self.proj = nn.Linear(dim, dim)
|
79 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
80 |
+
|
81 |
+
self.attn_gradients = None
|
82 |
+
self.attention_map = None
|
83 |
+
|
84 |
+
def save_attn_gradients(self, attn_gradients):
|
85 |
+
self.attn_gradients = attn_gradients
|
86 |
+
|
87 |
+
def get_attn_gradients(self):
|
88 |
+
return self.attn_gradients
|
89 |
+
|
90 |
+
def save_attention_map(self, attention_map):
|
91 |
+
self.attention_map = attention_map
|
92 |
+
|
93 |
+
def get_attention_map(self):
|
94 |
+
return self.attention_map
|
95 |
+
|
96 |
+
def forward(self, x, register_hook=False, return_attentions=False):
|
97 |
+
b, n, _, h = *x.shape, self.num_heads
|
98 |
+
|
99 |
+
qkv = self.qkv(x)
|
100 |
+
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
|
101 |
+
|
102 |
+
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
|
103 |
+
|
104 |
+
attn = dots.softmax(dim=-1)
|
105 |
+
attn = self.attn_drop(attn)
|
106 |
+
|
107 |
+
out = torch.einsum('bhij,bhjd->bhid', attn, v)
|
108 |
+
|
109 |
+
self.save_attention_map(attn)
|
110 |
+
if register_hook:
|
111 |
+
attn.register_hook(self.save_attn_gradients)
|
112 |
+
|
113 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
114 |
+
out = self.proj(out)
|
115 |
+
out = self.proj_drop(out)
|
116 |
+
if not return_attentions:
|
117 |
+
return out
|
118 |
+
else:
|
119 |
+
return out, attn
|
120 |
+
|
121 |
+
|
122 |
+
class Block(nn.Module):
|
123 |
+
|
124 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
125 |
+
super().__init__()
|
126 |
+
self.norm1 = norm_layer(dim)
|
127 |
+
self.attn = Attention(
|
128 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
129 |
+
self.norm2 = norm_layer(dim)
|
130 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
131 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
132 |
+
|
133 |
+
def forward(self, x, register_hook=False, return_attentions=False):
|
134 |
+
if not return_attentions:
|
135 |
+
x = x + self.attn(self.norm1(x), register_hook=register_hook)
|
136 |
+
else:
|
137 |
+
attn_res, attn = self.attn(self.norm1(x), register_hook=register_hook, return_attentions=True)
|
138 |
+
x = x + attn_res
|
139 |
+
x = x + self.mlp(self.norm2(x))
|
140 |
+
if not return_attentions:
|
141 |
+
return x
|
142 |
+
else:
|
143 |
+
return x, attn
|
144 |
+
|
145 |
+
|
146 |
+
class PatchEmbed(nn.Module):
|
147 |
+
""" Image to Patch Embedding
|
148 |
+
"""
|
149 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
150 |
+
super().__init__()
|
151 |
+
img_size = to_2tuple(img_size)
|
152 |
+
patch_size = to_2tuple(patch_size)
|
153 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
154 |
+
self.img_size = img_size
|
155 |
+
self.patch_size = patch_size
|
156 |
+
self.num_patches = num_patches
|
157 |
+
|
158 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
B, C, H, W = x.shape
|
162 |
+
# FIXME look at relaxing size constraints
|
163 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
164 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
165 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
166 |
+
return x
|
167 |
+
|
168 |
+
class VisionTransformer(nn.Module):
|
169 |
+
""" Vision Transformer
|
170 |
+
"""
|
171 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
172 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., norm_layer=nn.LayerNorm):
|
173 |
+
super().__init__()
|
174 |
+
self.num_classes = num_classes
|
175 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
176 |
+
self.patch_embed = PatchEmbed(
|
177 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
178 |
+
num_patches = self.patch_embed.num_patches
|
179 |
+
|
180 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
181 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
182 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
183 |
+
|
184 |
+
self.blocks = nn.ModuleList([
|
185 |
+
Block(
|
186 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
187 |
+
drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
|
188 |
+
for i in range(depth)])
|
189 |
+
self.norm = norm_layer(embed_dim)
|
190 |
+
|
191 |
+
# Classifier head
|
192 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
193 |
+
|
194 |
+
trunc_normal_(self.pos_embed, std=.02)
|
195 |
+
trunc_normal_(self.cls_token, std=.02)
|
196 |
+
self.apply(self._init_weights)
|
197 |
+
|
198 |
+
def _init_weights(self, m):
|
199 |
+
if isinstance(m, nn.Linear):
|
200 |
+
trunc_normal_(m.weight, std=.02)
|
201 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
202 |
+
nn.init.constant_(m.bias, 0)
|
203 |
+
elif isinstance(m, nn.LayerNorm):
|
204 |
+
nn.init.constant_(m.bias, 0)
|
205 |
+
nn.init.constant_(m.weight, 1.0)
|
206 |
+
|
207 |
+
@torch.jit.ignore
|
208 |
+
def no_weight_decay(self):
|
209 |
+
return {'pos_embed', 'cls_token'}
|
210 |
+
|
211 |
+
def forward(self, x, register_hook=False, return_attentions=False):
|
212 |
+
if return_attentions:
|
213 |
+
attentions = []
|
214 |
+
|
215 |
+
B = x.shape[0]
|
216 |
+
x = self.patch_embed(x)
|
217 |
+
|
218 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
219 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
220 |
+
x = x + self.pos_embed
|
221 |
+
x = self.pos_drop(x)
|
222 |
+
|
223 |
+
for blk in self.blocks:
|
224 |
+
if not return_attentions:
|
225 |
+
x = blk(x, register_hook=register_hook)
|
226 |
+
else:
|
227 |
+
x, attn = blk(x, register_hook=register_hook, return_attentions=True)
|
228 |
+
attentions.append(attn)
|
229 |
+
|
230 |
+
x = self.norm(x)
|
231 |
+
x = x[:, 0]
|
232 |
+
x = self.head(x)
|
233 |
+
|
234 |
+
if not return_attentions:
|
235 |
+
return x
|
236 |
+
else:
|
237 |
+
return x, torch.cat(attentions).unsqueeze(0)
|
238 |
+
|
239 |
+
|
240 |
+
def _conv_filter(state_dict, patch_size=16):
|
241 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
242 |
+
out_dict = {}
|
243 |
+
for k, v in state_dict.items():
|
244 |
+
if 'patch_embed.proj.weight' in k:
|
245 |
+
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
246 |
+
out_dict[k] = v
|
247 |
+
return out_dict
|
248 |
+
|
249 |
+
|
250 |
+
def vit_base_patch16_224(pretrained=False, **kwargs):
|
251 |
+
model = VisionTransformer(
|
252 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
253 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
254 |
+
model.default_cfg = default_cfgs['vit_base_patch16_224']
|
255 |
+
if pretrained:
|
256 |
+
load_pretrained(
|
257 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
|
258 |
+
return model
|
259 |
+
|
260 |
+
|
261 |
+
def vit_base_finetuned_patch16_224(pretrained=False, **kwargs):
|
262 |
+
model = VisionTransformer(
|
263 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
264 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
265 |
+
model.default_cfg = default_cfgs['vit_base_finetuned_patch16_224']
|
266 |
+
if pretrained:
|
267 |
+
load_pretrained(
|
268 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
|
269 |
+
return model
|
270 |
+
|
271 |
+
def vit_large_patch16_224(pretrained=False, **kwargs):
|
272 |
+
model = VisionTransformer(
|
273 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
274 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
275 |
+
model.default_cfg = default_cfgs['vit_large_patch16_224']
|
276 |
+
if pretrained:
|
277 |
+
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
278 |
+
return model
|
279 |
+
|
280 |
+
def deit_tiny_patch16_224(pretrained=False, **kwargs):
|
281 |
+
model = VisionTransformer(
|
282 |
+
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
283 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
284 |
+
model.default_cfg = default_cfgs['deit_tiny_patch16_224']
|
285 |
+
if pretrained:
|
286 |
+
load_pretrained(
|
287 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
|
288 |
+
return model
|
289 |
+
|
290 |
+
def deit_small_patch16_224(pretrained=False, **kwargs):
|
291 |
+
model = VisionTransformer(
|
292 |
+
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
293 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
294 |
+
model.default_cfg = default_cfgs['deit_small_patch16_224']
|
295 |
+
if pretrained:
|
296 |
+
load_pretrained(
|
297 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
|
298 |
+
return model
|
299 |
+
|
300 |
+
def deit_base_patch16_224(pretrained=False, **kwargs):
|
301 |
+
model = VisionTransformer(
|
302 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
303 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
304 |
+
model.default_cfg = default_cfgs['deit_base_patch16_224']
|
305 |
+
if pretrained:
|
306 |
+
load_pretrained(
|
307 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
|
308 |
+
return model
|
ViT_new.py → ViT/ViT_new.py
RENAMED
File without changes
|
ViT/__init__.py
ADDED
File without changes
|
ViT/explainer.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
# rule 5 from paper
|
6 |
+
def avg_heads(cam, grad):
|
7 |
+
cam = cam.reshape(-1, cam.shape[-3], cam.shape[-2], cam.shape[-1])
|
8 |
+
grad = grad.reshape(-1, cam.shape[-3], grad.shape[-2], grad.shape[-1])
|
9 |
+
cam = grad * cam
|
10 |
+
cam = cam.clamp(min=0).mean(dim=1)
|
11 |
+
return cam
|
12 |
+
|
13 |
+
# rule 6 from paper
|
14 |
+
def apply_self_attention_rules(R_ss, cam_ss):
|
15 |
+
R_ss_addition = torch.matmul(cam_ss, R_ss)
|
16 |
+
return R_ss_addition
|
17 |
+
|
18 |
+
def upscale_relevance(relevance):
|
19 |
+
relevance = relevance.reshape(-1, 1, 14, 14)
|
20 |
+
relevance = torch.nn.functional.interpolate(relevance, scale_factor=16, mode='bilinear')
|
21 |
+
|
22 |
+
# normalize between 0 and 1
|
23 |
+
relevance = relevance.reshape(relevance.shape[0], -1)
|
24 |
+
min = relevance.min(1, keepdim=True)[0]
|
25 |
+
max = relevance.max(1, keepdim=True)[0]
|
26 |
+
relevance = (relevance - min) / (max - min)
|
27 |
+
|
28 |
+
relevance = relevance.reshape(-1, 1, 224, 224)
|
29 |
+
return relevance
|
30 |
+
|
31 |
+
def generate_relevance(model, input, index=None):
|
32 |
+
# a batch of samples
|
33 |
+
batch_size = input.shape[0]
|
34 |
+
output = model(input, register_hook=True)
|
35 |
+
if index == None:
|
36 |
+
index = np.argmax(output.cpu().data.numpy(), axis=-1)
|
37 |
+
index = torch.tensor(index)
|
38 |
+
|
39 |
+
one_hot = np.zeros((batch_size, output.shape[-1]), dtype=np.float32)
|
40 |
+
one_hot[torch.arange(batch_size), index.data.cpu().numpy()] = 1
|
41 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
42 |
+
one_hot = torch.sum(one_hot.to(input.device) * output)
|
43 |
+
model.zero_grad()
|
44 |
+
|
45 |
+
num_tokens = model.blocks[0].attn.get_attention_map().shape[-1]
|
46 |
+
R = torch.eye(num_tokens, num_tokens).cuda()
|
47 |
+
R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
|
48 |
+
for i, blk in enumerate(model.blocks):
|
49 |
+
grad = torch.autograd.grad(one_hot, [blk.attn.attention_map], retain_graph=True)[0]
|
50 |
+
cam = blk.attn.get_attention_map()
|
51 |
+
cam = avg_heads(cam, grad)
|
52 |
+
R = R + apply_self_attention_rules(R, cam)
|
53 |
+
relevance = R[:, 0, 1:]
|
54 |
+
return upscale_relevance(relevance)
|
55 |
+
|
56 |
+
# create heatmap from mask on image
|
57 |
+
def show_cam_on_image(img, mask):
|
58 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
59 |
+
heatmap = np.float32(heatmap) / 255
|
60 |
+
cam = heatmap + np.float32(img)
|
61 |
+
cam = cam / np.max(cam)
|
62 |
+
return cam
|
63 |
+
|
64 |
+
|
65 |
+
def get_image_with_relevance(image, relevance):
|
66 |
+
image = image.permute(1, 2, 0)
|
67 |
+
relevance = relevance.permute(1, 2, 0)
|
68 |
+
image = (image - image.min()) / (image.max() - image.min())
|
69 |
+
image = 255 * image
|
70 |
+
vis = image * relevance
|
71 |
+
return vis.data.cpu().numpy()
|
ViT/helpers.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Model creation / weight loading / state_dict helpers
|
2 |
+
|
3 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
4 |
+
"""
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import math
|
8 |
+
from collections import OrderedDict
|
9 |
+
from copy import deepcopy
|
10 |
+
from typing import Callable
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.utils.model_zoo as model_zoo
|
15 |
+
|
16 |
+
_logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
def load_state_dict(checkpoint_path, use_ema=False):
|
20 |
+
if checkpoint_path and os.path.isfile(checkpoint_path):
|
21 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
22 |
+
state_dict_key = 'state_dict'
|
23 |
+
if isinstance(checkpoint, dict):
|
24 |
+
if use_ema and 'state_dict_ema' in checkpoint:
|
25 |
+
state_dict_key = 'state_dict_ema'
|
26 |
+
if state_dict_key and state_dict_key in checkpoint:
|
27 |
+
new_state_dict = OrderedDict()
|
28 |
+
for k, v in checkpoint[state_dict_key].items():
|
29 |
+
# strip `module.` prefix
|
30 |
+
name = k[7:] if k.startswith('module') else k
|
31 |
+
new_state_dict[name] = v
|
32 |
+
state_dict = new_state_dict
|
33 |
+
else:
|
34 |
+
state_dict = checkpoint
|
35 |
+
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
|
36 |
+
return state_dict
|
37 |
+
else:
|
38 |
+
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
39 |
+
raise FileNotFoundError()
|
40 |
+
|
41 |
+
|
42 |
+
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
|
43 |
+
state_dict = load_state_dict(checkpoint_path, use_ema)
|
44 |
+
model.load_state_dict(state_dict, strict=strict)
|
45 |
+
|
46 |
+
|
47 |
+
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
48 |
+
resume_epoch = None
|
49 |
+
if os.path.isfile(checkpoint_path):
|
50 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
51 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
52 |
+
if log_info:
|
53 |
+
_logger.info('Restoring model state from checkpoint...')
|
54 |
+
new_state_dict = OrderedDict()
|
55 |
+
for k, v in checkpoint['state_dict'].items():
|
56 |
+
name = k[7:] if k.startswith('module') else k
|
57 |
+
new_state_dict[name] = v
|
58 |
+
model.load_state_dict(new_state_dict)
|
59 |
+
|
60 |
+
if optimizer is not None and 'optimizer' in checkpoint:
|
61 |
+
if log_info:
|
62 |
+
_logger.info('Restoring optimizer state from checkpoint...')
|
63 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
64 |
+
|
65 |
+
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
|
66 |
+
if log_info:
|
67 |
+
_logger.info('Restoring AMP loss scaler state from checkpoint...')
|
68 |
+
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
|
69 |
+
|
70 |
+
if 'epoch' in checkpoint:
|
71 |
+
resume_epoch = checkpoint['epoch']
|
72 |
+
if 'version' in checkpoint and checkpoint['version'] > 1:
|
73 |
+
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
74 |
+
|
75 |
+
if log_info:
|
76 |
+
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
77 |
+
else:
|
78 |
+
model.load_state_dict(checkpoint)
|
79 |
+
if log_info:
|
80 |
+
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
81 |
+
return resume_epoch
|
82 |
+
else:
|
83 |
+
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
84 |
+
raise FileNotFoundError()
|
85 |
+
|
86 |
+
|
87 |
+
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True):
|
88 |
+
if cfg is None:
|
89 |
+
cfg = getattr(model, 'default_cfg')
|
90 |
+
if cfg is None or 'url' not in cfg or not cfg['url']:
|
91 |
+
_logger.warning("Pretrained model URL is invalid, using random initialization.")
|
92 |
+
return
|
93 |
+
|
94 |
+
state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
|
95 |
+
|
96 |
+
if filter_fn is not None:
|
97 |
+
state_dict = filter_fn(state_dict)
|
98 |
+
|
99 |
+
if in_chans == 1:
|
100 |
+
conv1_name = cfg['first_conv']
|
101 |
+
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
|
102 |
+
conv1_weight = state_dict[conv1_name + '.weight']
|
103 |
+
# Some weights are in torch.half, ensure it's float for sum on CPU
|
104 |
+
conv1_type = conv1_weight.dtype
|
105 |
+
conv1_weight = conv1_weight.float()
|
106 |
+
O, I, J, K = conv1_weight.shape
|
107 |
+
if I > 3:
|
108 |
+
assert conv1_weight.shape[1] % 3 == 0
|
109 |
+
# For models with space2depth stems
|
110 |
+
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
|
111 |
+
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
|
112 |
+
else:
|
113 |
+
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
|
114 |
+
conv1_weight = conv1_weight.to(conv1_type)
|
115 |
+
state_dict[conv1_name + '.weight'] = conv1_weight
|
116 |
+
elif in_chans != 3:
|
117 |
+
conv1_name = cfg['first_conv']
|
118 |
+
conv1_weight = state_dict[conv1_name + '.weight']
|
119 |
+
conv1_type = conv1_weight.dtype
|
120 |
+
conv1_weight = conv1_weight.float()
|
121 |
+
O, I, J, K = conv1_weight.shape
|
122 |
+
if I != 3:
|
123 |
+
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
|
124 |
+
del state_dict[conv1_name + '.weight']
|
125 |
+
strict = False
|
126 |
+
else:
|
127 |
+
# NOTE this strategy should be better than random init, but there could be other combinations of
|
128 |
+
# the original RGB input layer weights that'd work better for specific cases.
|
129 |
+
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
|
130 |
+
repeat = int(math.ceil(in_chans / 3))
|
131 |
+
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
132 |
+
conv1_weight *= (3 / float(in_chans))
|
133 |
+
conv1_weight = conv1_weight.to(conv1_type)
|
134 |
+
state_dict[conv1_name + '.weight'] = conv1_weight
|
135 |
+
|
136 |
+
classifier_name = cfg['classifier']
|
137 |
+
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
138 |
+
# special case for imagenet trained models with extra background class in pretrained weights
|
139 |
+
classifier_weight = state_dict[classifier_name + '.weight']
|
140 |
+
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
141 |
+
classifier_bias = state_dict[classifier_name + '.bias']
|
142 |
+
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
|
143 |
+
elif num_classes != cfg['num_classes']:
|
144 |
+
# completely discard fully connected for all other differences between pretrained and created model
|
145 |
+
del state_dict[classifier_name + '.weight']
|
146 |
+
del state_dict[classifier_name + '.bias']
|
147 |
+
strict = False
|
148 |
+
|
149 |
+
model.load_state_dict(state_dict, strict=strict)
|
150 |
+
|
151 |
+
|
152 |
+
def extract_layer(model, layer):
|
153 |
+
layer = layer.split('.')
|
154 |
+
module = model
|
155 |
+
if hasattr(model, 'module') and layer[0] != 'module':
|
156 |
+
module = model.module
|
157 |
+
if not hasattr(model, 'module') and layer[0] == 'module':
|
158 |
+
layer = layer[1:]
|
159 |
+
for l in layer:
|
160 |
+
if hasattr(module, l):
|
161 |
+
if not l.isdigit():
|
162 |
+
module = getattr(module, l)
|
163 |
+
else:
|
164 |
+
module = module[int(l)]
|
165 |
+
else:
|
166 |
+
return module
|
167 |
+
return module
|
168 |
+
|
169 |
+
|
170 |
+
def set_layer(model, layer, val):
|
171 |
+
layer = layer.split('.')
|
172 |
+
module = model
|
173 |
+
if hasattr(model, 'module') and layer[0] != 'module':
|
174 |
+
module = model.module
|
175 |
+
lst_index = 0
|
176 |
+
module2 = module
|
177 |
+
for l in layer:
|
178 |
+
if hasattr(module2, l):
|
179 |
+
if not l.isdigit():
|
180 |
+
module2 = getattr(module2, l)
|
181 |
+
else:
|
182 |
+
module2 = module2[int(l)]
|
183 |
+
lst_index += 1
|
184 |
+
lst_index -= 1
|
185 |
+
for l in layer[:lst_index]:
|
186 |
+
if not l.isdigit():
|
187 |
+
module = getattr(module, l)
|
188 |
+
else:
|
189 |
+
module = module[int(l)]
|
190 |
+
l = layer[lst_index]
|
191 |
+
setattr(module, l, val)
|
192 |
+
|
193 |
+
|
194 |
+
def adapt_model_from_string(parent_module, model_string):
|
195 |
+
separator = '***'
|
196 |
+
state_dict = {}
|
197 |
+
lst_shape = model_string.split(separator)
|
198 |
+
for k in lst_shape:
|
199 |
+
k = k.split(':')
|
200 |
+
key = k[0]
|
201 |
+
shape = k[1][1:-1].split(',')
|
202 |
+
if shape[0] != '':
|
203 |
+
state_dict[key] = [int(i) for i in shape]
|
204 |
+
|
205 |
+
new_module = deepcopy(parent_module)
|
206 |
+
for n, m in parent_module.named_modules():
|
207 |
+
old_module = extract_layer(parent_module, n)
|
208 |
+
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
|
209 |
+
if isinstance(old_module, Conv2dSame):
|
210 |
+
conv = Conv2dSame
|
211 |
+
else:
|
212 |
+
conv = nn.Conv2d
|
213 |
+
s = state_dict[n + '.weight']
|
214 |
+
in_channels = s[1]
|
215 |
+
out_channels = s[0]
|
216 |
+
g = 1
|
217 |
+
if old_module.groups > 1:
|
218 |
+
in_channels = out_channels
|
219 |
+
g = in_channels
|
220 |
+
new_conv = conv(
|
221 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
|
222 |
+
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
223 |
+
groups=g, stride=old_module.stride)
|
224 |
+
set_layer(new_module, n, new_conv)
|
225 |
+
if isinstance(old_module, nn.BatchNorm2d):
|
226 |
+
new_bn = nn.BatchNorm2d(
|
227 |
+
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
228 |
+
affine=old_module.affine, track_running_stats=True)
|
229 |
+
set_layer(new_module, n, new_bn)
|
230 |
+
if isinstance(old_module, nn.Linear):
|
231 |
+
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
|
232 |
+
num_features = state_dict[n + '.weight'][1]
|
233 |
+
new_fc = nn.Linear(
|
234 |
+
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
|
235 |
+
set_layer(new_module, n, new_fc)
|
236 |
+
if hasattr(new_module, 'num_features'):
|
237 |
+
new_module.num_features = num_features
|
238 |
+
new_module.eval()
|
239 |
+
parent_module.eval()
|
240 |
+
|
241 |
+
return new_module
|
242 |
+
|
243 |
+
|
244 |
+
def adapt_model_from_file(parent_module, model_variant):
|
245 |
+
adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
|
246 |
+
with open(adapt_file, 'r') as f:
|
247 |
+
return adapt_model_from_string(parent_module, f.read().strip())
|
248 |
+
|
249 |
+
|
250 |
+
def build_model_with_cfg(
|
251 |
+
model_cls: Callable,
|
252 |
+
variant: str,
|
253 |
+
pretrained: bool,
|
254 |
+
default_cfg: dict,
|
255 |
+
model_cfg: dict = None,
|
256 |
+
feature_cfg: dict = None,
|
257 |
+
pretrained_strict: bool = True,
|
258 |
+
pretrained_filter_fn: Callable = None,
|
259 |
+
**kwargs):
|
260 |
+
pruned = kwargs.pop('pruned', False)
|
261 |
+
features = False
|
262 |
+
feature_cfg = feature_cfg or {}
|
263 |
+
|
264 |
+
if kwargs.pop('features_only', False):
|
265 |
+
features = True
|
266 |
+
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
267 |
+
if 'out_indices' in kwargs:
|
268 |
+
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
269 |
+
|
270 |
+
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
|
271 |
+
model.default_cfg = deepcopy(default_cfg)
|
272 |
+
|
273 |
+
if pruned:
|
274 |
+
model = adapt_model_from_file(model, variant)
|
275 |
+
|
276 |
+
if pretrained:
|
277 |
+
load_pretrained(
|
278 |
+
model,
|
279 |
+
num_classes=kwargs.get('num_classes', 0),
|
280 |
+
in_chans=kwargs.get('in_chans', 3),
|
281 |
+
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
|
282 |
+
|
283 |
+
if features:
|
284 |
+
feature_cls = FeatureListNet
|
285 |
+
if 'feature_cls' in feature_cfg:
|
286 |
+
feature_cls = feature_cfg.pop('feature_cls')
|
287 |
+
if isinstance(feature_cls, str):
|
288 |
+
feature_cls = feature_cls.lower()
|
289 |
+
if 'hook' in feature_cls:
|
290 |
+
feature_cls = FeatureHookNet
|
291 |
+
else:
|
292 |
+
assert False, f'Unknown feature class {feature_cls}'
|
293 |
+
model = feature_cls(model, **feature_cfg)
|
294 |
+
|
295 |
+
return model
|
ViT/layer_helpers.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Layer/Module Helpers
|
2 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
3 |
+
"""
|
4 |
+
from itertools import repeat
|
5 |
+
import collections.abc
|
6 |
+
|
7 |
+
|
8 |
+
# From PyTorch internals
|
9 |
+
def _ntuple(n):
|
10 |
+
def parse(x):
|
11 |
+
if isinstance(x, collections.abc.Iterable):
|
12 |
+
return x
|
13 |
+
return tuple(repeat(x, n))
|
14 |
+
return parse
|
15 |
+
|
16 |
+
|
17 |
+
to_1tuple = _ntuple(1)
|
18 |
+
to_2tuple = _ntuple(2)
|
19 |
+
to_3tuple = _ntuple(3)
|
20 |
+
to_4tuple = _ntuple(4)
|
21 |
+
to_ntuple = _ntuple
|
ViT/weight_init.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
|
6 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
7 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
8 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
9 |
+
def norm_cdf(x):
|
10 |
+
# Computes standard normal cumulative distribution function
|
11 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
12 |
+
|
13 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
14 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
15 |
+
"The distribution of values may be incorrect.",
|
16 |
+
stacklevel=2)
|
17 |
+
|
18 |
+
with torch.no_grad():
|
19 |
+
# Values are generated by using a truncated uniform distribution and
|
20 |
+
# then using the inverse CDF for the normal distribution.
|
21 |
+
# Get upper and lower cdf values
|
22 |
+
l = norm_cdf((a - mean) / std)
|
23 |
+
u = norm_cdf((b - mean) / std)
|
24 |
+
|
25 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
26 |
+
# [2l-1, 2u-1].
|
27 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
28 |
+
|
29 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
30 |
+
# standard normal
|
31 |
+
tensor.erfinv_()
|
32 |
+
|
33 |
+
# Transform to proper mean, std
|
34 |
+
tensor.mul_(std * math.sqrt(2.))
|
35 |
+
tensor.add_(mean)
|
36 |
+
|
37 |
+
# Clamp to ensure it's in the proper range
|
38 |
+
tensor.clamp_(min=a, max=b)
|
39 |
+
return tensor
|
40 |
+
|
41 |
+
|
42 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
43 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
44 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
45 |
+
normal distribution. The values are effectively drawn from the
|
46 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
47 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
48 |
+
the bounds. The method used for generating the random values works
|
49 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
50 |
+
Args:
|
51 |
+
tensor: an n-dimensional `torch.Tensor`
|
52 |
+
mean: the mean of the normal distribution
|
53 |
+
std: the standard deviation of the normal distribution
|
54 |
+
a: the minimum cutoff value
|
55 |
+
b: the maximum cutoff value
|
56 |
+
Examples:
|
57 |
+
>>> w = torch.empty(3, 5)
|
58 |
+
>>> nn.init.trunc_normal_(w)
|
59 |
+
"""
|
60 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
imagenet_ablation_gt.py
ADDED
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import shutil
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.parallel
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
import torch.distributed as dist
|
13 |
+
import torch.optim
|
14 |
+
import torch.multiprocessing as mp
|
15 |
+
import torch.utils.data
|
16 |
+
import torch.utils.data.distributed
|
17 |
+
import torchvision.transforms as transforms
|
18 |
+
import torchvision.datasets as datasets
|
19 |
+
import torchvision.models as models
|
20 |
+
from segmentation_dataset import SegmentationDataset, VAL_PARTITION, TRAIN_PARTITION
|
21 |
+
|
22 |
+
# Uncomment the expected model below
|
23 |
+
|
24 |
+
# ViT
|
25 |
+
from ViT.ViT import vit_base_patch16_224 as vit
|
26 |
+
# from ViT.ViT import vit_large_patch16_224 as vit
|
27 |
+
|
28 |
+
# ViT-AugReg
|
29 |
+
# from ViT.ViT_new import vit_small_patch16_224 as vit
|
30 |
+
# from ViT.ViT_new import vit_base_patch16_224 as vit
|
31 |
+
# from ViT.ViT_new import vit_large_patch16_224 as vit
|
32 |
+
|
33 |
+
# DeiT
|
34 |
+
# from ViT.ViT import deit_base_patch16_224 as vit
|
35 |
+
# from ViT.ViT import deit_small_patch16_224 as vit
|
36 |
+
|
37 |
+
from ViT.explainer import generate_relevance, get_image_with_relevance
|
38 |
+
import torchvision
|
39 |
+
import cv2
|
40 |
+
from torch.utils.tensorboard import SummaryWriter
|
41 |
+
import json
|
42 |
+
|
43 |
+
model_names = sorted(name for name in models.__dict__
|
44 |
+
if name.islower() and not name.startswith("__")
|
45 |
+
and callable(models.__dict__[name]))
|
46 |
+
model_names.append("vit")
|
47 |
+
|
48 |
+
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
49 |
+
parser.add_argument('--data', metavar='DATA',
|
50 |
+
help='path to dataset')
|
51 |
+
parser.add_argument('--seg_data', metavar='SEG_DATA',
|
52 |
+
help='path to segmentation dataset')
|
53 |
+
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
|
54 |
+
choices=model_names,
|
55 |
+
help='model architecture: ' +
|
56 |
+
' | '.join(model_names) +
|
57 |
+
' (default: resnet18)')
|
58 |
+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
59 |
+
help='number of data loading workers (default: 4)')
|
60 |
+
parser.add_argument('--epochs', default=150, type=int, metavar='N',
|
61 |
+
help='number of total epochs to run')
|
62 |
+
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
63 |
+
help='manual epoch number (useful on restarts)')
|
64 |
+
parser.add_argument('-b', '--batch-size', default=8, type=int,
|
65 |
+
metavar='N',
|
66 |
+
help='mini-batch size (default: 256), this is the total '
|
67 |
+
'batch size of all GPUs on the current node when '
|
68 |
+
'using Data Parallel or Distributed Data Parallel')
|
69 |
+
parser.add_argument('--lr', '--learning-rate', default=3e-6, type=float,
|
70 |
+
metavar='LR', help='initial learning rate', dest='lr')
|
71 |
+
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
72 |
+
help='momentum')
|
73 |
+
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
74 |
+
metavar='W', help='weight decay (default: 1e-4)',
|
75 |
+
dest='weight_decay')
|
76 |
+
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
77 |
+
metavar='N', help='print frequency (default: 10)')
|
78 |
+
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
79 |
+
help='path to latest checkpoint (default: none)')
|
80 |
+
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
81 |
+
help='evaluate model on validation set')
|
82 |
+
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
83 |
+
help='use pre-trained model')
|
84 |
+
parser.add_argument('--world-size', default=-1, type=int,
|
85 |
+
help='number of nodes for distributed training')
|
86 |
+
parser.add_argument('--rank', default=-1, type=int,
|
87 |
+
help='node rank for distributed training')
|
88 |
+
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
89 |
+
help='url used to set up distributed training')
|
90 |
+
parser.add_argument('--dist-backend', default='nccl', type=str,
|
91 |
+
help='distributed backend')
|
92 |
+
parser.add_argument('--seed', default=None, type=int,
|
93 |
+
help='seed for initializing training. ')
|
94 |
+
parser.add_argument('--gpu', default=None, type=int,
|
95 |
+
help='GPU id to use.')
|
96 |
+
parser.add_argument('--save_interval', default=20, type=int,
|
97 |
+
help='interval to save segmentation results.')
|
98 |
+
parser.add_argument('--num_samples', default=3, type=int,
|
99 |
+
help='number of samples per class for training')
|
100 |
+
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
101 |
+
help='Use multi-processing distributed training to launch '
|
102 |
+
'N processes per node, which has N GPUs. This is the '
|
103 |
+
'fastest way to use PyTorch for either single node or '
|
104 |
+
'multi node data parallel training')
|
105 |
+
parser.add_argument('--lambda_seg', default=0.8, type=float,
|
106 |
+
help='influence of segmentation loss.')
|
107 |
+
parser.add_argument('--lambda_acc', default=0.2, type=float,
|
108 |
+
help='influence of accuracy loss.')
|
109 |
+
parser.add_argument('--experiment_folder', default=None, type=str,
|
110 |
+
help='path to folder to use for experiment.')
|
111 |
+
parser.add_argument('--dilation', default=0, type=float,
|
112 |
+
help='Use dilation on the segmentation maps.')
|
113 |
+
parser.add_argument('--lambda_background', default=2, type=float,
|
114 |
+
help='coefficient of loss for segmentation background.')
|
115 |
+
parser.add_argument('--lambda_foreground', default=0.3, type=float,
|
116 |
+
help='coefficient of loss for segmentation foreground.')
|
117 |
+
parser.add_argument('--num_classes', default=500, type=int,
|
118 |
+
help='coefficient of loss for segmentation foreground.')
|
119 |
+
parser.add_argument('--temperature', default=1, type=float,
|
120 |
+
help='temperature for softmax (mostly for DeiT).')
|
121 |
+
|
122 |
+
best_loss = float('inf')
|
123 |
+
|
124 |
+
def main():
|
125 |
+
args = parser.parse_args()
|
126 |
+
|
127 |
+
if args.experiment_folder is None:
|
128 |
+
args.experiment_folder = f'experiment/' \
|
129 |
+
f'lr_{args.lr}_seg_{args.lambda_seg}_acc_{args.lambda_acc}' \
|
130 |
+
f'_bckg_{args.lambda_background}_fgd_{args.lambda_foreground}'
|
131 |
+
if args.temperature != 1:
|
132 |
+
args.experiment_folder = args.experiment_folder + f'_tempera_{args.temperature}'
|
133 |
+
if args.batch_size != 8:
|
134 |
+
args.experiment_folder = args.experiment_folder + f'_bs_{args.batch_size}'
|
135 |
+
if args.num_classes != 500:
|
136 |
+
args.experiment_folder = args.experiment_folder + f'_num_classes_{args.num_classes}'
|
137 |
+
if args.num_samples != 3:
|
138 |
+
args.experiment_folder = args.experiment_folder + f'_num_samples_{args.num_samples}'
|
139 |
+
if args.epochs != 150:
|
140 |
+
args.experiment_folder = args.experiment_folder + f'_num_epochs_{args.epochs}'
|
141 |
+
|
142 |
+
if os.path.exists(args.experiment_folder):
|
143 |
+
raise Exception(f"Experiment path {args.experiment_folder} already exists!")
|
144 |
+
os.mkdir(args.experiment_folder)
|
145 |
+
os.mkdir(f'{args.experiment_folder}/train_samples')
|
146 |
+
os.mkdir(f'{args.experiment_folder}/val_samples')
|
147 |
+
|
148 |
+
with open(f'{args.experiment_folder}/commandline_args.txt', 'w') as f:
|
149 |
+
json.dump(args.__dict__, f, indent=2)
|
150 |
+
|
151 |
+
if args.seed is not None:
|
152 |
+
random.seed(args.seed)
|
153 |
+
torch.manual_seed(args.seed)
|
154 |
+
cudnn.deterministic = True
|
155 |
+
warnings.warn('You have chosen to seed training. '
|
156 |
+
'This will turn on the CUDNN deterministic setting, '
|
157 |
+
'which can slow down your training considerably! '
|
158 |
+
'You may see unexpected behavior when restarting '
|
159 |
+
'from checkpoints.')
|
160 |
+
|
161 |
+
if args.gpu is not None:
|
162 |
+
warnings.warn('You have chosen a specific GPU. This will completely '
|
163 |
+
'disable data parallelism.')
|
164 |
+
|
165 |
+
if args.dist_url == "env://" and args.world_size == -1:
|
166 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
167 |
+
|
168 |
+
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
169 |
+
|
170 |
+
ngpus_per_node = torch.cuda.device_count()
|
171 |
+
if args.multiprocessing_distributed:
|
172 |
+
# Since we have ngpus_per_node processes per node, the total world_size
|
173 |
+
# needs to be adjusted accordingly
|
174 |
+
args.world_size = ngpus_per_node * args.world_size
|
175 |
+
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
176 |
+
# main_worker process function
|
177 |
+
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
178 |
+
else:
|
179 |
+
# Simply call main_worker function
|
180 |
+
main_worker(args.gpu, ngpus_per_node, args)
|
181 |
+
|
182 |
+
|
183 |
+
def main_worker(gpu, ngpus_per_node, args):
|
184 |
+
global best_loss
|
185 |
+
args.gpu = gpu
|
186 |
+
|
187 |
+
if args.gpu is not None:
|
188 |
+
print("Use GPU: {} for training".format(args.gpu))
|
189 |
+
|
190 |
+
if args.distributed:
|
191 |
+
if args.dist_url == "env://" and args.rank == -1:
|
192 |
+
args.rank = int(os.environ["RANK"])
|
193 |
+
if args.multiprocessing_distributed:
|
194 |
+
# For multiprocessing distributed training, rank needs to be the
|
195 |
+
# global rank among all the processes
|
196 |
+
args.rank = args.rank * ngpus_per_node + gpu
|
197 |
+
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
198 |
+
world_size=args.world_size, rank=args.rank)
|
199 |
+
# create model
|
200 |
+
if args.pretrained:
|
201 |
+
print("=> using pre-trained model '{}'".format(args.arch))
|
202 |
+
model = models.__dict__[args.arch](pretrained=True)
|
203 |
+
else:
|
204 |
+
print("=> creating model '{}'".format(args.arch))
|
205 |
+
#model = models.__dict__[args.arch]()
|
206 |
+
model = vit(pretrained=True).cuda()
|
207 |
+
model.train()
|
208 |
+
print("done")
|
209 |
+
|
210 |
+
if not torch.cuda.is_available():
|
211 |
+
print('using CPU, this will be slow')
|
212 |
+
elif args.distributed:
|
213 |
+
# For multiprocessing distributed, DistributedDataParallel constructor
|
214 |
+
# should always set the single device scope, otherwise,
|
215 |
+
# DistributedDataParallel will use all available devices.
|
216 |
+
if args.gpu is not None:
|
217 |
+
torch.cuda.set_device(args.gpu)
|
218 |
+
model.cuda(args.gpu)
|
219 |
+
# When using a single GPU per process and per
|
220 |
+
# DistributedDataParallel, we need to divide the batch size
|
221 |
+
# ourselves based on the total number of GPUs we have
|
222 |
+
args.batch_size = int(args.batch_size / ngpus_per_node)
|
223 |
+
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
224 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
225 |
+
else:
|
226 |
+
model.cuda()
|
227 |
+
# DistributedDataParallel will divide and allocate batch_size to all
|
228 |
+
# available GPUs if device_ids are not set
|
229 |
+
model = torch.nn.parallel.DistributedDataParallel(model)
|
230 |
+
elif args.gpu is not None:
|
231 |
+
torch.cuda.set_device(args.gpu)
|
232 |
+
model = model.cuda(args.gpu)
|
233 |
+
else:
|
234 |
+
# DataParallel will divide and allocate batch_size to all available GPUs
|
235 |
+
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
|
236 |
+
model.features = torch.nn.DataParallel(model.features)
|
237 |
+
model.cuda()
|
238 |
+
else:
|
239 |
+
print("start")
|
240 |
+
model = torch.nn.DataParallel(model).cuda()
|
241 |
+
|
242 |
+
# define loss function (criterion) and optimizer
|
243 |
+
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
|
244 |
+
optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.weight_decay)
|
245 |
+
|
246 |
+
# optionally resume from a checkpoint
|
247 |
+
if args.resume:
|
248 |
+
if os.path.isfile(args.resume):
|
249 |
+
print("=> loading checkpoint '{}'".format(args.resume))
|
250 |
+
if args.gpu is None:
|
251 |
+
checkpoint = torch.load(args.resume)
|
252 |
+
else:
|
253 |
+
# Map model to be loaded to specified single gpu.
|
254 |
+
loc = 'cuda:{}'.format(args.gpu)
|
255 |
+
checkpoint = torch.load(args.resume, map_location=loc)
|
256 |
+
args.start_epoch = checkpoint['epoch']
|
257 |
+
best_loss = checkpoint['best_loss']
|
258 |
+
if args.gpu is not None:
|
259 |
+
# best_loss may be from a checkpoint from a different GPU
|
260 |
+
best_loss = best_loss.to(args.gpu)
|
261 |
+
model.load_state_dict(checkpoint['state_dict'])
|
262 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
263 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
264 |
+
.format(args.resume, checkpoint['epoch']))
|
265 |
+
else:
|
266 |
+
print("=> no checkpoint found at '{}'".format(args.resume))
|
267 |
+
|
268 |
+
cudnn.benchmark = True
|
269 |
+
|
270 |
+
train_dataset = SegmentationDataset(args.seg_data, args.data, partition=TRAIN_PARTITION, train_classes=args.num_classes,
|
271 |
+
num_samples=args.num_samples)
|
272 |
+
|
273 |
+
if args.distributed:
|
274 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
275 |
+
else:
|
276 |
+
train_sampler = None
|
277 |
+
|
278 |
+
train_loader = torch.utils.data.DataLoader(
|
279 |
+
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
280 |
+
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
|
281 |
+
|
282 |
+
val_dataset = SegmentationDataset(args.seg_data, args.data, partition=VAL_PARTITION, train_classes=args.num_classes,
|
283 |
+
num_samples=1)
|
284 |
+
|
285 |
+
val_loader = torch.utils.data.DataLoader(
|
286 |
+
val_dataset, batch_size=10, shuffle=False,
|
287 |
+
num_workers=args.workers, pin_memory=True)
|
288 |
+
|
289 |
+
if args.evaluate:
|
290 |
+
validate(val_loader, model, criterion, 0, args)
|
291 |
+
return
|
292 |
+
|
293 |
+
for epoch in range(args.start_epoch, args.epochs):
|
294 |
+
if args.distributed:
|
295 |
+
train_sampler.set_epoch(epoch)
|
296 |
+
adjust_learning_rate(optimizer, epoch, args)
|
297 |
+
|
298 |
+
log_dir = os.path.join(args.experiment_folder, 'logs')
|
299 |
+
logger = SummaryWriter(log_dir=log_dir)
|
300 |
+
args.logger = logger
|
301 |
+
|
302 |
+
# train for one epoch
|
303 |
+
train(train_loader, model, criterion, optimizer, epoch, args)
|
304 |
+
|
305 |
+
# evaluate on validation set
|
306 |
+
loss1 = validate(val_loader, model, criterion, epoch, args)
|
307 |
+
|
308 |
+
# remember best acc@1 and save checkpoint
|
309 |
+
is_best = loss1 <= best_loss
|
310 |
+
best_loss = min(loss1, best_loss)
|
311 |
+
|
312 |
+
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
313 |
+
and args.rank % ngpus_per_node == 0):
|
314 |
+
save_checkpoint({
|
315 |
+
'epoch': epoch + 1,
|
316 |
+
'arch': args.arch,
|
317 |
+
'state_dict': model.state_dict(),
|
318 |
+
'best_loss': best_loss,
|
319 |
+
'optimizer' : optimizer.state_dict(),
|
320 |
+
}, is_best, folder=args.experiment_folder)
|
321 |
+
|
322 |
+
|
323 |
+
def train(train_loader, model, criterion, optimizer, epoch, args):
|
324 |
+
mse_criterion = torch.nn.MSELoss(reduction='mean')
|
325 |
+
|
326 |
+
losses = AverageMeter('Loss', ':.4e')
|
327 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
328 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
329 |
+
orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
|
330 |
+
orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
|
331 |
+
progress = ProgressMeter(
|
332 |
+
len(train_loader),
|
333 |
+
[losses, top1, top5, orig_top1, orig_top5],
|
334 |
+
prefix="Epoch: [{}]".format(epoch))
|
335 |
+
|
336 |
+
orig_model = vit(pretrained=True).cuda()
|
337 |
+
orig_model.eval()
|
338 |
+
|
339 |
+
# switch to train mode
|
340 |
+
model.train()
|
341 |
+
|
342 |
+
for i, (seg_map, image_ten, class_name) in enumerate(train_loader):
|
343 |
+
if torch.cuda.is_available():
|
344 |
+
image_ten = image_ten.cuda(args.gpu, non_blocking=True)
|
345 |
+
seg_map = seg_map.cuda(args.gpu, non_blocking=True)
|
346 |
+
class_name = class_name.cuda(args.gpu, non_blocking=True)
|
347 |
+
|
348 |
+
# segmentation loss
|
349 |
+
relevance = generate_relevance(model, image_ten, index=class_name)
|
350 |
+
|
351 |
+
reverse_seg_map = seg_map.clone()
|
352 |
+
reverse_seg_map[reverse_seg_map == 1] = -1
|
353 |
+
reverse_seg_map[reverse_seg_map == 0] = 1
|
354 |
+
reverse_seg_map[reverse_seg_map == -1] = 0
|
355 |
+
background_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
|
356 |
+
foreground_loss = mse_criterion(relevance * seg_map, seg_map)
|
357 |
+
segmentation_loss = args.lambda_background * background_loss
|
358 |
+
segmentation_loss += args.lambda_foreground * foreground_loss
|
359 |
+
|
360 |
+
# classification loss
|
361 |
+
output = model(image_ten)
|
362 |
+
with torch.no_grad():
|
363 |
+
output_orig = orig_model(image_ten)
|
364 |
+
|
365 |
+
_, pred = output.topk(1, 1, True, True)
|
366 |
+
pred = pred.flatten()
|
367 |
+
|
368 |
+
if args.temperature != 1:
|
369 |
+
output = output / args.temperature
|
370 |
+
classification_loss = criterion(output, class_name.flatten())
|
371 |
+
|
372 |
+
loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
|
373 |
+
|
374 |
+
# debugging output
|
375 |
+
if i % args.save_interval == 0:
|
376 |
+
orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
|
377 |
+
for j in range(image_ten.shape[0]):
|
378 |
+
image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
|
379 |
+
new_vis = get_image_with_relevance(image_ten[j], relevance[j])
|
380 |
+
old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
|
381 |
+
gt = get_image_with_relevance(image_ten[j], seg_map[j])
|
382 |
+
h_img = cv2.hconcat([image, gt, old_vis, new_vis])
|
383 |
+
cv2.imwrite(f'{args.experiment_folder}/train_samples/res_{i}_{j}.jpg', h_img)
|
384 |
+
|
385 |
+
# measure accuracy and record loss
|
386 |
+
acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
|
387 |
+
losses.update(loss.item(), image_ten.size(0))
|
388 |
+
top1.update(acc1[0], image_ten.size(0))
|
389 |
+
top5.update(acc5[0], image_ten.size(0))
|
390 |
+
|
391 |
+
# metrics for original vit
|
392 |
+
acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
|
393 |
+
orig_top1.update(acc1_orig[0], image_ten.size(0))
|
394 |
+
orig_top5.update(acc5_orig[0], image_ten.size(0))
|
395 |
+
|
396 |
+
# compute gradient and do SGD step
|
397 |
+
optimizer.zero_grad()
|
398 |
+
loss.backward()
|
399 |
+
optimizer.step()
|
400 |
+
|
401 |
+
if i % args.print_freq == 0:
|
402 |
+
progress.display(i)
|
403 |
+
args.logger.add_scalar('{}/{}'.format('train', 'segmentation_loss'), segmentation_loss,
|
404 |
+
epoch*len(train_loader)+i)
|
405 |
+
args.logger.add_scalar('{}/{}'.format('train', 'classification_loss'), classification_loss,
|
406 |
+
epoch * len(train_loader) + i)
|
407 |
+
args.logger.add_scalar('{}/{}'.format('train', 'orig_top1'), acc1_orig,
|
408 |
+
epoch * len(train_loader) + i)
|
409 |
+
args.logger.add_scalar('{}/{}'.format('train', 'top1'), acc1,
|
410 |
+
epoch * len(train_loader) + i)
|
411 |
+
args.logger.add_scalar('{}/{}'.format('train', 'orig_top5'), acc5_orig,
|
412 |
+
epoch * len(train_loader) + i)
|
413 |
+
args.logger.add_scalar('{}/{}'.format('train', 'top5'), acc5,
|
414 |
+
epoch * len(train_loader) + i)
|
415 |
+
args.logger.add_scalar('{}/{}'.format('train', 'tot_loss'), loss,
|
416 |
+
epoch * len(train_loader) + i)
|
417 |
+
|
418 |
+
|
419 |
+
def validate(val_loader, model, criterion, epoch, args):
|
420 |
+
mse_criterion = torch.nn.MSELoss(reduction='mean')
|
421 |
+
|
422 |
+
losses = AverageMeter('Loss', ':.4e')
|
423 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
424 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
425 |
+
orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
|
426 |
+
orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
|
427 |
+
progress = ProgressMeter(
|
428 |
+
len(val_loader),
|
429 |
+
[losses, top1, top5, orig_top1, orig_top5],
|
430 |
+
prefix="Epoch: [{}]".format(val_loader))
|
431 |
+
|
432 |
+
# switch to evaluate mode
|
433 |
+
model.eval()
|
434 |
+
|
435 |
+
orig_model = vit(pretrained=True).cuda()
|
436 |
+
orig_model.eval()
|
437 |
+
|
438 |
+
with torch.no_grad():
|
439 |
+
for i, (seg_map, image_ten, class_name) in enumerate(val_loader):
|
440 |
+
if args.gpu is not None:
|
441 |
+
image_ten = image_ten.cuda(args.gpu, non_blocking=True)
|
442 |
+
if torch.cuda.is_available():
|
443 |
+
seg_map = seg_map.cuda(args.gpu, non_blocking=True)
|
444 |
+
class_name = class_name.cuda(args.gpu, non_blocking=True)
|
445 |
+
|
446 |
+
# segmentation loss
|
447 |
+
with torch.enable_grad():
|
448 |
+
relevance = generate_relevance(model, image_ten, index=class_name)
|
449 |
+
|
450 |
+
reverse_seg_map = seg_map.clone()
|
451 |
+
reverse_seg_map[reverse_seg_map == 1] = -1
|
452 |
+
reverse_seg_map[reverse_seg_map == 0] = 1
|
453 |
+
reverse_seg_map[reverse_seg_map == -1] = 0
|
454 |
+
background_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
|
455 |
+
foreground_loss = mse_criterion(relevance * seg_map, seg_map)
|
456 |
+
segmentation_loss = args.lambda_background * background_loss
|
457 |
+
segmentation_loss += args.lambda_foreground * foreground_loss
|
458 |
+
|
459 |
+
# classification loss
|
460 |
+
with torch.no_grad():
|
461 |
+
output = model(image_ten)
|
462 |
+
output_orig = orig_model(image_ten)
|
463 |
+
|
464 |
+
_, pred = output.topk(1, 1, True, True)
|
465 |
+
pred = pred.flatten()
|
466 |
+
if args.temperature != 1:
|
467 |
+
output = output / args.temperature
|
468 |
+
classification_loss = criterion(output, class_name.flatten())
|
469 |
+
|
470 |
+
loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
|
471 |
+
|
472 |
+
# save results
|
473 |
+
if i % args.save_interval == 0:
|
474 |
+
with torch.enable_grad():
|
475 |
+
orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
|
476 |
+
for j in range(image_ten.shape[0]):
|
477 |
+
image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
|
478 |
+
new_vis = get_image_with_relevance(image_ten[j], relevance[j])
|
479 |
+
old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
|
480 |
+
gt = get_image_with_relevance(image_ten[j], seg_map[j])
|
481 |
+
h_img = cv2.hconcat([image, gt, old_vis, new_vis])
|
482 |
+
cv2.imwrite(f'{args.experiment_folder}/val_samples/res_{i}_{j}.jpg', h_img)
|
483 |
+
|
484 |
+
# measure accuracy and record loss
|
485 |
+
acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
|
486 |
+
losses.update(loss.item(), image_ten.size(0))
|
487 |
+
top1.update(acc1[0], image_ten.size(0))
|
488 |
+
top5.update(acc5[0], image_ten.size(0))
|
489 |
+
|
490 |
+
# metrics for original vit
|
491 |
+
acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
|
492 |
+
orig_top1.update(acc1_orig[0], image_ten.size(0))
|
493 |
+
orig_top5.update(acc5_orig[0], image_ten.size(0))
|
494 |
+
|
495 |
+
if i % args.print_freq == 0:
|
496 |
+
progress.display(i)
|
497 |
+
args.logger.add_scalar('{}/{}'.format('val', 'segmentation_loss'), segmentation_loss,
|
498 |
+
epoch * len(val_loader) + i)
|
499 |
+
args.logger.add_scalar('{}/{}'.format('val', 'classification_loss'), classification_loss,
|
500 |
+
epoch * len(val_loader) + i)
|
501 |
+
args.logger.add_scalar('{}/{}'.format('val', 'orig_top1'), acc1_orig,
|
502 |
+
epoch * len(val_loader) + i)
|
503 |
+
args.logger.add_scalar('{}/{}'.format('val', 'top1'), acc1,
|
504 |
+
epoch * len(val_loader) + i)
|
505 |
+
args.logger.add_scalar('{}/{}'.format('val', 'orig_top5'), acc5_orig,
|
506 |
+
epoch * len(val_loader) + i)
|
507 |
+
args.logger.add_scalar('{}/{}'.format('val', 'top5'), acc5,
|
508 |
+
epoch * len(val_loader) + i)
|
509 |
+
args.logger.add_scalar('{}/{}'.format('val', 'tot_loss'), loss,
|
510 |
+
epoch * len(val_loader) + i)
|
511 |
+
|
512 |
+
# TODO: this should also be done with the ProgressMeter
|
513 |
+
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
|
514 |
+
.format(top1=top1, top5=top5))
|
515 |
+
|
516 |
+
return losses.avg
|
517 |
+
|
518 |
+
|
519 |
+
def save_checkpoint(state, is_best, folder, filename='checkpoint.pth.tar'):
|
520 |
+
torch.save(state, f'{folder}/{filename}')
|
521 |
+
if is_best:
|
522 |
+
shutil.copyfile(f'{folder}/{filename}', f'{folder}/model_best.pth.tar')
|
523 |
+
|
524 |
+
|
525 |
+
class AverageMeter(object):
|
526 |
+
"""Computes and stores the average and current value"""
|
527 |
+
def __init__(self, name, fmt=':f'):
|
528 |
+
self.name = name
|
529 |
+
self.fmt = fmt
|
530 |
+
self.reset()
|
531 |
+
|
532 |
+
def reset(self):
|
533 |
+
self.val = 0
|
534 |
+
self.avg = 0
|
535 |
+
self.sum = 0
|
536 |
+
self.count = 0
|
537 |
+
|
538 |
+
def update(self, val, n=1):
|
539 |
+
self.val = val
|
540 |
+
self.sum += val * n
|
541 |
+
self.count += n
|
542 |
+
self.avg = self.sum / self.count
|
543 |
+
|
544 |
+
def __str__(self):
|
545 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
546 |
+
return fmtstr.format(**self.__dict__)
|
547 |
+
|
548 |
+
|
549 |
+
class ProgressMeter(object):
|
550 |
+
def __init__(self, num_batches, meters, prefix=""):
|
551 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
552 |
+
self.meters = meters
|
553 |
+
self.prefix = prefix
|
554 |
+
|
555 |
+
def display(self, batch):
|
556 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
557 |
+
entries += [str(meter) for meter in self.meters]
|
558 |
+
print('\t'.join(entries))
|
559 |
+
|
560 |
+
def _get_batch_fmtstr(self, num_batches):
|
561 |
+
num_digits = len(str(num_batches // 1))
|
562 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
563 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
564 |
+
|
565 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
566 |
+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
567 |
+
lr = args.lr * (0.85 ** (epoch // 2))
|
568 |
+
for param_group in optimizer.param_groups:
|
569 |
+
param_group['lr'] = lr
|
570 |
+
|
571 |
+
|
572 |
+
def accuracy(output, target, topk=(1,)):
|
573 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
574 |
+
with torch.no_grad():
|
575 |
+
maxk = max(topk)
|
576 |
+
batch_size = target.size(0)
|
577 |
+
|
578 |
+
_, pred = output.topk(maxk, 1, True, True)
|
579 |
+
pred = pred.t()
|
580 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
581 |
+
|
582 |
+
res = []
|
583 |
+
for k in topk:
|
584 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
585 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
586 |
+
return res
|
587 |
+
|
588 |
+
|
589 |
+
if __name__ == '__main__':
|
590 |
+
main()
|
imagenet_classes.json
ADDED
@@ -0,0 +1,1002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n01440764": 0,
|
3 |
+
"n01443537": 1,
|
4 |
+
"n01484850": 2,
|
5 |
+
"n01491361": 3,
|
6 |
+
"n01494475": 4,
|
7 |
+
"n01496331": 5,
|
8 |
+
"n01498041": 6,
|
9 |
+
"n01514668": 7,
|
10 |
+
"n01514859": 8,
|
11 |
+
"n01518878": 9,
|
12 |
+
"n01530575": 10,
|
13 |
+
"n01531178": 11,
|
14 |
+
"n01532829": 12,
|
15 |
+
"n01534433": 13,
|
16 |
+
"n01537544": 14,
|
17 |
+
"n01558993": 15,
|
18 |
+
"n01560419": 16,
|
19 |
+
"n01580077": 17,
|
20 |
+
"n01582220": 18,
|
21 |
+
"n01592084": 19,
|
22 |
+
"n01601694": 20,
|
23 |
+
"n01608432": 21,
|
24 |
+
"n01614925": 22,
|
25 |
+
"n01616318": 23,
|
26 |
+
"n01622779": 24,
|
27 |
+
"n01629819": 25,
|
28 |
+
"n01630670": 26,
|
29 |
+
"n01631663": 27,
|
30 |
+
"n01632458": 28,
|
31 |
+
"n01632777": 29,
|
32 |
+
"n01641577": 30,
|
33 |
+
"n01644373": 31,
|
34 |
+
"n01644900": 32,
|
35 |
+
"n01664065": 33,
|
36 |
+
"n01665541": 34,
|
37 |
+
"n01667114": 35,
|
38 |
+
"n01667778": 36,
|
39 |
+
"n01669191": 37,
|
40 |
+
"n01675722": 38,
|
41 |
+
"n01677366": 39,
|
42 |
+
"n01682714": 40,
|
43 |
+
"n01685808": 41,
|
44 |
+
"n01687978": 42,
|
45 |
+
"n01688243": 43,
|
46 |
+
"n01689811": 44,
|
47 |
+
"n01692333": 45,
|
48 |
+
"n01693334": 46,
|
49 |
+
"n01694178": 47,
|
50 |
+
"n01695060": 48,
|
51 |
+
"n01697457": 49,
|
52 |
+
"n01698640": 50,
|
53 |
+
"n01704323": 51,
|
54 |
+
"n01728572": 52,
|
55 |
+
"n01728920": 53,
|
56 |
+
"n01729322": 54,
|
57 |
+
"n01729977": 55,
|
58 |
+
"n01734418": 56,
|
59 |
+
"n01735189": 57,
|
60 |
+
"n01737021": 58,
|
61 |
+
"n01739381": 59,
|
62 |
+
"n01740131": 60,
|
63 |
+
"n01742172": 61,
|
64 |
+
"n01744401": 62,
|
65 |
+
"n01748264": 63,
|
66 |
+
"n01749939": 64,
|
67 |
+
"n01751748": 65,
|
68 |
+
"n01753488": 66,
|
69 |
+
"n01755581": 67,
|
70 |
+
"n01756291": 68,
|
71 |
+
"n01768244": 69,
|
72 |
+
"n01770081": 70,
|
73 |
+
"n01770393": 71,
|
74 |
+
"n01773157": 72,
|
75 |
+
"n01773549": 73,
|
76 |
+
"n01773797": 74,
|
77 |
+
"n01774384": 75,
|
78 |
+
"n01774750": 76,
|
79 |
+
"n01775062": 77,
|
80 |
+
"n01776313": 78,
|
81 |
+
"n01784675": 79,
|
82 |
+
"n01795545": 80,
|
83 |
+
"n01796340": 81,
|
84 |
+
"n01797886": 82,
|
85 |
+
"n01798484": 83,
|
86 |
+
"n01806143": 84,
|
87 |
+
"n01806567": 85,
|
88 |
+
"n01807496": 86,
|
89 |
+
"n01817953": 87,
|
90 |
+
"n01818515": 88,
|
91 |
+
"n01819313": 89,
|
92 |
+
"n01820546": 90,
|
93 |
+
"n01824575": 91,
|
94 |
+
"n01828970": 92,
|
95 |
+
"n01829413": 93,
|
96 |
+
"n01833805": 94,
|
97 |
+
"n01843065": 95,
|
98 |
+
"n01843383": 96,
|
99 |
+
"n01847000": 97,
|
100 |
+
"n01855032": 98,
|
101 |
+
"n01855672": 99,
|
102 |
+
"n01860187": 100,
|
103 |
+
"n01871265": 101,
|
104 |
+
"n01872401": 102,
|
105 |
+
"n01873310": 103,
|
106 |
+
"n01877812": 104,
|
107 |
+
"n01882714": 105,
|
108 |
+
"n01883070": 106,
|
109 |
+
"n01910747": 107,
|
110 |
+
"n01914609": 108,
|
111 |
+
"n01917289": 109,
|
112 |
+
"n01924916": 110,
|
113 |
+
"n01930112": 111,
|
114 |
+
"n01943899": 112,
|
115 |
+
"n01944390": 113,
|
116 |
+
"n01945685": 114,
|
117 |
+
"n01950731": 115,
|
118 |
+
"n01955084": 116,
|
119 |
+
"n01968897": 117,
|
120 |
+
"n01978287": 118,
|
121 |
+
"n01978455": 119,
|
122 |
+
"n01980166": 120,
|
123 |
+
"n01981276": 121,
|
124 |
+
"n01983481": 122,
|
125 |
+
"n01984695": 123,
|
126 |
+
"n01985128": 124,
|
127 |
+
"n01986214": 125,
|
128 |
+
"n01990800": 126,
|
129 |
+
"n02002556": 127,
|
130 |
+
"n02002724": 128,
|
131 |
+
"n02006656": 129,
|
132 |
+
"n02007558": 130,
|
133 |
+
"n02009229": 131,
|
134 |
+
"n02009912": 132,
|
135 |
+
"n02011460": 133,
|
136 |
+
"n02012849": 134,
|
137 |
+
"n02013706": 135,
|
138 |
+
"n02017213": 136,
|
139 |
+
"n02018207": 137,
|
140 |
+
"n02018795": 138,
|
141 |
+
"n02025239": 139,
|
142 |
+
"n02027492": 140,
|
143 |
+
"n02028035": 141,
|
144 |
+
"n02033041": 142,
|
145 |
+
"n02037110": 143,
|
146 |
+
"n02051845": 144,
|
147 |
+
"n02056570": 145,
|
148 |
+
"n02058221": 146,
|
149 |
+
"n02066245": 147,
|
150 |
+
"n02071294": 148,
|
151 |
+
"n02074367": 149,
|
152 |
+
"n02077923": 150,
|
153 |
+
"n02085620": 151,
|
154 |
+
"n02085782": 152,
|
155 |
+
"n02085936": 153,
|
156 |
+
"n02086079": 154,
|
157 |
+
"n02086240": 155,
|
158 |
+
"n02086646": 156,
|
159 |
+
"n02086910": 157,
|
160 |
+
"n02087046": 158,
|
161 |
+
"n02087394": 159,
|
162 |
+
"n02088094": 160,
|
163 |
+
"n02088238": 161,
|
164 |
+
"n02088364": 162,
|
165 |
+
"n02088466": 163,
|
166 |
+
"n02088632": 164,
|
167 |
+
"n02089078": 165,
|
168 |
+
"n02089867": 166,
|
169 |
+
"n02089973": 167,
|
170 |
+
"n02090379": 168,
|
171 |
+
"n02090622": 169,
|
172 |
+
"n02090721": 170,
|
173 |
+
"n02091032": 171,
|
174 |
+
"n02091134": 172,
|
175 |
+
"n02091244": 173,
|
176 |
+
"n02091467": 174,
|
177 |
+
"n02091635": 175,
|
178 |
+
"n02091831": 176,
|
179 |
+
"n02092002": 177,
|
180 |
+
"n02092339": 178,
|
181 |
+
"n02093256": 179,
|
182 |
+
"n02093428": 180,
|
183 |
+
"n02093647": 181,
|
184 |
+
"n02093754": 182,
|
185 |
+
"n02093859": 183,
|
186 |
+
"n02093991": 184,
|
187 |
+
"n02094114": 185,
|
188 |
+
"n02094258": 186,
|
189 |
+
"n02094433": 187,
|
190 |
+
"n02095314": 188,
|
191 |
+
"n02095570": 189,
|
192 |
+
"n02095889": 190,
|
193 |
+
"n02096051": 191,
|
194 |
+
"n02096177": 192,
|
195 |
+
"n02096294": 193,
|
196 |
+
"n02096437": 194,
|
197 |
+
"n02096585": 195,
|
198 |
+
"n02097047": 196,
|
199 |
+
"n02097130": 197,
|
200 |
+
"n02097209": 198,
|
201 |
+
"n02097298": 199,
|
202 |
+
"n02097474": 200,
|
203 |
+
"n02097658": 201,
|
204 |
+
"n02098105": 202,
|
205 |
+
"n02098286": 203,
|
206 |
+
"n02098413": 204,
|
207 |
+
"n02099267": 205,
|
208 |
+
"n02099429": 206,
|
209 |
+
"n02099601": 207,
|
210 |
+
"n02099712": 208,
|
211 |
+
"n02099849": 209,
|
212 |
+
"n02100236": 210,
|
213 |
+
"n02100583": 211,
|
214 |
+
"n02100735": 212,
|
215 |
+
"n02100877": 213,
|
216 |
+
"n02101006": 214,
|
217 |
+
"n02101388": 215,
|
218 |
+
"n02101556": 216,
|
219 |
+
"n02102040": 217,
|
220 |
+
"n02102177": 218,
|
221 |
+
"n02102318": 219,
|
222 |
+
"n02102480": 220,
|
223 |
+
"n02102973": 221,
|
224 |
+
"n02104029": 222,
|
225 |
+
"n02104365": 223,
|
226 |
+
"n02105056": 224,
|
227 |
+
"n02105162": 225,
|
228 |
+
"n02105251": 226,
|
229 |
+
"n02105412": 227,
|
230 |
+
"n02105505": 228,
|
231 |
+
"n02105641": 229,
|
232 |
+
"n02105855": 230,
|
233 |
+
"n02106030": 231,
|
234 |
+
"n02106166": 232,
|
235 |
+
"n02106382": 233,
|
236 |
+
"n02106550": 234,
|
237 |
+
"n02106662": 235,
|
238 |
+
"n02107142": 236,
|
239 |
+
"n02107312": 237,
|
240 |
+
"n02107574": 238,
|
241 |
+
"n02107683": 239,
|
242 |
+
"n02107908": 240,
|
243 |
+
"n02108000": 241,
|
244 |
+
"n02108089": 242,
|
245 |
+
"n02108422": 243,
|
246 |
+
"n02108551": 244,
|
247 |
+
"n02108915": 245,
|
248 |
+
"n02109047": 246,
|
249 |
+
"n02109525": 247,
|
250 |
+
"n02109961": 248,
|
251 |
+
"n02110063": 249,
|
252 |
+
"n02110185": 250,
|
253 |
+
"n02110341": 251,
|
254 |
+
"n02110627": 252,
|
255 |
+
"n02110806": 253,
|
256 |
+
"n02110958": 254,
|
257 |
+
"n02111129": 255,
|
258 |
+
"n02111277": 256,
|
259 |
+
"n02111500": 257,
|
260 |
+
"n02111889": 258,
|
261 |
+
"n02112018": 259,
|
262 |
+
"n02112137": 260,
|
263 |
+
"n02112350": 261,
|
264 |
+
"n02112706": 262,
|
265 |
+
"n02113023": 263,
|
266 |
+
"n02113186": 264,
|
267 |
+
"n02113624": 265,
|
268 |
+
"n02113712": 266,
|
269 |
+
"n02113799": 267,
|
270 |
+
"n02113978": 268,
|
271 |
+
"n02114367": 269,
|
272 |
+
"n02114548": 270,
|
273 |
+
"n02114712": 271,
|
274 |
+
"n02114855": 272,
|
275 |
+
"n02115641": 273,
|
276 |
+
"n02115913": 274,
|
277 |
+
"n02116738": 275,
|
278 |
+
"n02117135": 276,
|
279 |
+
"n02119022": 277,
|
280 |
+
"n02119789": 278,
|
281 |
+
"n02120079": 279,
|
282 |
+
"n02120505": 280,
|
283 |
+
"n02123045": 281,
|
284 |
+
"n02123159": 282,
|
285 |
+
"n02123394": 283,
|
286 |
+
"n02123597": 284,
|
287 |
+
"n02124075": 285,
|
288 |
+
"n02125311": 286,
|
289 |
+
"n02127052": 287,
|
290 |
+
"n02128385": 288,
|
291 |
+
"n02128757": 289,
|
292 |
+
"n02128925": 290,
|
293 |
+
"n02129165": 291,
|
294 |
+
"n02129604": 292,
|
295 |
+
"n02130308": 293,
|
296 |
+
"n02132136": 294,
|
297 |
+
"n02133161": 295,
|
298 |
+
"n02134084": 296,
|
299 |
+
"n02134418": 297,
|
300 |
+
"n02137549": 298,
|
301 |
+
"n02138441": 299,
|
302 |
+
"n02165105": 300,
|
303 |
+
"n02165456": 301,
|
304 |
+
"n02167151": 302,
|
305 |
+
"n02168699": 303,
|
306 |
+
"n02169497": 304,
|
307 |
+
"n02172182": 305,
|
308 |
+
"n02174001": 306,
|
309 |
+
"n02177972": 307,
|
310 |
+
"n02190166": 308,
|
311 |
+
"n02206856": 309,
|
312 |
+
"n02219486": 310,
|
313 |
+
"n02226429": 311,
|
314 |
+
"n02229544": 312,
|
315 |
+
"n02231487": 313,
|
316 |
+
"n02233338": 314,
|
317 |
+
"n02236044": 315,
|
318 |
+
"n02256656": 316,
|
319 |
+
"n02259212": 317,
|
320 |
+
"n02264363": 318,
|
321 |
+
"n02268443": 319,
|
322 |
+
"n02268853": 320,
|
323 |
+
"n02276258": 321,
|
324 |
+
"n02277742": 322,
|
325 |
+
"n02279972": 323,
|
326 |
+
"n02280649": 324,
|
327 |
+
"n02281406": 325,
|
328 |
+
"n02281787": 326,
|
329 |
+
"n02317335": 327,
|
330 |
+
"n02319095": 328,
|
331 |
+
"n02321529": 329,
|
332 |
+
"n02325366": 330,
|
333 |
+
"n02326432": 331,
|
334 |
+
"n02328150": 332,
|
335 |
+
"n02342885": 333,
|
336 |
+
"n02346627": 334,
|
337 |
+
"n02356798": 335,
|
338 |
+
"n02361337": 336,
|
339 |
+
"n02363005": 337,
|
340 |
+
"n02364673": 338,
|
341 |
+
"n02389026": 339,
|
342 |
+
"n02391049": 340,
|
343 |
+
"n02395406": 341,
|
344 |
+
"n02396427": 342,
|
345 |
+
"n02397096": 343,
|
346 |
+
"n02398521": 344,
|
347 |
+
"n02403003": 345,
|
348 |
+
"n02408429": 346,
|
349 |
+
"n02410509": 347,
|
350 |
+
"n02412080": 348,
|
351 |
+
"n02415577": 349,
|
352 |
+
"n02417914": 350,
|
353 |
+
"n02422106": 351,
|
354 |
+
"n02422699": 352,
|
355 |
+
"n02423022": 353,
|
356 |
+
"n02437312": 354,
|
357 |
+
"n02437616": 355,
|
358 |
+
"n02441942": 356,
|
359 |
+
"n02442845": 357,
|
360 |
+
"n02443114": 358,
|
361 |
+
"n02443484": 359,
|
362 |
+
"n02444819": 360,
|
363 |
+
"n02445715": 361,
|
364 |
+
"n02447366": 362,
|
365 |
+
"n02454379": 363,
|
366 |
+
"n02457408": 364,
|
367 |
+
"n02480495": 365,
|
368 |
+
"n02480855": 366,
|
369 |
+
"n02481823": 367,
|
370 |
+
"n02483362": 368,
|
371 |
+
"n02483708": 369,
|
372 |
+
"n02484975": 370,
|
373 |
+
"n02486261": 371,
|
374 |
+
"n02486410": 372,
|
375 |
+
"n02487347": 373,
|
376 |
+
"n02488291": 374,
|
377 |
+
"n02488702": 375,
|
378 |
+
"n02489166": 376,
|
379 |
+
"n02490219": 377,
|
380 |
+
"n02492035": 378,
|
381 |
+
"n02492660": 379,
|
382 |
+
"n02493509": 380,
|
383 |
+
"n02493793": 381,
|
384 |
+
"n02494079": 382,
|
385 |
+
"n02497673": 383,
|
386 |
+
"n02500267": 384,
|
387 |
+
"n02504013": 385,
|
388 |
+
"n02504458": 386,
|
389 |
+
"n02509815": 387,
|
390 |
+
"n02510455": 388,
|
391 |
+
"n02514041": 389,
|
392 |
+
"n02526121": 390,
|
393 |
+
"n02536864": 391,
|
394 |
+
"n02606052": 392,
|
395 |
+
"n02607072": 393,
|
396 |
+
"n02640242": 394,
|
397 |
+
"n02641379": 395,
|
398 |
+
"n02643566": 396,
|
399 |
+
"n02655020": 397,
|
400 |
+
"n02666196": 398,
|
401 |
+
"n02667093": 399,
|
402 |
+
"n02669723": 400,
|
403 |
+
"n02672831": 401,
|
404 |
+
"n02676566": 402,
|
405 |
+
"n02687172": 403,
|
406 |
+
"n02690373": 404,
|
407 |
+
"n02692877": 405,
|
408 |
+
"n02699494": 406,
|
409 |
+
"n02701002": 407,
|
410 |
+
"n02704792": 408,
|
411 |
+
"n02708093": 409,
|
412 |
+
"n02727426": 410,
|
413 |
+
"n02730930": 411,
|
414 |
+
"n02747177": 412,
|
415 |
+
"n02749479": 413,
|
416 |
+
"n02769748": 414,
|
417 |
+
"n02776631": 415,
|
418 |
+
"n02777292": 416,
|
419 |
+
"n02782093": 417,
|
420 |
+
"n02783161": 418,
|
421 |
+
"n02786058": 419,
|
422 |
+
"n02787622": 420,
|
423 |
+
"n02788148": 421,
|
424 |
+
"n02790996": 422,
|
425 |
+
"n02791124": 423,
|
426 |
+
"n02791270": 424,
|
427 |
+
"n02793495": 425,
|
428 |
+
"n02794156": 426,
|
429 |
+
"n02795169": 427,
|
430 |
+
"n02797295": 428,
|
431 |
+
"n02799071": 429,
|
432 |
+
"n02802426": 430,
|
433 |
+
"n02804414": 431,
|
434 |
+
"n02804610": 432,
|
435 |
+
"n02807133": 433,
|
436 |
+
"n02808304": 434,
|
437 |
+
"n02808440": 435,
|
438 |
+
"n02814533": 436,
|
439 |
+
"n02814860": 437,
|
440 |
+
"n02815834": 438,
|
441 |
+
"n02817516": 439,
|
442 |
+
"n02823428": 440,
|
443 |
+
"n02823750": 441,
|
444 |
+
"n02825657": 442,
|
445 |
+
"n02834397": 443,
|
446 |
+
"n02835271": 444,
|
447 |
+
"n02837789": 445,
|
448 |
+
"n02840245": 446,
|
449 |
+
"n02841315": 447,
|
450 |
+
"n02843684": 448,
|
451 |
+
"n02859443": 449,
|
452 |
+
"n02860847": 450,
|
453 |
+
"n02865351": 451,
|
454 |
+
"n02869837": 452,
|
455 |
+
"n02870880": 453,
|
456 |
+
"n02871525": 454,
|
457 |
+
"n02877765": 455,
|
458 |
+
"n02879718": 456,
|
459 |
+
"n02883205": 457,
|
460 |
+
"n02892201": 458,
|
461 |
+
"n02892767": 459,
|
462 |
+
"n02894605": 460,
|
463 |
+
"n02895154": 461,
|
464 |
+
"n02906734": 462,
|
465 |
+
"n02909870": 463,
|
466 |
+
"n02910353": 464,
|
467 |
+
"n02916936": 465,
|
468 |
+
"n02917067": 466,
|
469 |
+
"n02927161": 467,
|
470 |
+
"n02930766": 468,
|
471 |
+
"n02939185": 469,
|
472 |
+
"n02948072": 470,
|
473 |
+
"n02950826": 471,
|
474 |
+
"n02951358": 472,
|
475 |
+
"n02951585": 473,
|
476 |
+
"n02963159": 474,
|
477 |
+
"n02965783": 475,
|
478 |
+
"n02966193": 476,
|
479 |
+
"n02966687": 477,
|
480 |
+
"n02971356": 478,
|
481 |
+
"n02974003": 479,
|
482 |
+
"n02977058": 480,
|
483 |
+
"n02978881": 481,
|
484 |
+
"n02979186": 482,
|
485 |
+
"n02980441": 483,
|
486 |
+
"n02981792": 484,
|
487 |
+
"n02988304": 485,
|
488 |
+
"n02992211": 486,
|
489 |
+
"n02992529": 487,
|
490 |
+
"n02999410": 488,
|
491 |
+
"n03000134": 489,
|
492 |
+
"n03000247": 490,
|
493 |
+
"n03000684": 491,
|
494 |
+
"n03014705": 492,
|
495 |
+
"n03016953": 493,
|
496 |
+
"n03017168": 494,
|
497 |
+
"n03018349": 495,
|
498 |
+
"n03026506": 496,
|
499 |
+
"n03028079": 497,
|
500 |
+
"n03032252": 498,
|
501 |
+
"n03041632": 499,
|
502 |
+
"n03042490": 500,
|
503 |
+
"n03045698": 501,
|
504 |
+
"n03047690": 502,
|
505 |
+
"n03062245": 503,
|
506 |
+
"n03063599": 504,
|
507 |
+
"n03063689": 505,
|
508 |
+
"n03065424": 506,
|
509 |
+
"n03075370": 507,
|
510 |
+
"n03085013": 508,
|
511 |
+
"n03089624": 509,
|
512 |
+
"n03095699": 510,
|
513 |
+
"n03100240": 511,
|
514 |
+
"n03109150": 512,
|
515 |
+
"n03110669": 513,
|
516 |
+
"n03124043": 514,
|
517 |
+
"n03124170": 515,
|
518 |
+
"n03125729": 516,
|
519 |
+
"n03126707": 517,
|
520 |
+
"n03127747": 518,
|
521 |
+
"n03127925": 519,
|
522 |
+
"n03131574": 520,
|
523 |
+
"n03133878": 521,
|
524 |
+
"n03134739": 522,
|
525 |
+
"n03141823": 523,
|
526 |
+
"n03146219": 524,
|
527 |
+
"n03160309": 525,
|
528 |
+
"n03179701": 526,
|
529 |
+
"n03180011": 527,
|
530 |
+
"n03187595": 528,
|
531 |
+
"n03188531": 529,
|
532 |
+
"n03196217": 530,
|
533 |
+
"n03197337": 531,
|
534 |
+
"n03201208": 532,
|
535 |
+
"n03207743": 533,
|
536 |
+
"n03207941": 534,
|
537 |
+
"n03208938": 535,
|
538 |
+
"n03216828": 536,
|
539 |
+
"n03218198": 537,
|
540 |
+
"n03220513": 538,
|
541 |
+
"n03223299": 539,
|
542 |
+
"n03240683": 540,
|
543 |
+
"n03249569": 541,
|
544 |
+
"n03250847": 542,
|
545 |
+
"n03255030": 543,
|
546 |
+
"n03259280": 544,
|
547 |
+
"n03271574": 545,
|
548 |
+
"n03272010": 546,
|
549 |
+
"n03272562": 547,
|
550 |
+
"n03290653": 548,
|
551 |
+
"n03291819": 549,
|
552 |
+
"n03297495": 550,
|
553 |
+
"n03314780": 551,
|
554 |
+
"n03325584": 552,
|
555 |
+
"n03337140": 553,
|
556 |
+
"n03344393": 554,
|
557 |
+
"n03345487": 555,
|
558 |
+
"n03347037": 556,
|
559 |
+
"n03355925": 557,
|
560 |
+
"n03372029": 558,
|
561 |
+
"n03376595": 559,
|
562 |
+
"n03379051": 560,
|
563 |
+
"n03384352": 561,
|
564 |
+
"n03388043": 562,
|
565 |
+
"n03388183": 563,
|
566 |
+
"n03388549": 564,
|
567 |
+
"n03393912": 565,
|
568 |
+
"n03394916": 566,
|
569 |
+
"n03400231": 567,
|
570 |
+
"n03404251": 568,
|
571 |
+
"n03417042": 569,
|
572 |
+
"n03424325": 570,
|
573 |
+
"n03425413": 571,
|
574 |
+
"n03443371": 572,
|
575 |
+
"n03444034": 573,
|
576 |
+
"n03445777": 574,
|
577 |
+
"n03445924": 575,
|
578 |
+
"n03447447": 576,
|
579 |
+
"n03447721": 577,
|
580 |
+
"n03450230": 578,
|
581 |
+
"n03452741": 579,
|
582 |
+
"n03457902": 580,
|
583 |
+
"n03459775": 581,
|
584 |
+
"n03461385": 582,
|
585 |
+
"n03467068": 583,
|
586 |
+
"n03476684": 584,
|
587 |
+
"n03476991": 585,
|
588 |
+
"n03478589": 586,
|
589 |
+
"n03481172": 587,
|
590 |
+
"n03482405": 588,
|
591 |
+
"n03483316": 589,
|
592 |
+
"n03485407": 590,
|
593 |
+
"n03485794": 591,
|
594 |
+
"n03492542": 592,
|
595 |
+
"n03494278": 593,
|
596 |
+
"n03495258": 594,
|
597 |
+
"n03496892": 595,
|
598 |
+
"n03498962": 596,
|
599 |
+
"n03527444": 597,
|
600 |
+
"n03529860": 598,
|
601 |
+
"n03530642": 599,
|
602 |
+
"n03532672": 600,
|
603 |
+
"n03534580": 601,
|
604 |
+
"n03535780": 602,
|
605 |
+
"n03538406": 603,
|
606 |
+
"n03544143": 604,
|
607 |
+
"n03584254": 605,
|
608 |
+
"n03584829": 606,
|
609 |
+
"n03590841": 607,
|
610 |
+
"n03594734": 608,
|
611 |
+
"n03594945": 609,
|
612 |
+
"n03595614": 610,
|
613 |
+
"n03598930": 611,
|
614 |
+
"n03599486": 612,
|
615 |
+
"n03602883": 613,
|
616 |
+
"n03617480": 614,
|
617 |
+
"n03623198": 615,
|
618 |
+
"n03627232": 616,
|
619 |
+
"n03630383": 617,
|
620 |
+
"n03633091": 618,
|
621 |
+
"n03637318": 619,
|
622 |
+
"n03642806": 620,
|
623 |
+
"n03649909": 621,
|
624 |
+
"n03657121": 622,
|
625 |
+
"n03658185": 623,
|
626 |
+
"n03661043": 624,
|
627 |
+
"n03662601": 625,
|
628 |
+
"n03666591": 626,
|
629 |
+
"n03670208": 627,
|
630 |
+
"n03673027": 628,
|
631 |
+
"n03676483": 629,
|
632 |
+
"n03680355": 630,
|
633 |
+
"n03690938": 631,
|
634 |
+
"n03691459": 632,
|
635 |
+
"n03692522": 633,
|
636 |
+
"n03697007": 634,
|
637 |
+
"n03706229": 635,
|
638 |
+
"n03709823": 636,
|
639 |
+
"n03710193": 637,
|
640 |
+
"n03710637": 638,
|
641 |
+
"n03710721": 639,
|
642 |
+
"n03717622": 640,
|
643 |
+
"n03720891": 641,
|
644 |
+
"n03721384": 642,
|
645 |
+
"n03724870": 643,
|
646 |
+
"n03729826": 644,
|
647 |
+
"n03733131": 645,
|
648 |
+
"n03733281": 646,
|
649 |
+
"n03733805": 647,
|
650 |
+
"n03742115": 648,
|
651 |
+
"n03743016": 649,
|
652 |
+
"n03759954": 650,
|
653 |
+
"n03761084": 651,
|
654 |
+
"n03763968": 652,
|
655 |
+
"n03764736": 653,
|
656 |
+
"n03769881": 654,
|
657 |
+
"n03770439": 655,
|
658 |
+
"n03770679": 656,
|
659 |
+
"n03773504": 657,
|
660 |
+
"n03775071": 658,
|
661 |
+
"n03775546": 659,
|
662 |
+
"n03776460": 660,
|
663 |
+
"n03777568": 661,
|
664 |
+
"n03777754": 662,
|
665 |
+
"n03781244": 663,
|
666 |
+
"n03782006": 664,
|
667 |
+
"n03785016": 665,
|
668 |
+
"n03786901": 666,
|
669 |
+
"n03787032": 667,
|
670 |
+
"n03788195": 668,
|
671 |
+
"n03788365": 669,
|
672 |
+
"n03791053": 670,
|
673 |
+
"n03792782": 671,
|
674 |
+
"n03792972": 672,
|
675 |
+
"n03793489": 673,
|
676 |
+
"n03794056": 674,
|
677 |
+
"n03796401": 675,
|
678 |
+
"n03803284": 676,
|
679 |
+
"n03804744": 677,
|
680 |
+
"n03814639": 678,
|
681 |
+
"n03814906": 679,
|
682 |
+
"n03825788": 680,
|
683 |
+
"n03832673": 681,
|
684 |
+
"n03837869": 682,
|
685 |
+
"n03838899": 683,
|
686 |
+
"n03840681": 684,
|
687 |
+
"n03841143": 685,
|
688 |
+
"n03843555": 686,
|
689 |
+
"n03854065": 687,
|
690 |
+
"n03857828": 688,
|
691 |
+
"n03866082": 689,
|
692 |
+
"n03868242": 690,
|
693 |
+
"n03868863": 691,
|
694 |
+
"n03871628": 692,
|
695 |
+
"n03873416": 693,
|
696 |
+
"n03874293": 694,
|
697 |
+
"n03874599": 695,
|
698 |
+
"n03876231": 696,
|
699 |
+
"n03877472": 697,
|
700 |
+
"n03877845": 698,
|
701 |
+
"n03884397": 699,
|
702 |
+
"n03887697": 700,
|
703 |
+
"n03888257": 701,
|
704 |
+
"n03888605": 702,
|
705 |
+
"n03891251": 703,
|
706 |
+
"n03891332": 704,
|
707 |
+
"n03895866": 705,
|
708 |
+
"n03899768": 706,
|
709 |
+
"n03902125": 707,
|
710 |
+
"n03903868": 708,
|
711 |
+
"n03908618": 709,
|
712 |
+
"n03908714": 710,
|
713 |
+
"n03916031": 711,
|
714 |
+
"n03920288": 712,
|
715 |
+
"n03924679": 713,
|
716 |
+
"n03929660": 714,
|
717 |
+
"n03929855": 715,
|
718 |
+
"n03930313": 716,
|
719 |
+
"n03930630": 717,
|
720 |
+
"n03933933": 718,
|
721 |
+
"n03935335": 719,
|
722 |
+
"n03937543": 720,
|
723 |
+
"n03938244": 721,
|
724 |
+
"n03942813": 722,
|
725 |
+
"n03944341": 723,
|
726 |
+
"n03947888": 724,
|
727 |
+
"n03950228": 725,
|
728 |
+
"n03954731": 726,
|
729 |
+
"n03956157": 727,
|
730 |
+
"n03958227": 728,
|
731 |
+
"n03961711": 729,
|
732 |
+
"n03967562": 730,
|
733 |
+
"n03970156": 731,
|
734 |
+
"n03976467": 732,
|
735 |
+
"n03976657": 733,
|
736 |
+
"n03977966": 734,
|
737 |
+
"n03980874": 735,
|
738 |
+
"n03982430": 736,
|
739 |
+
"n03983396": 737,
|
740 |
+
"n03991062": 738,
|
741 |
+
"n03992509": 739,
|
742 |
+
"n03995372": 740,
|
743 |
+
"n03998194": 741,
|
744 |
+
"n04004767": 742,
|
745 |
+
"n04005630": 743,
|
746 |
+
"n04008634": 744,
|
747 |
+
"n04009552": 745,
|
748 |
+
"n04019541": 746,
|
749 |
+
"n04023962": 747,
|
750 |
+
"n04026417": 748,
|
751 |
+
"n04033901": 749,
|
752 |
+
"n04033995": 750,
|
753 |
+
"n04037443": 751,
|
754 |
+
"n04039381": 752,
|
755 |
+
"n04040759": 753,
|
756 |
+
"n04041544": 754,
|
757 |
+
"n04044716": 755,
|
758 |
+
"n04049303": 756,
|
759 |
+
"n04065272": 757,
|
760 |
+
"n04067472": 758,
|
761 |
+
"n04069434": 759,
|
762 |
+
"n04070727": 760,
|
763 |
+
"n04074963": 761,
|
764 |
+
"n04081281": 762,
|
765 |
+
"n04086273": 763,
|
766 |
+
"n04090263": 764,
|
767 |
+
"n04099969": 765,
|
768 |
+
"n04111531": 766,
|
769 |
+
"n04116512": 767,
|
770 |
+
"n04118538": 768,
|
771 |
+
"n04118776": 769,
|
772 |
+
"n04120489": 770,
|
773 |
+
"n04125021": 771,
|
774 |
+
"n04127249": 772,
|
775 |
+
"n04131690": 773,
|
776 |
+
"n04133789": 774,
|
777 |
+
"n04136333": 775,
|
778 |
+
"n04141076": 776,
|
779 |
+
"n04141327": 777,
|
780 |
+
"n04141975": 778,
|
781 |
+
"n04146614": 779,
|
782 |
+
"n04147183": 780,
|
783 |
+
"n04149813": 781,
|
784 |
+
"n04152593": 782,
|
785 |
+
"n04153751": 783,
|
786 |
+
"n04154565": 784,
|
787 |
+
"n04162706": 785,
|
788 |
+
"n04179913": 786,
|
789 |
+
"n04192698": 787,
|
790 |
+
"n04200800": 788,
|
791 |
+
"n04201297": 789,
|
792 |
+
"n04204238": 790,
|
793 |
+
"n04204347": 791,
|
794 |
+
"n04208210": 792,
|
795 |
+
"n04209133": 793,
|
796 |
+
"n04209239": 794,
|
797 |
+
"n04228054": 795,
|
798 |
+
"n04229816": 796,
|
799 |
+
"n04235860": 797,
|
800 |
+
"n04238763": 798,
|
801 |
+
"n04239074": 799,
|
802 |
+
"n04243546": 800,
|
803 |
+
"n04251144": 801,
|
804 |
+
"n04252077": 802,
|
805 |
+
"n04252225": 803,
|
806 |
+
"n04254120": 804,
|
807 |
+
"n04254680": 805,
|
808 |
+
"n04254777": 806,
|
809 |
+
"n04258138": 807,
|
810 |
+
"n04259630": 808,
|
811 |
+
"n04263257": 809,
|
812 |
+
"n04264628": 810,
|
813 |
+
"n04265275": 811,
|
814 |
+
"n04266014": 812,
|
815 |
+
"n04270147": 813,
|
816 |
+
"n04273569": 814,
|
817 |
+
"n04275548": 815,
|
818 |
+
"n04277352": 816,
|
819 |
+
"n04285008": 817,
|
820 |
+
"n04286575": 818,
|
821 |
+
"n04296562": 819,
|
822 |
+
"n04310018": 820,
|
823 |
+
"n04311004": 821,
|
824 |
+
"n04311174": 822,
|
825 |
+
"n04317175": 823,
|
826 |
+
"n04325704": 824,
|
827 |
+
"n04326547": 825,
|
828 |
+
"n04328186": 826,
|
829 |
+
"n04330267": 827,
|
830 |
+
"n04332243": 828,
|
831 |
+
"n04335435": 829,
|
832 |
+
"n04336792": 830,
|
833 |
+
"n04344873": 831,
|
834 |
+
"n04346328": 832,
|
835 |
+
"n04347754": 833,
|
836 |
+
"n04350905": 834,
|
837 |
+
"n04355338": 835,
|
838 |
+
"n04355933": 836,
|
839 |
+
"n04356056": 837,
|
840 |
+
"n04357314": 838,
|
841 |
+
"n04366367": 839,
|
842 |
+
"n04367480": 840,
|
843 |
+
"n04370456": 841,
|
844 |
+
"n04371430": 842,
|
845 |
+
"n04371774": 843,
|
846 |
+
"n04372370": 844,
|
847 |
+
"n04376876": 845,
|
848 |
+
"n04380533": 846,
|
849 |
+
"n04389033": 847,
|
850 |
+
"n04392985": 848,
|
851 |
+
"n04398044": 849,
|
852 |
+
"n04399382": 850,
|
853 |
+
"n04404412": 851,
|
854 |
+
"n04409515": 852,
|
855 |
+
"n04417672": 853,
|
856 |
+
"n04418357": 854,
|
857 |
+
"n04423845": 855,
|
858 |
+
"n04428191": 856,
|
859 |
+
"n04429376": 857,
|
860 |
+
"n04435653": 858,
|
861 |
+
"n04442312": 859,
|
862 |
+
"n04443257": 860,
|
863 |
+
"n04447861": 861,
|
864 |
+
"n04456115": 862,
|
865 |
+
"n04458633": 863,
|
866 |
+
"n04461696": 864,
|
867 |
+
"n04462240": 865,
|
868 |
+
"n04465501": 866,
|
869 |
+
"n04467665": 867,
|
870 |
+
"n04476259": 868,
|
871 |
+
"n04479046": 869,
|
872 |
+
"n04482393": 870,
|
873 |
+
"n04483307": 871,
|
874 |
+
"n04485082": 872,
|
875 |
+
"n04486054": 873,
|
876 |
+
"n04487081": 874,
|
877 |
+
"n04487394": 875,
|
878 |
+
"n04493381": 876,
|
879 |
+
"n04501370": 877,
|
880 |
+
"n04505470": 878,
|
881 |
+
"n04507155": 879,
|
882 |
+
"n04509417": 880,
|
883 |
+
"n04515003": 881,
|
884 |
+
"n04517823": 882,
|
885 |
+
"n04522168": 883,
|
886 |
+
"n04523525": 884,
|
887 |
+
"n04525038": 885,
|
888 |
+
"n04525305": 886,
|
889 |
+
"n04532106": 887,
|
890 |
+
"n04532670": 888,
|
891 |
+
"n04536866": 889,
|
892 |
+
"n04540053": 890,
|
893 |
+
"n04542943": 891,
|
894 |
+
"n04548280": 892,
|
895 |
+
"n04548362": 893,
|
896 |
+
"n04550184": 894,
|
897 |
+
"n04552348": 895,
|
898 |
+
"n04553703": 896,
|
899 |
+
"n04554684": 897,
|
900 |
+
"n04557648": 898,
|
901 |
+
"n04560804": 899,
|
902 |
+
"n04562935": 900,
|
903 |
+
"n04579145": 901,
|
904 |
+
"n04579432": 902,
|
905 |
+
"n04584207": 903,
|
906 |
+
"n04589890": 904,
|
907 |
+
"n04590129": 905,
|
908 |
+
"n04591157": 906,
|
909 |
+
"n04591713": 907,
|
910 |
+
"n04592741": 908,
|
911 |
+
"n04596742": 909,
|
912 |
+
"n04597913": 910,
|
913 |
+
"n04599235": 911,
|
914 |
+
"n04604644": 912,
|
915 |
+
"n04606251": 913,
|
916 |
+
"n04612504": 914,
|
917 |
+
"n04613696": 915,
|
918 |
+
"n06359193": 916,
|
919 |
+
"n06596364": 917,
|
920 |
+
"n06785654": 918,
|
921 |
+
"n06794110": 919,
|
922 |
+
"n06874185": 920,
|
923 |
+
"n07248320": 921,
|
924 |
+
"n07565083": 922,
|
925 |
+
"n07579787": 923,
|
926 |
+
"n07583066": 924,
|
927 |
+
"n07584110": 925,
|
928 |
+
"n07590611": 926,
|
929 |
+
"n07613480": 927,
|
930 |
+
"n07614500": 928,
|
931 |
+
"n07615774": 929,
|
932 |
+
"n07684084": 930,
|
933 |
+
"n07693725": 931,
|
934 |
+
"n07695742": 932,
|
935 |
+
"n07697313": 933,
|
936 |
+
"n07697537": 934,
|
937 |
+
"n07711569": 935,
|
938 |
+
"n07714571": 936,
|
939 |
+
"n07714990": 937,
|
940 |
+
"n07715103": 938,
|
941 |
+
"n07716358": 939,
|
942 |
+
"n07716906": 940,
|
943 |
+
"n07717410": 941,
|
944 |
+
"n07717556": 942,
|
945 |
+
"n07718472": 943,
|
946 |
+
"n07718747": 944,
|
947 |
+
"n07720875": 945,
|
948 |
+
"n07730033": 946,
|
949 |
+
"n07734744": 947,
|
950 |
+
"n07742313": 948,
|
951 |
+
"n07745940": 949,
|
952 |
+
"n07747607": 950,
|
953 |
+
"n07749582": 951,
|
954 |
+
"n07753113": 952,
|
955 |
+
"n07753275": 953,
|
956 |
+
"n07753592": 954,
|
957 |
+
"n07754684": 955,
|
958 |
+
"n07760859": 956,
|
959 |
+
"n07768694": 957,
|
960 |
+
"n07802026": 958,
|
961 |
+
"n07831146": 959,
|
962 |
+
"n07836838": 960,
|
963 |
+
"n07860988": 961,
|
964 |
+
"n07871810": 962,
|
965 |
+
"n07873807": 963,
|
966 |
+
"n07875152": 964,
|
967 |
+
"n07880968": 965,
|
968 |
+
"n07892512": 966,
|
969 |
+
"n07920052": 967,
|
970 |
+
"n07930864": 968,
|
971 |
+
"n07932039": 969,
|
972 |
+
"n09193705": 970,
|
973 |
+
"n09229709": 971,
|
974 |
+
"n09246464": 972,
|
975 |
+
"n09256479": 973,
|
976 |
+
"n09288635": 974,
|
977 |
+
"n09332890": 975,
|
978 |
+
"n09399592": 976,
|
979 |
+
"n09421951": 977,
|
980 |
+
"n09428293": 978,
|
981 |
+
"n09468604": 979,
|
982 |
+
"n09472597": 980,
|
983 |
+
"n09835506": 981,
|
984 |
+
"n10148035": 982,
|
985 |
+
"n10565667": 983,
|
986 |
+
"n11879895": 984,
|
987 |
+
"n11939491": 985,
|
988 |
+
"n12057211": 986,
|
989 |
+
"n12144580": 987,
|
990 |
+
"n12267677": 988,
|
991 |
+
"n12620546": 989,
|
992 |
+
"n12768682": 990,
|
993 |
+
"n12985857": 991,
|
994 |
+
"n12998815": 992,
|
995 |
+
"n13037406": 993,
|
996 |
+
"n13040303": 994,
|
997 |
+
"n13044778": 995,
|
998 |
+
"n13052670": 996,
|
999 |
+
"n13054560": 997,
|
1000 |
+
"n13133613": 998,
|
1001 |
+
"n15075141": 999
|
1002 |
+
}
|
imagenet_eval_robustness.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import shutil
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.parallel
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
import torch.distributed as dist
|
13 |
+
import torch.optim
|
14 |
+
import torch.multiprocessing as mp
|
15 |
+
import torch.utils.data
|
16 |
+
import torch.utils.data.distributed
|
17 |
+
import torchvision.transforms as transforms
|
18 |
+
import torchvision.datasets as datasets
|
19 |
+
import torchvision.models as models
|
20 |
+
|
21 |
+
# Uncomment the expected model below
|
22 |
+
|
23 |
+
# ViT
|
24 |
+
from ViT.ViT import vit_base_patch16_224 as vit
|
25 |
+
# from ViT.ViT import vit_large_patch16_224 as vit
|
26 |
+
|
27 |
+
# ViT-AugReg
|
28 |
+
# from ViT.ViT_new import vit_small_patch16_224 as vit
|
29 |
+
# from ViT.ViT_new import vit_base_patch16_224 as vit
|
30 |
+
# from ViT.ViT_new import vit_large_patch16_224 as vit
|
31 |
+
|
32 |
+
# DeiT
|
33 |
+
# from ViT.ViT import deit_base_patch16_224 as vit
|
34 |
+
# from ViT.ViT import deit_small_patch16_224 as vit
|
35 |
+
|
36 |
+
from robustness_dataset import RobustnessDataset
|
37 |
+
from objectnet_dataset import ObjectNetDataset
|
38 |
+
model_names = sorted(name for name in models.__dict__
|
39 |
+
if name.islower() and not name.startswith("__")
|
40 |
+
and callable(models.__dict__[name]))
|
41 |
+
model_names.append("vit")
|
42 |
+
|
43 |
+
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
44 |
+
parser.add_argument('--data', metavar='DIR',
|
45 |
+
help='path to dataset')
|
46 |
+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
47 |
+
help='number of data loading workers (default: 4)')
|
48 |
+
parser.add_argument('--epochs', default=150, type=int, metavar='N',
|
49 |
+
help='number of total epochs to run')
|
50 |
+
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
51 |
+
help='manual epoch number (useful on restarts)')
|
52 |
+
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
53 |
+
metavar='N',
|
54 |
+
help='mini-batch size (default: 256), this is the total '
|
55 |
+
'batch size of all GPUs on the current node when '
|
56 |
+
'using Data Parallel or Distributed Data Parallel')
|
57 |
+
parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,
|
58 |
+
metavar='LR', help='initial learning rate', dest='lr')
|
59 |
+
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
60 |
+
help='momentum')
|
61 |
+
parser.add_argument('--wd', '--weight-decay', default=0.05, type=float,
|
62 |
+
metavar='W', help='weight decay (default: 1e-4)',
|
63 |
+
dest='weight_decay')
|
64 |
+
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
65 |
+
metavar='N', help='print frequency (default: 10)')
|
66 |
+
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
67 |
+
help='path to latest checkpoint (default: none)')
|
68 |
+
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
69 |
+
help='path to resume checkpoint (default: none)')
|
70 |
+
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
71 |
+
help='evaluate model on validation set')
|
72 |
+
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
73 |
+
help='use pre-trained model')
|
74 |
+
parser.add_argument('--world-size', default=-1, type=int,
|
75 |
+
help='number of nodes for distributed training')
|
76 |
+
parser.add_argument('--rank', default=-1, type=int,
|
77 |
+
help='node rank for distributed training')
|
78 |
+
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
79 |
+
help='url used to set up distributed training')
|
80 |
+
parser.add_argument('--dist-backend', default='nccl', type=str,
|
81 |
+
help='distributed backend')
|
82 |
+
parser.add_argument('--seed', default=None, type=int,
|
83 |
+
help='seed for initializing training. ')
|
84 |
+
parser.add_argument('--gpu', default=None, type=int,
|
85 |
+
help='GPU id to use.')
|
86 |
+
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
87 |
+
help='Use multi-processing distributed training to launch '
|
88 |
+
'N processes per node, which has N GPUs. This is the '
|
89 |
+
'fastest way to use PyTorch for either single node or '
|
90 |
+
'multi node data parallel training')
|
91 |
+
parser.add_argument("--isV2", default=False, action='store_true',
|
92 |
+
help='is dataset imagenet V2.')
|
93 |
+
parser.add_argument("--isSI", default=False, action='store_true',
|
94 |
+
help='is dataset SI-score.')
|
95 |
+
parser.add_argument("--isObjectNet", default=False, action='store_true',
|
96 |
+
help='is dataset SI-score.')
|
97 |
+
|
98 |
+
|
99 |
+
def main():
|
100 |
+
args = parser.parse_args()
|
101 |
+
|
102 |
+
if args.seed is not None:
|
103 |
+
random.seed(args.seed)
|
104 |
+
torch.manual_seed(args.seed)
|
105 |
+
cudnn.deterministic = True
|
106 |
+
warnings.warn('You have chosen to seed training. '
|
107 |
+
'This will turn on the CUDNN deterministic setting, '
|
108 |
+
'which can slow down your training considerably! '
|
109 |
+
'You may see unexpected behavior when restarting '
|
110 |
+
'from checkpoints.')
|
111 |
+
|
112 |
+
if args.gpu is not None:
|
113 |
+
warnings.warn('You have chosen a specific GPU. This will completely '
|
114 |
+
'disable data parallelism.')
|
115 |
+
|
116 |
+
if args.dist_url == "env://" and args.world_size == -1:
|
117 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
118 |
+
|
119 |
+
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
120 |
+
|
121 |
+
ngpus_per_node = torch.cuda.device_count()
|
122 |
+
if args.multiprocessing_distributed:
|
123 |
+
# Since we have ngpus_per_node processes per node, the total world_size
|
124 |
+
# needs to be adjusted accordingly
|
125 |
+
args.world_size = ngpus_per_node * args.world_size
|
126 |
+
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
127 |
+
# main_worker process function
|
128 |
+
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
129 |
+
else:
|
130 |
+
# Simply call main_worker function
|
131 |
+
main_worker(args.gpu, ngpus_per_node, args)
|
132 |
+
|
133 |
+
|
134 |
+
def main_worker(gpu, ngpus_per_node, args):
|
135 |
+
global best_acc1
|
136 |
+
args.gpu = gpu
|
137 |
+
|
138 |
+
if args.gpu is not None:
|
139 |
+
print("Use GPU: {} for training".format(args.gpu))
|
140 |
+
|
141 |
+
if args.distributed:
|
142 |
+
if args.dist_url == "env://" and args.rank == -1:
|
143 |
+
args.rank = int(os.environ["RANK"])
|
144 |
+
if args.multiprocessing_distributed:
|
145 |
+
# For multiprocessing distributed training, rank needs to be the
|
146 |
+
# global rank among all the processes
|
147 |
+
args.rank = args.rank * ngpus_per_node + gpu
|
148 |
+
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
149 |
+
world_size=args.world_size, rank=args.rank)
|
150 |
+
# create model
|
151 |
+
print("=> creating model")
|
152 |
+
if args.checkpoint:
|
153 |
+
model = vit().cuda()
|
154 |
+
checkpoint = torch.load(args.checkpoint)
|
155 |
+
model.load_state_dict(checkpoint['state_dict'])
|
156 |
+
else:
|
157 |
+
model = vit(pretrained=True).cuda()
|
158 |
+
print("done")
|
159 |
+
|
160 |
+
if not torch.cuda.is_available():
|
161 |
+
print('using CPU, this will be slow')
|
162 |
+
elif args.distributed:
|
163 |
+
# For multiprocessing distributed, DistributedDataParallel constructor
|
164 |
+
# should always set the single device scope, otherwise,
|
165 |
+
# DistributedDataParallel will use all available devices.
|
166 |
+
if args.gpu is not None:
|
167 |
+
torch.cuda.set_device(args.gpu)
|
168 |
+
model.cuda(args.gpu)
|
169 |
+
# When using a single GPU per process and per
|
170 |
+
# DistributedDataParallel, we need to divide the batch size
|
171 |
+
# ourselves based on the total number of GPUs we have
|
172 |
+
args.batch_size = int(args.batch_size / ngpus_per_node)
|
173 |
+
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
174 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
175 |
+
else:
|
176 |
+
model.cuda()
|
177 |
+
# DistributedDataParallel will divide and allocate batch_size to all
|
178 |
+
# available GPUs if device_ids are not set
|
179 |
+
model = torch.nn.parallel.DistributedDataParallel(model)
|
180 |
+
elif args.gpu is not None:
|
181 |
+
torch.cuda.set_device(args.gpu)
|
182 |
+
model = model.cuda(args.gpu)
|
183 |
+
else:
|
184 |
+
print("start")
|
185 |
+
model = torch.nn.DataParallel(model).cuda()
|
186 |
+
|
187 |
+
# optionally resume from a checkpoint
|
188 |
+
if args.resume:
|
189 |
+
if os.path.isfile(args.resume):
|
190 |
+
print("=> loading checkpoint '{}'".format(args.resume))
|
191 |
+
if args.gpu is None:
|
192 |
+
checkpoint = torch.load(args.resume)
|
193 |
+
else:
|
194 |
+
# Map model to be loaded to specified single gpu.
|
195 |
+
loc = 'cuda:{}'.format(args.gpu)
|
196 |
+
checkpoint = torch.load(args.resume, map_location=loc)
|
197 |
+
args.start_epoch = checkpoint['epoch']
|
198 |
+
best_acc1 = checkpoint['best_acc1']
|
199 |
+
if args.gpu is not None:
|
200 |
+
# best_acc1 may be from a checkpoint from a different GPU
|
201 |
+
best_acc1 = best_acc1.to(args.gpu)
|
202 |
+
model.load_state_dict(checkpoint['state_dict'])
|
203 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
204 |
+
.format(args.resume, checkpoint['epoch']))
|
205 |
+
else:
|
206 |
+
print("=> no checkpoint found at '{}'".format(args.resume))
|
207 |
+
|
208 |
+
cudnn.benchmark = True
|
209 |
+
|
210 |
+
if args.isObjectNet:
|
211 |
+
val_dataset = ObjectNetDataset(args.data)
|
212 |
+
else:
|
213 |
+
val_dataset = RobustnessDataset(args.data, isV2=args.isV2, isSI=args.isSI)
|
214 |
+
|
215 |
+
val_loader = torch.utils.data.DataLoader(
|
216 |
+
val_dataset, batch_size=args.batch_size, shuffle=False,
|
217 |
+
num_workers=args.workers, pin_memory=True)
|
218 |
+
|
219 |
+
if args.evaluate:
|
220 |
+
validate(val_loader, model, args)
|
221 |
+
return
|
222 |
+
|
223 |
+
def validate(val_loader, model, args):
|
224 |
+
batch_time = AverageMeter('Time', ':6.3f')
|
225 |
+
losses = AverageMeter('Loss', ':.4e')
|
226 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
227 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
228 |
+
progress = ProgressMeter(
|
229 |
+
len(val_loader),
|
230 |
+
[batch_time, losses, top1, top5],
|
231 |
+
prefix='Test: ')
|
232 |
+
|
233 |
+
# switch to evaluate mode
|
234 |
+
model.eval()
|
235 |
+
|
236 |
+
with torch.no_grad():
|
237 |
+
end = time.time()
|
238 |
+
for i, (images, target) in enumerate(val_loader):
|
239 |
+
if args.gpu is not None:
|
240 |
+
images = images.cuda(args.gpu, non_blocking=True)
|
241 |
+
if torch.cuda.is_available():
|
242 |
+
target = target.cuda(args.gpu, non_blocking=True)
|
243 |
+
|
244 |
+
# compute output
|
245 |
+
output = model(images)
|
246 |
+
|
247 |
+
# measure accuracy and record loss
|
248 |
+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
249 |
+
top1.update(acc1[0], images.size(0))
|
250 |
+
top5.update(acc5[0], images.size(0))
|
251 |
+
|
252 |
+
# measure elapsed time
|
253 |
+
batch_time.update(time.time() - end)
|
254 |
+
end = time.time()
|
255 |
+
|
256 |
+
if i % args.print_freq == 0:
|
257 |
+
progress.display(i)
|
258 |
+
|
259 |
+
# TODO: this should also be done with the ProgressMeter
|
260 |
+
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
|
261 |
+
.format(top1=top1, top5=top5))
|
262 |
+
|
263 |
+
return top1.avg
|
264 |
+
|
265 |
+
|
266 |
+
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
267 |
+
torch.save(state, filename)
|
268 |
+
if is_best:
|
269 |
+
shutil.copyfile(filename, 'model_best.pth.tar')
|
270 |
+
|
271 |
+
|
272 |
+
class AverageMeter(object):
|
273 |
+
"""Computes and stores the average and current value"""
|
274 |
+
def __init__(self, name, fmt=':f'):
|
275 |
+
self.name = name
|
276 |
+
self.fmt = fmt
|
277 |
+
self.reset()
|
278 |
+
|
279 |
+
def reset(self):
|
280 |
+
self.val = 0
|
281 |
+
self.avg = 0
|
282 |
+
self.sum = 0
|
283 |
+
self.count = 0
|
284 |
+
|
285 |
+
def update(self, val, n=1):
|
286 |
+
self.val = val
|
287 |
+
self.sum += val * n
|
288 |
+
self.count += n
|
289 |
+
self.avg = self.sum / self.count
|
290 |
+
|
291 |
+
def __str__(self):
|
292 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
293 |
+
return fmtstr.format(**self.__dict__)
|
294 |
+
|
295 |
+
|
296 |
+
class ProgressMeter(object):
|
297 |
+
def __init__(self, num_batches, meters, prefix=""):
|
298 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
299 |
+
self.meters = meters
|
300 |
+
self.prefix = prefix
|
301 |
+
|
302 |
+
def display(self, batch):
|
303 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
304 |
+
entries += [str(meter) for meter in self.meters]
|
305 |
+
print('\t'.join(entries))
|
306 |
+
|
307 |
+
def _get_batch_fmtstr(self, num_batches):
|
308 |
+
num_digits = len(str(num_batches // 1))
|
309 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
310 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
311 |
+
|
312 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
313 |
+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
314 |
+
lr = args.lr * (0.85 ** (epoch // 2))
|
315 |
+
for param_group in optimizer.param_groups:
|
316 |
+
param_group['lr'] = lr
|
317 |
+
|
318 |
+
|
319 |
+
def accuracy(output, target, topk=(1,)):
|
320 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
321 |
+
with torch.no_grad():
|
322 |
+
maxk = max(topk)
|
323 |
+
batch_size = target.size(0)
|
324 |
+
|
325 |
+
_, pred = output.topk(maxk, 1, True, True)
|
326 |
+
pred = pred.t()
|
327 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
328 |
+
|
329 |
+
res = []
|
330 |
+
for k in topk:
|
331 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
332 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
333 |
+
return res
|
334 |
+
|
335 |
+
|
336 |
+
if __name__ == '__main__':
|
337 |
+
main()
|
imagenet_eval_robustness_per_class.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import shutil
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.parallel
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
import torch.distributed as dist
|
13 |
+
import torch.optim
|
14 |
+
import torch.multiprocessing as mp
|
15 |
+
import torch.utils.data
|
16 |
+
import torch.utils.data.distributed
|
17 |
+
import torchvision.transforms as transforms
|
18 |
+
import torchvision.datasets as datasets
|
19 |
+
import torchvision.models as models
|
20 |
+
|
21 |
+
# Uncomment the expected model below
|
22 |
+
|
23 |
+
# ViT
|
24 |
+
from ViT.ViT import vit_base_patch16_224 as vit
|
25 |
+
# from ViT.ViT import vit_large_patch16_224 as vit
|
26 |
+
|
27 |
+
# ViT-AugReg
|
28 |
+
# from ViT.ViT_new import vit_small_patch16_224 as vit
|
29 |
+
# from ViT.ViT_new import vit_base_patch16_224 as vit
|
30 |
+
# from ViT.ViT_new import vit_large_patch16_224 as vit
|
31 |
+
|
32 |
+
# DeiT
|
33 |
+
# from ViT.ViT import deit_base_patch16_224 as vit
|
34 |
+
# from ViT.ViT import deit_small_patch16_224 as vit
|
35 |
+
|
36 |
+
from robustness_dataset_per_class import RobustnessDataset
|
37 |
+
from objectnet_dataset import ObjectNetDataset
|
38 |
+
model_names = sorted(name for name in models.__dict__
|
39 |
+
if name.islower() and not name.startswith("__")
|
40 |
+
and callable(models.__dict__[name]))
|
41 |
+
model_names.append("vit")
|
42 |
+
|
43 |
+
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
44 |
+
parser.add_argument('--data', metavar='DIR',
|
45 |
+
help='path to dataset')
|
46 |
+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
47 |
+
help='number of data loading workers (default: 4)')
|
48 |
+
parser.add_argument('--epochs', default=150, type=int, metavar='N',
|
49 |
+
help='number of total epochs to run')
|
50 |
+
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
51 |
+
help='manual epoch number (useful on restarts)')
|
52 |
+
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
53 |
+
metavar='N',
|
54 |
+
help='mini-batch size (default: 256), this is the total '
|
55 |
+
'batch size of all GPUs on the current node when '
|
56 |
+
'using Data Parallel or Distributed Data Parallel')
|
57 |
+
parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,
|
58 |
+
metavar='LR', help='initial learning rate', dest='lr')
|
59 |
+
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
60 |
+
help='momentum')
|
61 |
+
parser.add_argument('--wd', '--weight-decay', default=0.05, type=float,
|
62 |
+
metavar='W', help='weight decay (default: 1e-4)',
|
63 |
+
dest='weight_decay')
|
64 |
+
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
65 |
+
metavar='N', help='print frequency (default: 10)')
|
66 |
+
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
67 |
+
help='path to latest checkpoint (default: none)')
|
68 |
+
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
69 |
+
help='path to resume checkpoint (default: none)')
|
70 |
+
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
71 |
+
help='evaluate model on validation set')
|
72 |
+
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
73 |
+
help='use pre-trained model')
|
74 |
+
parser.add_argument('--world-size', default=-1, type=int,
|
75 |
+
help='number of nodes for distributed training')
|
76 |
+
parser.add_argument('--rank', default=-1, type=int,
|
77 |
+
help='node rank for distributed training')
|
78 |
+
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
79 |
+
help='url used to set up distributed training')
|
80 |
+
parser.add_argument('--dist-backend', default='nccl', type=str,
|
81 |
+
help='distributed backend')
|
82 |
+
parser.add_argument('--seed', default=None, type=int,
|
83 |
+
help='seed for initializing training. ')
|
84 |
+
parser.add_argument('--gpu', default=None, type=int,
|
85 |
+
help='GPU id to use.')
|
86 |
+
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
87 |
+
help='Use multi-processing distributed training to launch '
|
88 |
+
'N processes per node, which has N GPUs. This is the '
|
89 |
+
'fastest way to use PyTorch for either single node or '
|
90 |
+
'multi node data parallel training')
|
91 |
+
parser.add_argument("--isV2", default=False, action='store_true',
|
92 |
+
help='is dataset imagenet V2.')
|
93 |
+
parser.add_argument("--isSI", default=False, action='store_true',
|
94 |
+
help='is dataset SI-score.')
|
95 |
+
parser.add_argument("--isObjectNet", default=False, action='store_true',
|
96 |
+
help='is dataset SI-score.')
|
97 |
+
|
98 |
+
|
99 |
+
def main():
|
100 |
+
args = parser.parse_args()
|
101 |
+
|
102 |
+
if args.seed is not None:
|
103 |
+
random.seed(args.seed)
|
104 |
+
torch.manual_seed(args.seed)
|
105 |
+
cudnn.deterministic = True
|
106 |
+
warnings.warn('You have chosen to seed training. '
|
107 |
+
'This will turn on the CUDNN deterministic setting, '
|
108 |
+
'which can slow down your training considerably! '
|
109 |
+
'You may see unexpected behavior when restarting '
|
110 |
+
'from checkpoints.')
|
111 |
+
|
112 |
+
if args.gpu is not None:
|
113 |
+
warnings.warn('You have chosen a specific GPU. This will completely '
|
114 |
+
'disable data parallelism.')
|
115 |
+
|
116 |
+
if args.dist_url == "env://" and args.world_size == -1:
|
117 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
118 |
+
|
119 |
+
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
120 |
+
|
121 |
+
ngpus_per_node = torch.cuda.device_count()
|
122 |
+
if args.multiprocessing_distributed:
|
123 |
+
# Since we have ngpus_per_node processes per node, the total world_size
|
124 |
+
# needs to be adjusted accordingly
|
125 |
+
args.world_size = ngpus_per_node * args.world_size
|
126 |
+
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
127 |
+
# main_worker process function
|
128 |
+
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
129 |
+
else:
|
130 |
+
# Simply call main_worker function
|
131 |
+
main_worker(args.gpu, ngpus_per_node, args)
|
132 |
+
|
133 |
+
|
134 |
+
def main_worker(gpu, ngpus_per_node, args):
|
135 |
+
global best_acc1
|
136 |
+
args.gpu = gpu
|
137 |
+
|
138 |
+
if args.gpu is not None:
|
139 |
+
print("Use GPU: {} for training".format(args.gpu))
|
140 |
+
|
141 |
+
if args.distributed:
|
142 |
+
if args.dist_url == "env://" and args.rank == -1:
|
143 |
+
args.rank = int(os.environ["RANK"])
|
144 |
+
if args.multiprocessing_distributed:
|
145 |
+
# For multiprocessing distributed training, rank needs to be the
|
146 |
+
# global rank among all the processes
|
147 |
+
args.rank = args.rank * ngpus_per_node + gpu
|
148 |
+
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
149 |
+
world_size=args.world_size, rank=args.rank)
|
150 |
+
# create model
|
151 |
+
print("=> creating model")
|
152 |
+
if args.checkpoint:
|
153 |
+
model = vit().cuda()
|
154 |
+
checkpoint = torch.load(args.checkpoint)
|
155 |
+
model.load_state_dict(checkpoint['state_dict'])
|
156 |
+
else:
|
157 |
+
model = vit(pretrained=True).cuda()
|
158 |
+
print("done")
|
159 |
+
|
160 |
+
if not torch.cuda.is_available():
|
161 |
+
print('using CPU, this will be slow')
|
162 |
+
elif args.distributed:
|
163 |
+
# For multiprocessing distributed, DistributedDataParallel constructor
|
164 |
+
# should always set the single device scope, otherwise,
|
165 |
+
# DistributedDataParallel will use all available devices.
|
166 |
+
if args.gpu is not None:
|
167 |
+
torch.cuda.set_device(args.gpu)
|
168 |
+
model.cuda(args.gpu)
|
169 |
+
# When using a single GPU per process and per
|
170 |
+
# DistributedDataParallel, we need to divide the batch size
|
171 |
+
# ourselves based on the total number of GPUs we have
|
172 |
+
args.batch_size = int(args.batch_size / ngpus_per_node)
|
173 |
+
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
174 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
175 |
+
else:
|
176 |
+
model.cuda()
|
177 |
+
# DistributedDataParallel will divide and allocate batch_size to all
|
178 |
+
# available GPUs if device_ids are not set
|
179 |
+
model = torch.nn.parallel.DistributedDataParallel(model)
|
180 |
+
elif args.gpu is not None:
|
181 |
+
torch.cuda.set_device(args.gpu)
|
182 |
+
model = model.cuda(args.gpu)
|
183 |
+
else:
|
184 |
+
# DataParallel will divide and allocate batch_size to all available GPUs
|
185 |
+
print("start")
|
186 |
+
model = torch.nn.DataParallel(model).cuda()
|
187 |
+
|
188 |
+
# optionally resume from a checkpoint
|
189 |
+
if args.resume:
|
190 |
+
if os.path.isfile(args.resume):
|
191 |
+
print("=> loading checkpoint '{}'".format(args.resume))
|
192 |
+
if args.gpu is None:
|
193 |
+
checkpoint = torch.load(args.resume)
|
194 |
+
else:
|
195 |
+
# Map model to be loaded to specified single gpu.
|
196 |
+
loc = 'cuda:{}'.format(args.gpu)
|
197 |
+
checkpoint = torch.load(args.resume, map_location=loc)
|
198 |
+
args.start_epoch = checkpoint['epoch']
|
199 |
+
best_acc1 = checkpoint['best_acc1']
|
200 |
+
if args.gpu is not None:
|
201 |
+
# best_acc1 may be from a checkpoint from a different GPU
|
202 |
+
best_acc1 = best_acc1.to(args.gpu)
|
203 |
+
model.load_state_dict(checkpoint['state_dict'])
|
204 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
205 |
+
.format(args.resume, checkpoint['epoch']))
|
206 |
+
else:
|
207 |
+
print("=> no checkpoint found at '{}'".format(args.resume))
|
208 |
+
|
209 |
+
cudnn.benchmark = True
|
210 |
+
|
211 |
+
# Data loading code
|
212 |
+
|
213 |
+
top1_per_class = {}
|
214 |
+
top5_per_class = {}
|
215 |
+
for folder in os.listdir(args.data):
|
216 |
+
val_dataset = RobustnessDataset(args.data, folder=folder, isV2=args.isV2, isSI=args.isSI)
|
217 |
+
print("len: ", len(val_dataset))
|
218 |
+
val_loader = torch.utils.data.DataLoader(
|
219 |
+
val_dataset, batch_size=args.batch_size, shuffle=False,
|
220 |
+
num_workers=args.workers, pin_memory=True)
|
221 |
+
class_name = val_dataset.get_classname()
|
222 |
+
top1, top5 = validate(val_loader, model, args)
|
223 |
+
top1_per_class[class_name] = top1.item()
|
224 |
+
top5_per_class[class_name] = top5.item()
|
225 |
+
|
226 |
+
print("overall top1 per class: ", top1_per_class)
|
227 |
+
print("overall top5 per class: ", top5_per_class)
|
228 |
+
|
229 |
+
def validate(val_loader, model, args):
|
230 |
+
batch_time = AverageMeter('Time', ':6.3f')
|
231 |
+
losses = AverageMeter('Loss', ':.4e')
|
232 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
233 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
234 |
+
progress = ProgressMeter(
|
235 |
+
len(val_loader),
|
236 |
+
[batch_time, losses, top1, top5],
|
237 |
+
prefix='Test: ')
|
238 |
+
|
239 |
+
# switch to evaluate mode
|
240 |
+
model.eval()
|
241 |
+
|
242 |
+
with torch.no_grad():
|
243 |
+
end = time.time()
|
244 |
+
for i, (images, target) in enumerate(val_loader):
|
245 |
+
if args.gpu is not None:
|
246 |
+
images = images.cuda(args.gpu, non_blocking=True)
|
247 |
+
if torch.cuda.is_available():
|
248 |
+
target = target.cuda(args.gpu, non_blocking=True)
|
249 |
+
|
250 |
+
# compute output
|
251 |
+
output = model(images)
|
252 |
+
|
253 |
+
# measure accuracy and record loss
|
254 |
+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
255 |
+
top1.update(acc1[0], images.size(0))
|
256 |
+
top5.update(acc5[0], images.size(0))
|
257 |
+
|
258 |
+
# measure elapsed time
|
259 |
+
batch_time.update(time.time() - end)
|
260 |
+
end = time.time()
|
261 |
+
|
262 |
+
if i % args.print_freq == 0:
|
263 |
+
progress.display(i)
|
264 |
+
|
265 |
+
# TODO: this should also be done with the ProgressMeter
|
266 |
+
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
|
267 |
+
.format(top1=top1, top5=top5))
|
268 |
+
|
269 |
+
return top1.avg, top5.avg
|
270 |
+
|
271 |
+
|
272 |
+
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
273 |
+
torch.save(state, filename)
|
274 |
+
if is_best:
|
275 |
+
shutil.copyfile(filename, 'model_best.pth.tar')
|
276 |
+
|
277 |
+
|
278 |
+
class AverageMeter(object):
|
279 |
+
"""Computes and stores the average and current value"""
|
280 |
+
def __init__(self, name, fmt=':f'):
|
281 |
+
self.name = name
|
282 |
+
self.fmt = fmt
|
283 |
+
self.reset()
|
284 |
+
|
285 |
+
def reset(self):
|
286 |
+
self.val = 0
|
287 |
+
self.avg = 0
|
288 |
+
self.sum = 0
|
289 |
+
self.count = 0
|
290 |
+
|
291 |
+
def update(self, val, n=1):
|
292 |
+
self.val = val
|
293 |
+
self.sum += val * n
|
294 |
+
self.count += n
|
295 |
+
self.avg = self.sum / self.count
|
296 |
+
|
297 |
+
def __str__(self):
|
298 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
299 |
+
return fmtstr.format(**self.__dict__)
|
300 |
+
|
301 |
+
|
302 |
+
class ProgressMeter(object):
|
303 |
+
def __init__(self, num_batches, meters, prefix=""):
|
304 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
305 |
+
self.meters = meters
|
306 |
+
self.prefix = prefix
|
307 |
+
|
308 |
+
def display(self, batch):
|
309 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
310 |
+
entries += [str(meter) for meter in self.meters]
|
311 |
+
print('\t'.join(entries))
|
312 |
+
|
313 |
+
def _get_batch_fmtstr(self, num_batches):
|
314 |
+
num_digits = len(str(num_batches // 1))
|
315 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
316 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
317 |
+
|
318 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
319 |
+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
320 |
+
lr = args.lr * (0.85 ** (epoch // 2))
|
321 |
+
for param_group in optimizer.param_groups:
|
322 |
+
param_group['lr'] = lr
|
323 |
+
|
324 |
+
|
325 |
+
def accuracy(output, target, topk=(1,)):
|
326 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
327 |
+
with torch.no_grad():
|
328 |
+
maxk = max(topk)
|
329 |
+
batch_size = target.size(0)
|
330 |
+
|
331 |
+
_, pred = output.topk(maxk, 1, True, True)
|
332 |
+
pred = pred.t()
|
333 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
334 |
+
|
335 |
+
res = []
|
336 |
+
for k in topk:
|
337 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
338 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
339 |
+
return res
|
340 |
+
|
341 |
+
|
342 |
+
if __name__ == '__main__':
|
343 |
+
main()
|
imagenet_finetune.py
ADDED
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import shutil
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.parallel
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
import torch.distributed as dist
|
13 |
+
import torch.optim
|
14 |
+
import torch.multiprocessing as mp
|
15 |
+
import torch.utils.data
|
16 |
+
import torch.utils.data.distributed
|
17 |
+
import torchvision.transforms as transforms
|
18 |
+
import torchvision.datasets as datasets
|
19 |
+
import torchvision.models as models
|
20 |
+
from segmentation_dataset import SegmentationDataset, VAL_PARTITION, TRAIN_PARTITION
|
21 |
+
|
22 |
+
# Uncomment the expected model below
|
23 |
+
|
24 |
+
# ViT
|
25 |
+
from ViT.ViT import vit_base_patch16_224 as vit
|
26 |
+
# from ViT.ViT import vit_large_patch16_224 as vit
|
27 |
+
|
28 |
+
# ViT-AugReg
|
29 |
+
# from ViT.ViT_new import vit_small_patch16_224 as vit
|
30 |
+
# from ViT.ViT_new import vit_base_patch16_224 as vit
|
31 |
+
# from ViT.ViT_new import vit_large_patch16_224 as vit
|
32 |
+
|
33 |
+
# DeiT
|
34 |
+
# from ViT.ViT import deit_base_patch16_224 as vit
|
35 |
+
# from ViT.ViT import deit_small_patch16_224 as vit
|
36 |
+
|
37 |
+
from ViT.explainer import generate_relevance, get_image_with_relevance
|
38 |
+
import torchvision
|
39 |
+
import cv2
|
40 |
+
from torch.utils.tensorboard import SummaryWriter
|
41 |
+
import json
|
42 |
+
|
43 |
+
model_names = sorted(name for name in models.__dict__
|
44 |
+
if name.islower() and not name.startswith("__")
|
45 |
+
and callable(models.__dict__[name]))
|
46 |
+
model_names.append("vit")
|
47 |
+
|
48 |
+
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
49 |
+
parser.add_argument('--data', metavar='DATA',
|
50 |
+
help='path to dataset')
|
51 |
+
parser.add_argument('--seg_data', metavar='SEG_DATA',
|
52 |
+
help='path to segmentation dataset')
|
53 |
+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
54 |
+
help='number of data loading workers (default: 4)')
|
55 |
+
parser.add_argument('--epochs', default=50, type=int, metavar='N',
|
56 |
+
help='number of total epochs to run')
|
57 |
+
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
58 |
+
help='manual epoch number (useful on restarts)')
|
59 |
+
parser.add_argument('-b', '--batch-size', default=8, type=int,
|
60 |
+
metavar='N',
|
61 |
+
help='mini-batch size (default: 256), this is the total '
|
62 |
+
'batch size of all GPUs on the current node when '
|
63 |
+
'using Data Parallel or Distributed Data Parallel')
|
64 |
+
parser.add_argument('--lr', '--learning-rate', default=3e-6, type=float,
|
65 |
+
metavar='LR', help='initial learning rate', dest='lr')
|
66 |
+
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
67 |
+
help='momentum')
|
68 |
+
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
69 |
+
metavar='W', help='weight decay (default: 1e-4)',
|
70 |
+
dest='weight_decay')
|
71 |
+
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
72 |
+
metavar='N', help='print frequency (default: 10)')
|
73 |
+
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
74 |
+
help='path to latest checkpoint (default: none)')
|
75 |
+
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
76 |
+
help='evaluate model on validation set')
|
77 |
+
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
78 |
+
help='use pre-trained model')
|
79 |
+
parser.add_argument('--world-size', default=-1, type=int,
|
80 |
+
help='number of nodes for distributed training')
|
81 |
+
parser.add_argument('--rank', default=-1, type=int,
|
82 |
+
help='node rank for distributed training')
|
83 |
+
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
84 |
+
help='url used to set up distributed training')
|
85 |
+
parser.add_argument('--dist-backend', default='nccl', type=str,
|
86 |
+
help='distributed backend')
|
87 |
+
parser.add_argument('--gpu', default=None, type=int,
|
88 |
+
help='GPU id to use.')
|
89 |
+
parser.add_argument('--save_interval', default=20, type=int,
|
90 |
+
help='interval to save segmentation results.')
|
91 |
+
parser.add_argument('--num_samples', default=3, type=int,
|
92 |
+
help='number of samples per class for training')
|
93 |
+
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
94 |
+
help='Use multi-processing distributed training to launch '
|
95 |
+
'N processes per node, which has N GPUs. This is the '
|
96 |
+
'fastest way to use PyTorch for either single node or '
|
97 |
+
'multi node data parallel training')
|
98 |
+
parser.add_argument('--lambda_seg', default=0.8, type=float,
|
99 |
+
help='influence of segmentation loss.')
|
100 |
+
parser.add_argument('--lambda_acc', default=0.2, type=float,
|
101 |
+
help='influence of accuracy loss.')
|
102 |
+
parser.add_argument('--experiment_folder', default=None, type=str,
|
103 |
+
help='path to folder to use for experiment.')
|
104 |
+
parser.add_argument('--dilation', default=0, type=float,
|
105 |
+
help='Use dilation on the segmentation maps.')
|
106 |
+
parser.add_argument('--lambda_background', default=2, type=float,
|
107 |
+
help='coefficient of loss for segmentation background.')
|
108 |
+
parser.add_argument('--lambda_foreground', default=0.3, type=float,
|
109 |
+
help='coefficient of loss for segmentation foreground.')
|
110 |
+
parser.add_argument('--num_classes', default=500, type=int,
|
111 |
+
help='coefficient of loss for segmentation foreground.')
|
112 |
+
parser.add_argument('--temperature', default=1, type=float,
|
113 |
+
help='temperature for softmax (mostly for DeiT).')
|
114 |
+
parser.add_argument('--class_seed', default=None, type=int,
|
115 |
+
help='seed to randomly shuffle classes chosen for training.')
|
116 |
+
|
117 |
+
best_loss = float('inf')
|
118 |
+
|
119 |
+
def main():
|
120 |
+
args = parser.parse_args()
|
121 |
+
|
122 |
+
if args.experiment_folder is None:
|
123 |
+
args.experiment_folder = f'experiment/' \
|
124 |
+
f'lr_{args.lr}_seg_{args.lambda_seg}_acc_{args.lambda_acc}' \
|
125 |
+
f'_bckg_{args.lambda_background}_fgd_{args.lambda_foreground}'
|
126 |
+
if args.temperature != 1:
|
127 |
+
args.experiment_folder = args.experiment_folder + f'_tempera_{args.temperature}'
|
128 |
+
if args.batch_size != 8:
|
129 |
+
args.experiment_folder = args.experiment_folder + f'_bs_{args.batch_size}'
|
130 |
+
if args.num_classes != 500:
|
131 |
+
args.experiment_folder = args.experiment_folder + f'_num_classes_{args.num_classes}'
|
132 |
+
if args.num_samples != 3:
|
133 |
+
args.experiment_folder = args.experiment_folder + f'_num_samples_{args.num_samples}'
|
134 |
+
if args.epochs != 150:
|
135 |
+
args.experiment_folder = args.experiment_folder + f'_num_epochs_{args.epochs}'
|
136 |
+
if args.class_seed is not None:
|
137 |
+
args.experiment_folder = args.experiment_folder + f'_seed_{args.class_seed}'
|
138 |
+
|
139 |
+
if os.path.exists(args.experiment_folder):
|
140 |
+
raise Exception(f"Experiment path {args.experiment_folder} already exists!")
|
141 |
+
os.mkdir(args.experiment_folder)
|
142 |
+
os.mkdir(f'{args.experiment_folder}/train_samples')
|
143 |
+
os.mkdir(f'{args.experiment_folder}/val_samples')
|
144 |
+
|
145 |
+
with open(f'{args.experiment_folder}/commandline_args.txt', 'w') as f:
|
146 |
+
json.dump(args.__dict__, f, indent=2)
|
147 |
+
|
148 |
+
if args.gpu is not None:
|
149 |
+
warnings.warn('You have chosen a specific GPU. This will completely '
|
150 |
+
'disable data parallelism.')
|
151 |
+
|
152 |
+
if args.dist_url == "env://" and args.world_size == -1:
|
153 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
154 |
+
|
155 |
+
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
156 |
+
|
157 |
+
ngpus_per_node = torch.cuda.device_count()
|
158 |
+
if args.multiprocessing_distributed:
|
159 |
+
# Since we have ngpus_per_node processes per node, the total world_size
|
160 |
+
# needs to be adjusted accordingly
|
161 |
+
args.world_size = ngpus_per_node * args.world_size
|
162 |
+
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
163 |
+
# main_worker process function
|
164 |
+
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
165 |
+
else:
|
166 |
+
# Simply call main_worker function
|
167 |
+
main_worker(args.gpu, ngpus_per_node, args)
|
168 |
+
|
169 |
+
|
170 |
+
def main_worker(gpu, ngpus_per_node, args):
|
171 |
+
global best_loss
|
172 |
+
args.gpu = gpu
|
173 |
+
|
174 |
+
if args.gpu is not None:
|
175 |
+
print("Use GPU: {} for training".format(args.gpu))
|
176 |
+
|
177 |
+
if args.distributed:
|
178 |
+
if args.dist_url == "env://" and args.rank == -1:
|
179 |
+
args.rank = int(os.environ["RANK"])
|
180 |
+
if args.multiprocessing_distributed:
|
181 |
+
# For multiprocessing distributed training, rank needs to be the
|
182 |
+
# global rank among all the processes
|
183 |
+
args.rank = args.rank * ngpus_per_node + gpu
|
184 |
+
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
185 |
+
world_size=args.world_size, rank=args.rank)
|
186 |
+
# create model
|
187 |
+
print("=> creating model")
|
188 |
+
model = vit(pretrained=True).cuda()
|
189 |
+
model.train()
|
190 |
+
print("done")
|
191 |
+
|
192 |
+
if not torch.cuda.is_available():
|
193 |
+
print('using CPU, this will be slow')
|
194 |
+
elif args.distributed:
|
195 |
+
# For multiprocessing distributed, DistributedDataParallel constructor
|
196 |
+
# should always set the single device scope, otherwise,
|
197 |
+
# DistributedDataParallel will use all available devices.
|
198 |
+
if args.gpu is not None:
|
199 |
+
torch.cuda.set_device(args.gpu)
|
200 |
+
model.cuda(args.gpu)
|
201 |
+
# When using a single GPU per process and per
|
202 |
+
# DistributedDataParallel, we need to divide the batch size
|
203 |
+
# ourselves based on the total number of GPUs we have
|
204 |
+
args.batch_size = int(args.batch_size / ngpus_per_node)
|
205 |
+
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
206 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
207 |
+
else:
|
208 |
+
model.cuda()
|
209 |
+
# DistributedDataParallel will divide and allocate batch_size to all
|
210 |
+
# available GPUs if device_ids are not set
|
211 |
+
model = torch.nn.parallel.DistributedDataParallel(model)
|
212 |
+
elif args.gpu is not None:
|
213 |
+
torch.cuda.set_device(args.gpu)
|
214 |
+
model = model.cuda(args.gpu)
|
215 |
+
else:
|
216 |
+
# DataParallel will divide and allocate batch_size to all available GPUs
|
217 |
+
print("start")
|
218 |
+
model = torch.nn.DataParallel(model).cuda()
|
219 |
+
|
220 |
+
# define loss function (criterion) and optimizer
|
221 |
+
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
|
222 |
+
optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.weight_decay)
|
223 |
+
|
224 |
+
# optionally resume from a checkpoint
|
225 |
+
if args.resume:
|
226 |
+
if os.path.isfile(args.resume):
|
227 |
+
print("=> loading checkpoint '{}'".format(args.resume))
|
228 |
+
if args.gpu is None:
|
229 |
+
checkpoint = torch.load(args.resume)
|
230 |
+
else:
|
231 |
+
# Map model to be loaded to specified single gpu.
|
232 |
+
loc = 'cuda:{}'.format(args.gpu)
|
233 |
+
checkpoint = torch.load(args.resume, map_location=loc)
|
234 |
+
args.start_epoch = checkpoint['epoch']
|
235 |
+
best_loss = checkpoint['best_loss']
|
236 |
+
if args.gpu is not None:
|
237 |
+
# best_loss may be from a checkpoint from a different GPU
|
238 |
+
best_loss = best_loss.to(args.gpu)
|
239 |
+
model.load_state_dict(checkpoint['state_dict'])
|
240 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
241 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
242 |
+
.format(args.resume, checkpoint['epoch']))
|
243 |
+
else:
|
244 |
+
print("=> no checkpoint found at '{}'".format(args.resume))
|
245 |
+
|
246 |
+
cudnn.benchmark = True
|
247 |
+
|
248 |
+
train_dataset = SegmentationDataset(args.seg_data, args.data, partition=TRAIN_PARTITION, train_classes=args.num_classes,
|
249 |
+
num_samples=args.num_samples, seed=args.class_seed)
|
250 |
+
|
251 |
+
if args.distributed:
|
252 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
253 |
+
else:
|
254 |
+
train_sampler = None
|
255 |
+
|
256 |
+
train_loader = torch.utils.data.DataLoader(
|
257 |
+
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
258 |
+
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
|
259 |
+
|
260 |
+
val_dataset = SegmentationDataset(args.seg_data, args.data, partition=VAL_PARTITION, train_classes=args.num_classes,
|
261 |
+
num_samples=1, seed=args.class_seed)
|
262 |
+
|
263 |
+
val_loader = torch.utils.data.DataLoader(
|
264 |
+
val_dataset, batch_size=10, shuffle=False,
|
265 |
+
num_workers=args.workers, pin_memory=True)
|
266 |
+
|
267 |
+
if args.evaluate:
|
268 |
+
validate(val_loader, model, criterion, 0, args)
|
269 |
+
return
|
270 |
+
|
271 |
+
for epoch in range(args.start_epoch, args.epochs):
|
272 |
+
if args.distributed:
|
273 |
+
train_sampler.set_epoch(epoch)
|
274 |
+
adjust_learning_rate(optimizer, epoch, args)
|
275 |
+
|
276 |
+
log_dir = os.path.join(args.experiment_folder, 'logs')
|
277 |
+
logger = SummaryWriter(log_dir=log_dir)
|
278 |
+
args.logger = logger
|
279 |
+
|
280 |
+
# train for one epoch
|
281 |
+
train(train_loader, model, criterion, optimizer, epoch, args)
|
282 |
+
|
283 |
+
# evaluate on validation set
|
284 |
+
loss1 = validate(val_loader, model, criterion, epoch, args)
|
285 |
+
|
286 |
+
# remember best acc@1 and save checkpoint
|
287 |
+
is_best = loss1 <= best_loss
|
288 |
+
best_loss = min(loss1, best_loss)
|
289 |
+
|
290 |
+
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
291 |
+
and args.rank % ngpus_per_node == 0):
|
292 |
+
save_checkpoint({
|
293 |
+
'epoch': epoch + 1,
|
294 |
+
'state_dict': model.state_dict(),
|
295 |
+
'best_loss': best_loss,
|
296 |
+
'optimizer' : optimizer.state_dict(),
|
297 |
+
}, is_best, folder=args.experiment_folder)
|
298 |
+
|
299 |
+
|
300 |
+
def train(train_loader, model, criterion, optimizer, epoch, args):
|
301 |
+
mse_criterion = torch.nn.MSELoss(reduction='mean')
|
302 |
+
|
303 |
+
losses = AverageMeter('Loss', ':.4e')
|
304 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
305 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
306 |
+
orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
|
307 |
+
orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
|
308 |
+
progress = ProgressMeter(
|
309 |
+
len(train_loader),
|
310 |
+
[losses, top1, top5, orig_top1, orig_top5],
|
311 |
+
prefix="Epoch: [{}]".format(epoch))
|
312 |
+
|
313 |
+
orig_model = vit(pretrained=True).cuda()
|
314 |
+
orig_model.eval()
|
315 |
+
|
316 |
+
# switch to train mode
|
317 |
+
model.train()
|
318 |
+
|
319 |
+
for i, (seg_map, image_ten, class_name) in enumerate(train_loader):
|
320 |
+
if torch.cuda.is_available():
|
321 |
+
image_ten = image_ten.cuda(args.gpu, non_blocking=True)
|
322 |
+
seg_map = seg_map.cuda(args.gpu, non_blocking=True)
|
323 |
+
class_name = class_name.cuda(args.gpu, non_blocking=True)
|
324 |
+
|
325 |
+
# segmentation loss
|
326 |
+
relevance = generate_relevance(model, image_ten, index=class_name)
|
327 |
+
|
328 |
+
reverse_seg_map = seg_map.clone()
|
329 |
+
reverse_seg_map[reverse_seg_map == 1] = -1
|
330 |
+
reverse_seg_map[reverse_seg_map == 0] = 1
|
331 |
+
reverse_seg_map[reverse_seg_map == -1] = 0
|
332 |
+
background_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
|
333 |
+
foreground_loss = mse_criterion(relevance * seg_map, seg_map)
|
334 |
+
segmentation_loss = args.lambda_background * background_loss
|
335 |
+
segmentation_loss += args.lambda_foreground * foreground_loss
|
336 |
+
|
337 |
+
# classification loss
|
338 |
+
output = model(image_ten)
|
339 |
+
with torch.no_grad():
|
340 |
+
output_orig = orig_model(image_ten)
|
341 |
+
|
342 |
+
_, pred = output.topk(1, 1, True, True)
|
343 |
+
pred = pred.flatten()
|
344 |
+
|
345 |
+
if args.temperature != 1:
|
346 |
+
output = output / args.temperature
|
347 |
+
classification_loss = criterion(output, pred)
|
348 |
+
|
349 |
+
loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
|
350 |
+
|
351 |
+
# debugging output
|
352 |
+
if i % args.save_interval == 0:
|
353 |
+
orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
|
354 |
+
for j in range(image_ten.shape[0]):
|
355 |
+
image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
|
356 |
+
new_vis = get_image_with_relevance(image_ten[j], relevance[j])
|
357 |
+
old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
|
358 |
+
gt = get_image_with_relevance(image_ten[j], seg_map[j])
|
359 |
+
h_img = cv2.hconcat([image, gt, old_vis, new_vis])
|
360 |
+
cv2.imwrite(f'{args.experiment_folder}/train_samples/res_{i}_{j}.jpg', h_img)
|
361 |
+
|
362 |
+
# measure accuracy and record loss
|
363 |
+
acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
|
364 |
+
losses.update(loss.item(), image_ten.size(0))
|
365 |
+
top1.update(acc1[0], image_ten.size(0))
|
366 |
+
top5.update(acc5[0], image_ten.size(0))
|
367 |
+
|
368 |
+
# metrics for original vit
|
369 |
+
acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
|
370 |
+
orig_top1.update(acc1_orig[0], image_ten.size(0))
|
371 |
+
orig_top5.update(acc5_orig[0], image_ten.size(0))
|
372 |
+
|
373 |
+
# compute gradient and do SGD step
|
374 |
+
optimizer.zero_grad()
|
375 |
+
loss.backward()
|
376 |
+
optimizer.step()
|
377 |
+
|
378 |
+
if i % args.print_freq == 0:
|
379 |
+
progress.display(i)
|
380 |
+
args.logger.add_scalar('{}/{}'.format('train', 'segmentation_loss'), segmentation_loss,
|
381 |
+
epoch*len(train_loader)+i)
|
382 |
+
args.logger.add_scalar('{}/{}'.format('train', 'classification_loss'), classification_loss,
|
383 |
+
epoch * len(train_loader) + i)
|
384 |
+
args.logger.add_scalar('{}/{}'.format('train', 'orig_top1'), acc1_orig,
|
385 |
+
epoch * len(train_loader) + i)
|
386 |
+
args.logger.add_scalar('{}/{}'.format('train', 'top1'), acc1,
|
387 |
+
epoch * len(train_loader) + i)
|
388 |
+
args.logger.add_scalar('{}/{}'.format('train', 'orig_top5'), acc5_orig,
|
389 |
+
epoch * len(train_loader) + i)
|
390 |
+
args.logger.add_scalar('{}/{}'.format('train', 'top5'), acc5,
|
391 |
+
epoch * len(train_loader) + i)
|
392 |
+
args.logger.add_scalar('{}/{}'.format('train', 'tot_loss'), loss,
|
393 |
+
epoch * len(train_loader) + i)
|
394 |
+
|
395 |
+
|
396 |
+
def validate(val_loader, model, criterion, epoch, args):
|
397 |
+
mse_criterion = torch.nn.MSELoss(reduction='mean')
|
398 |
+
|
399 |
+
losses = AverageMeter('Loss', ':.4e')
|
400 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
401 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
402 |
+
orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
|
403 |
+
orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
|
404 |
+
progress = ProgressMeter(
|
405 |
+
len(val_loader),
|
406 |
+
[losses, top1, top5, orig_top1, orig_top5],
|
407 |
+
prefix="Epoch: [{}]".format(val_loader))
|
408 |
+
|
409 |
+
# switch to evaluate mode
|
410 |
+
model.eval()
|
411 |
+
|
412 |
+
orig_model = vit(pretrained=True).cuda()
|
413 |
+
orig_model.eval()
|
414 |
+
|
415 |
+
with torch.no_grad():
|
416 |
+
for i, (seg_map, image_ten, class_name) in enumerate(val_loader):
|
417 |
+
if args.gpu is not None:
|
418 |
+
image_ten = image_ten.cuda(args.gpu, non_blocking=True)
|
419 |
+
if torch.cuda.is_available():
|
420 |
+
seg_map = seg_map.cuda(args.gpu, non_blocking=True)
|
421 |
+
class_name = class_name.cuda(args.gpu, non_blocking=True)
|
422 |
+
|
423 |
+
# segmentation loss
|
424 |
+
with torch.enable_grad():
|
425 |
+
relevance = generate_relevance(model, image_ten, index=class_name)
|
426 |
+
|
427 |
+
reverse_seg_map = seg_map.clone()
|
428 |
+
reverse_seg_map[reverse_seg_map == 1] = -1
|
429 |
+
reverse_seg_map[reverse_seg_map == 0] = 1
|
430 |
+
reverse_seg_map[reverse_seg_map == -1] = 0
|
431 |
+
background_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
|
432 |
+
foreground_loss = mse_criterion(relevance * seg_map, seg_map)
|
433 |
+
segmentation_loss = args.lambda_background * background_loss
|
434 |
+
segmentation_loss += args.lambda_foreground * foreground_loss
|
435 |
+
|
436 |
+
# classification loss
|
437 |
+
with torch.no_grad():
|
438 |
+
output = model(image_ten)
|
439 |
+
output_orig = orig_model(image_ten)
|
440 |
+
|
441 |
+
_, pred = output.topk(1, 1, True, True)
|
442 |
+
pred = pred.flatten()
|
443 |
+
if args.temperature != 1:
|
444 |
+
output = output / args.temperature
|
445 |
+
classification_loss = criterion(output, pred)
|
446 |
+
|
447 |
+
loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
|
448 |
+
|
449 |
+
# save results
|
450 |
+
if i % args.save_interval == 0:
|
451 |
+
with torch.enable_grad():
|
452 |
+
orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
|
453 |
+
for j in range(image_ten.shape[0]):
|
454 |
+
image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
|
455 |
+
new_vis = get_image_with_relevance(image_ten[j], relevance[j])
|
456 |
+
old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
|
457 |
+
gt = get_image_with_relevance(image_ten[j], seg_map[j])
|
458 |
+
h_img = cv2.hconcat([image, gt, old_vis, new_vis])
|
459 |
+
cv2.imwrite(f'{args.experiment_folder}/val_samples/res_{i}_{j}.jpg', h_img)
|
460 |
+
|
461 |
+
# measure accuracy and record loss
|
462 |
+
acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
|
463 |
+
losses.update(loss.item(), image_ten.size(0))
|
464 |
+
top1.update(acc1[0], image_ten.size(0))
|
465 |
+
top5.update(acc5[0], image_ten.size(0))
|
466 |
+
|
467 |
+
# metrics for original vit
|
468 |
+
acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
|
469 |
+
orig_top1.update(acc1_orig[0], image_ten.size(0))
|
470 |
+
orig_top5.update(acc5_orig[0], image_ten.size(0))
|
471 |
+
|
472 |
+
if i % args.print_freq == 0:
|
473 |
+
progress.display(i)
|
474 |
+
args.logger.add_scalar('{}/{}'.format('val', 'segmentation_loss'), segmentation_loss,
|
475 |
+
epoch * len(val_loader) + i)
|
476 |
+
args.logger.add_scalar('{}/{}'.format('val', 'classification_loss'), classification_loss,
|
477 |
+
epoch * len(val_loader) + i)
|
478 |
+
args.logger.add_scalar('{}/{}'.format('val', 'orig_top1'), acc1_orig,
|
479 |
+
epoch * len(val_loader) + i)
|
480 |
+
args.logger.add_scalar('{}/{}'.format('val', 'top1'), acc1,
|
481 |
+
epoch * len(val_loader) + i)
|
482 |
+
args.logger.add_scalar('{}/{}'.format('val', 'orig_top5'), acc5_orig,
|
483 |
+
epoch * len(val_loader) + i)
|
484 |
+
args.logger.add_scalar('{}/{}'.format('val', 'top5'), acc5,
|
485 |
+
epoch * len(val_loader) + i)
|
486 |
+
args.logger.add_scalar('{}/{}'.format('val', 'tot_loss'), loss,
|
487 |
+
epoch * len(val_loader) + i)
|
488 |
+
|
489 |
+
# TODO: this should also be done with the ProgressMeter
|
490 |
+
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
|
491 |
+
.format(top1=top1, top5=top5))
|
492 |
+
|
493 |
+
return losses.avg
|
494 |
+
|
495 |
+
|
496 |
+
def save_checkpoint(state, is_best, folder, filename='checkpoint.pth.tar'):
|
497 |
+
torch.save(state, f'{folder}/{filename}')
|
498 |
+
if is_best:
|
499 |
+
shutil.copyfile(f'{folder}/{filename}', f'{folder}/model_best.pth.tar')
|
500 |
+
|
501 |
+
|
502 |
+
class AverageMeter(object):
|
503 |
+
"""Computes and stores the average and current value"""
|
504 |
+
def __init__(self, name, fmt=':f'):
|
505 |
+
self.name = name
|
506 |
+
self.fmt = fmt
|
507 |
+
self.reset()
|
508 |
+
|
509 |
+
def reset(self):
|
510 |
+
self.val = 0
|
511 |
+
self.avg = 0
|
512 |
+
self.sum = 0
|
513 |
+
self.count = 0
|
514 |
+
|
515 |
+
def update(self, val, n=1):
|
516 |
+
self.val = val
|
517 |
+
self.sum += val * n
|
518 |
+
self.count += n
|
519 |
+
self.avg = self.sum / self.count
|
520 |
+
|
521 |
+
def __str__(self):
|
522 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
523 |
+
return fmtstr.format(**self.__dict__)
|
524 |
+
|
525 |
+
|
526 |
+
class ProgressMeter(object):
|
527 |
+
def __init__(self, num_batches, meters, prefix=""):
|
528 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
529 |
+
self.meters = meters
|
530 |
+
self.prefix = prefix
|
531 |
+
|
532 |
+
def display(self, batch):
|
533 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
534 |
+
entries += [str(meter) for meter in self.meters]
|
535 |
+
print('\t'.join(entries))
|
536 |
+
|
537 |
+
def _get_batch_fmtstr(self, num_batches):
|
538 |
+
num_digits = len(str(num_batches // 1))
|
539 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
540 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
541 |
+
|
542 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
543 |
+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
544 |
+
lr = args.lr * (0.85 ** (epoch // 2))
|
545 |
+
for param_group in optimizer.param_groups:
|
546 |
+
param_group['lr'] = lr
|
547 |
+
|
548 |
+
|
549 |
+
def accuracy(output, target, topk=(1,)):
|
550 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
551 |
+
with torch.no_grad():
|
552 |
+
maxk = max(topk)
|
553 |
+
batch_size = target.size(0)
|
554 |
+
|
555 |
+
_, pred = output.topk(maxk, 1, True, True)
|
556 |
+
pred = pred.t()
|
557 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
558 |
+
|
559 |
+
res = []
|
560 |
+
for k in topk:
|
561 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
562 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
563 |
+
return res
|
564 |
+
|
565 |
+
|
566 |
+
if __name__ == '__main__':
|
567 |
+
main()
|
imagenet_finetune_gradmask.py
ADDED
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import shutil
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.parallel
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
import torch.distributed as dist
|
13 |
+
import torch.optim
|
14 |
+
import torch.multiprocessing as mp
|
15 |
+
import torch.utils.data
|
16 |
+
import torch.utils.data.distributed
|
17 |
+
import torchvision.transforms as transforms
|
18 |
+
import torchvision.datasets as datasets
|
19 |
+
import torchvision.models as models
|
20 |
+
from segmentation_dataset import SegmentationDataset, VAL_PARTITION, TRAIN_PARTITION
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
# Uncomment the expected model below
|
24 |
+
|
25 |
+
# ViT
|
26 |
+
from ViT.ViT import vit_base_patch16_224 as vit
|
27 |
+
# from ViT.ViT import vit_large_patch16_224 as vit
|
28 |
+
|
29 |
+
# ViT-AugReg
|
30 |
+
# from ViT.ViT_new import vit_small_patch16_224 as vit
|
31 |
+
# from ViT.ViT_new import vit_base_patch16_224 as vit
|
32 |
+
# from ViT.ViT_new import vit_large_patch16_224 as vit
|
33 |
+
|
34 |
+
# DeiT
|
35 |
+
# from ViT.ViT import deit_base_patch16_224 as vit
|
36 |
+
# from ViT.ViT import deit_small_patch16_224 as vit
|
37 |
+
|
38 |
+
from ViT.explainer import generate_relevance, get_image_with_relevance
|
39 |
+
import torchvision
|
40 |
+
import cv2
|
41 |
+
from torch.utils.tensorboard import SummaryWriter
|
42 |
+
import json
|
43 |
+
|
44 |
+
model_names = sorted(name for name in models.__dict__
|
45 |
+
if name.islower() and not name.startswith("__")
|
46 |
+
and callable(models.__dict__[name]))
|
47 |
+
model_names.append("vit")
|
48 |
+
|
49 |
+
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
50 |
+
parser.add_argument('--data', metavar='DATA',
|
51 |
+
help='path to dataset')
|
52 |
+
parser.add_argument('--seg_data', metavar='SEG_DATA',
|
53 |
+
help='path to segmentation dataset')
|
54 |
+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
55 |
+
help='number of data loading workers (default: 4)')
|
56 |
+
parser.add_argument('--epochs', default=50, type=int, metavar='N',
|
57 |
+
help='number of total epochs to run')
|
58 |
+
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
59 |
+
help='manual epoch number (useful on restarts)')
|
60 |
+
parser.add_argument('-b', '--batch-size', default=8, type=int,
|
61 |
+
metavar='N',
|
62 |
+
help='mini-batch size (default: 256), this is the total '
|
63 |
+
'batch size of all GPUs on the current node when '
|
64 |
+
'using Data Parallel or Distributed Data Parallel')
|
65 |
+
parser.add_argument('--lr', '--learning-rate', default=3e-6, type=float,
|
66 |
+
metavar='LR', help='initial learning rate', dest='lr')
|
67 |
+
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
68 |
+
help='momentum')
|
69 |
+
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
70 |
+
metavar='W', help='weight decay (default: 1e-4)',
|
71 |
+
dest='weight_decay')
|
72 |
+
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
73 |
+
metavar='N', help='print frequency (default: 10)')
|
74 |
+
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
75 |
+
help='path to latest checkpoint (default: none)')
|
76 |
+
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
77 |
+
help='evaluate model on validation set')
|
78 |
+
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
79 |
+
help='use pre-trained model')
|
80 |
+
parser.add_argument('--world-size', default=-1, type=int,
|
81 |
+
help='number of nodes for distributed training')
|
82 |
+
parser.add_argument('--rank', default=-1, type=int,
|
83 |
+
help='node rank for distributed training')
|
84 |
+
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
85 |
+
help='url used to set up distributed training')
|
86 |
+
parser.add_argument('--dist-backend', default='nccl', type=str,
|
87 |
+
help='distributed backend')
|
88 |
+
parser.add_argument('--seed', default=None, type=int,
|
89 |
+
help='seed for initializing training. ')
|
90 |
+
parser.add_argument('--gpu', default=None, type=int,
|
91 |
+
help='GPU id to use.')
|
92 |
+
parser.add_argument('--save_interval', default=20, type=int,
|
93 |
+
help='interval to save segmentation results.')
|
94 |
+
parser.add_argument('--num_samples', default=3, type=int,
|
95 |
+
help='number of samples per class for training')
|
96 |
+
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
97 |
+
help='Use multi-processing distributed training to launch '
|
98 |
+
'N processes per node, which has N GPUs. This is the '
|
99 |
+
'fastest way to use PyTorch for either single node or '
|
100 |
+
'multi node data parallel training')
|
101 |
+
parser.add_argument('--lambda_seg', default=0.8, type=float,
|
102 |
+
help='influence of segmentation loss.')
|
103 |
+
parser.add_argument('--lambda_acc', default=0.2, type=float,
|
104 |
+
help='influence of accuracy loss.')
|
105 |
+
parser.add_argument('--experiment_folder', default=None, type=str,
|
106 |
+
help='path to folder to use for experiment.')
|
107 |
+
parser.add_argument('--num_classes', default=500, type=int,
|
108 |
+
help='coefficient of loss for segmentation foreground.')
|
109 |
+
parser.add_argument('--temperature', default=1, type=float,
|
110 |
+
help='temperature for softmax (mostly for DeiT).')
|
111 |
+
|
112 |
+
best_loss = float('inf')
|
113 |
+
|
114 |
+
def main():
|
115 |
+
args = parser.parse_args()
|
116 |
+
|
117 |
+
if args.experiment_folder is None:
|
118 |
+
args.experiment_folder = f'experiment/' \
|
119 |
+
f'lr_{args.lr}_seg_{args.lambda_seg}_acc_{args.lambda_acc}'
|
120 |
+
if args.temperature != 1:
|
121 |
+
args.experiment_folder = args.experiment_folder + f'_tempera_{args.temperature}'
|
122 |
+
if args.batch_size != 10:
|
123 |
+
args.experiment_folder = args.experiment_folder + f'_bs_{args.batch_size}'
|
124 |
+
if args.num_classes != 500:
|
125 |
+
args.experiment_folder = args.experiment_folder + f'_num_classes_{args.num_classes}'
|
126 |
+
if args.num_samples != 3:
|
127 |
+
args.experiment_folder = args.experiment_folder + f'_num_samples_{args.num_samples}'
|
128 |
+
if args.epochs != 150:
|
129 |
+
args.experiment_folder = args.experiment_folder + f'_num_epochs_{args.epochs}'
|
130 |
+
|
131 |
+
if os.path.exists(args.experiment_folder):
|
132 |
+
raise Exception(f"Experiment path {args.experiment_folder} already exists!")
|
133 |
+
os.mkdir(args.experiment_folder)
|
134 |
+
os.mkdir(f'{args.experiment_folder}/train_samples')
|
135 |
+
os.mkdir(f'{args.experiment_folder}/val_samples')
|
136 |
+
|
137 |
+
with open(f'{args.experiment_folder}/commandline_args.txt', 'w') as f:
|
138 |
+
json.dump(args.__dict__, f, indent=2)
|
139 |
+
|
140 |
+
if args.seed is not None:
|
141 |
+
random.seed(args.seed)
|
142 |
+
torch.manual_seed(args.seed)
|
143 |
+
cudnn.deterministic = True
|
144 |
+
warnings.warn('You have chosen to seed training. '
|
145 |
+
'This will turn on the CUDNN deterministic setting, '
|
146 |
+
'which can slow down your training considerably! '
|
147 |
+
'You may see unexpected behavior when restarting '
|
148 |
+
'from checkpoints.')
|
149 |
+
|
150 |
+
if args.gpu is not None:
|
151 |
+
warnings.warn('You have chosen a specific GPU. This will completely '
|
152 |
+
'disable data parallelism.')
|
153 |
+
|
154 |
+
if args.dist_url == "env://" and args.world_size == -1:
|
155 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
156 |
+
|
157 |
+
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
158 |
+
|
159 |
+
ngpus_per_node = torch.cuda.device_count()
|
160 |
+
if args.multiprocessing_distributed:
|
161 |
+
# Since we have ngpus_per_node processes per node, the total world_size
|
162 |
+
# needs to be adjusted accordingly
|
163 |
+
args.world_size = ngpus_per_node * args.world_size
|
164 |
+
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
165 |
+
# main_worker process function
|
166 |
+
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
167 |
+
else:
|
168 |
+
# Simply call main_worker function
|
169 |
+
main_worker(args.gpu, ngpus_per_node, args)
|
170 |
+
|
171 |
+
|
172 |
+
def main_worker(gpu, ngpus_per_node, args):
|
173 |
+
global best_loss
|
174 |
+
args.gpu = gpu
|
175 |
+
|
176 |
+
if args.gpu is not None:
|
177 |
+
print("Use GPU: {} for training".format(args.gpu))
|
178 |
+
|
179 |
+
if args.distributed:
|
180 |
+
if args.dist_url == "env://" and args.rank == -1:
|
181 |
+
args.rank = int(os.environ["RANK"])
|
182 |
+
if args.multiprocessing_distributed:
|
183 |
+
# For multiprocessing distributed training, rank needs to be the
|
184 |
+
# global rank among all the processes
|
185 |
+
args.rank = args.rank * ngpus_per_node + gpu
|
186 |
+
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
187 |
+
world_size=args.world_size, rank=args.rank)
|
188 |
+
# create model
|
189 |
+
print("=> creating model")
|
190 |
+
model = vit(pretrained=True).cuda()
|
191 |
+
model.train()
|
192 |
+
print("done")
|
193 |
+
|
194 |
+
if not torch.cuda.is_available():
|
195 |
+
print('using CPU, this will be slow')
|
196 |
+
elif args.distributed:
|
197 |
+
# For multiprocessing distributed, DistributedDataParallel constructor
|
198 |
+
# should always set the single device scope, otherwise,
|
199 |
+
# DistributedDataParallel will use all available devices.
|
200 |
+
if args.gpu is not None:
|
201 |
+
torch.cuda.set_device(args.gpu)
|
202 |
+
model.cuda(args.gpu)
|
203 |
+
# When using a single GPU per process and per
|
204 |
+
# DistributedDataParallel, we need to divide the batch size
|
205 |
+
# ourselves based on the total number of GPUs we have
|
206 |
+
args.batch_size = int(args.batch_size / ngpus_per_node)
|
207 |
+
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
208 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
209 |
+
else:
|
210 |
+
model.cuda()
|
211 |
+
# DistributedDataParallel will divide and allocate batch_size to all
|
212 |
+
# available GPUs if device_ids are not set
|
213 |
+
model = torch.nn.parallel.DistributedDataParallel(model)
|
214 |
+
elif args.gpu is not None:
|
215 |
+
torch.cuda.set_device(args.gpu)
|
216 |
+
model = model.cuda(args.gpu)
|
217 |
+
else:
|
218 |
+
# DataParallel will divide and allocate batch_size to all available GPUs
|
219 |
+
print("start")
|
220 |
+
model = torch.nn.DataParallel(model).cuda()
|
221 |
+
|
222 |
+
# define loss function (criterion) and optimizer
|
223 |
+
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
|
224 |
+
optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.weight_decay)
|
225 |
+
|
226 |
+
# optionally resume from a checkpoint
|
227 |
+
if args.resume:
|
228 |
+
if os.path.isfile(args.resume):
|
229 |
+
print("=> loading checkpoint '{}'".format(args.resume))
|
230 |
+
if args.gpu is None:
|
231 |
+
checkpoint = torch.load(args.resume)
|
232 |
+
else:
|
233 |
+
# Map model to be loaded to specified single gpu.
|
234 |
+
loc = 'cuda:{}'.format(args.gpu)
|
235 |
+
checkpoint = torch.load(args.resume, map_location=loc)
|
236 |
+
args.start_epoch = checkpoint['epoch']
|
237 |
+
best_loss = checkpoint['best_loss']
|
238 |
+
if args.gpu is not None:
|
239 |
+
# best_loss may be from a checkpoint from a different GPU
|
240 |
+
best_loss = best_loss.to(args.gpu)
|
241 |
+
model.load_state_dict(checkpoint['state_dict'])
|
242 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
243 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
244 |
+
.format(args.resume, checkpoint['epoch']))
|
245 |
+
else:
|
246 |
+
print("=> no checkpoint found at '{}'".format(args.resume))
|
247 |
+
|
248 |
+
cudnn.benchmark = True
|
249 |
+
|
250 |
+
train_dataset = SegmentationDataset(args.seg_data, args.data, partition=TRAIN_PARTITION, train_classes=args.num_classes,
|
251 |
+
num_samples=args.num_samples)
|
252 |
+
|
253 |
+
if args.distributed:
|
254 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
255 |
+
else:
|
256 |
+
train_sampler = None
|
257 |
+
|
258 |
+
train_loader = torch.utils.data.DataLoader(
|
259 |
+
train_dataset, batch_size=args.batch_size, shuffle=False,
|
260 |
+
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
|
261 |
+
|
262 |
+
val_dataset = SegmentationDataset(args.seg_data, args.data, partition=VAL_PARTITION, train_classes=args.num_classes,
|
263 |
+
num_samples=1)
|
264 |
+
|
265 |
+
val_loader = torch.utils.data.DataLoader(
|
266 |
+
val_dataset, batch_size=5, shuffle=False,
|
267 |
+
num_workers=args.workers, pin_memory=True)
|
268 |
+
|
269 |
+
if args.evaluate:
|
270 |
+
validate(val_loader, model, criterion, 0, args)
|
271 |
+
return
|
272 |
+
|
273 |
+
for epoch in range(args.start_epoch, args.epochs):
|
274 |
+
if args.distributed:
|
275 |
+
train_sampler.set_epoch(epoch)
|
276 |
+
adjust_learning_rate(optimizer, epoch, args)
|
277 |
+
|
278 |
+
log_dir = os.path.join(args.experiment_folder, 'logs')
|
279 |
+
logger = SummaryWriter(log_dir=log_dir)
|
280 |
+
args.logger = logger
|
281 |
+
|
282 |
+
# train for one epoch
|
283 |
+
train(train_loader, model, criterion, optimizer, epoch, args)
|
284 |
+
|
285 |
+
# evaluate on validation set
|
286 |
+
loss1 = validate(val_loader, model, criterion, epoch, args)
|
287 |
+
|
288 |
+
# remember best acc@1 and save checkpoint
|
289 |
+
is_best = loss1 < best_loss
|
290 |
+
best_loss = min(loss1, best_loss)
|
291 |
+
|
292 |
+
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
293 |
+
and args.rank % ngpus_per_node == 0):
|
294 |
+
save_checkpoint({
|
295 |
+
'epoch': epoch + 1,
|
296 |
+
'state_dict': model.state_dict(),
|
297 |
+
'best_loss': best_loss,
|
298 |
+
'optimizer' : optimizer.state_dict(),
|
299 |
+
}, is_best, folder=args.experiment_folder)
|
300 |
+
|
301 |
+
def train(train_loader, model, criterion, optimizer, epoch, args):
|
302 |
+
mse_criterion = torch.nn.MSELoss(reduction='mean')
|
303 |
+
|
304 |
+
losses = AverageMeter('Loss', ':.4e')
|
305 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
306 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
307 |
+
orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
|
308 |
+
orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
|
309 |
+
progress = ProgressMeter(
|
310 |
+
len(train_loader),
|
311 |
+
[losses, top1, top5, orig_top1, orig_top5],
|
312 |
+
prefix="Epoch: [{}]".format(epoch))
|
313 |
+
|
314 |
+
orig_model = vit(pretrained=True).cuda()
|
315 |
+
orig_model.eval()
|
316 |
+
|
317 |
+
# switch to train mode
|
318 |
+
model.train()
|
319 |
+
|
320 |
+
for i, (seg_map, image_ten, class_name) in enumerate(train_loader):
|
321 |
+
if torch.cuda.is_available():
|
322 |
+
image_ten = image_ten.cuda(args.gpu, non_blocking=True)
|
323 |
+
seg_map = seg_map.cuda(args.gpu, non_blocking=True)
|
324 |
+
class_name = class_name.cuda(args.gpu, non_blocking=True)
|
325 |
+
|
326 |
+
|
327 |
+
image_ten.requires_grad = True
|
328 |
+
output = model(image_ten)
|
329 |
+
|
330 |
+
# segmentation loss
|
331 |
+
batch_size = image_ten.shape[0]
|
332 |
+
index = class_name
|
333 |
+
if index == None:
|
334 |
+
index = np.argmax(output.cpu().data.numpy(), axis=-1)
|
335 |
+
index = torch.tensor(index)
|
336 |
+
|
337 |
+
one_hot = np.zeros((batch_size, output.shape[-1]), dtype=np.float32)
|
338 |
+
one_hot[torch.arange(batch_size), index.data.cpu().numpy()] = 1
|
339 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
340 |
+
one_hot = torch.sum(one_hot.to(image_ten.device) * output)
|
341 |
+
model.zero_grad()
|
342 |
+
|
343 |
+
relevance = torch.autograd.grad(one_hot, image_ten, retain_graph=True)[0]
|
344 |
+
|
345 |
+
reverse_seg_map = seg_map.clone()
|
346 |
+
reverse_seg_map[reverse_seg_map == 1] = -1
|
347 |
+
reverse_seg_map[reverse_seg_map == 0] = 1
|
348 |
+
reverse_seg_map[reverse_seg_map == -1] = 0
|
349 |
+
grad_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
|
350 |
+
segmentation_loss = grad_loss
|
351 |
+
|
352 |
+
# classification loss
|
353 |
+
with torch.no_grad():
|
354 |
+
output_orig = orig_model(image_ten)
|
355 |
+
if args.temperature != 1:
|
356 |
+
output = output / args.temperature
|
357 |
+
classification_loss = criterion(output, class_name.flatten())
|
358 |
+
|
359 |
+
loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
|
360 |
+
|
361 |
+
# debugging output
|
362 |
+
if i % args.save_interval == 0:
|
363 |
+
orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
|
364 |
+
for j in range(image_ten.shape[0]):
|
365 |
+
image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
|
366 |
+
new_vis = get_image_with_relevance(image_ten[j]*relevance[j], torch.ones_like(image_ten[j]))
|
367 |
+
old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
|
368 |
+
gt = get_image_with_relevance(image_ten[j], seg_map[j])
|
369 |
+
h_img = cv2.hconcat([image, gt, old_vis, new_vis])
|
370 |
+
cv2.imwrite(f'{args.experiment_folder}/train_samples/res_{i}_{j}.jpg', h_img)
|
371 |
+
|
372 |
+
# measure accuracy and record loss
|
373 |
+
acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
|
374 |
+
losses.update(loss.item(), image_ten.size(0))
|
375 |
+
top1.update(acc1[0], image_ten.size(0))
|
376 |
+
top5.update(acc5[0], image_ten.size(0))
|
377 |
+
|
378 |
+
# metrics for original vit
|
379 |
+
acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
|
380 |
+
orig_top1.update(acc1_orig[0], image_ten.size(0))
|
381 |
+
orig_top5.update(acc5_orig[0], image_ten.size(0))
|
382 |
+
|
383 |
+
# compute gradient and do SGD step
|
384 |
+
optimizer.zero_grad()
|
385 |
+
loss.backward()
|
386 |
+
optimizer.step()
|
387 |
+
|
388 |
+
if i % args.print_freq == 0:
|
389 |
+
progress.display(i)
|
390 |
+
args.logger.add_scalar('{}/{}'.format('train', 'segmentation_loss'), segmentation_loss,
|
391 |
+
epoch*len(train_loader)+i)
|
392 |
+
args.logger.add_scalar('{}/{}'.format('train', 'classification_loss'), classification_loss,
|
393 |
+
epoch * len(train_loader) + i)
|
394 |
+
args.logger.add_scalar('{}/{}'.format('train', 'orig_top1'), acc1_orig,
|
395 |
+
epoch * len(train_loader) + i)
|
396 |
+
args.logger.add_scalar('{}/{}'.format('train', 'top1'), acc1,
|
397 |
+
epoch * len(train_loader) + i)
|
398 |
+
args.logger.add_scalar('{}/{}'.format('train', 'orig_top5'), acc5_orig,
|
399 |
+
epoch * len(train_loader) + i)
|
400 |
+
args.logger.add_scalar('{}/{}'.format('train', 'top5'), acc5,
|
401 |
+
epoch * len(train_loader) + i)
|
402 |
+
args.logger.add_scalar('{}/{}'.format('train', 'tot_loss'), loss,
|
403 |
+
epoch * len(train_loader) + i)
|
404 |
+
|
405 |
+
|
406 |
+
def validate(val_loader, model, criterion, epoch, args):
|
407 |
+
mse_criterion = torch.nn.MSELoss(reduction='mean')
|
408 |
+
|
409 |
+
losses = AverageMeter('Loss', ':.4e')
|
410 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
411 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
412 |
+
orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
|
413 |
+
orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
|
414 |
+
progress = ProgressMeter(
|
415 |
+
len(val_loader),
|
416 |
+
[losses, top1, top5, orig_top1, orig_top5],
|
417 |
+
prefix="Epoch: [{}]".format(val_loader))
|
418 |
+
|
419 |
+
# switch to evaluate mode
|
420 |
+
model.eval()
|
421 |
+
|
422 |
+
orig_model = vit(pretrained=True).cuda()
|
423 |
+
orig_model.eval()
|
424 |
+
|
425 |
+
with torch.no_grad():
|
426 |
+
for i, (seg_map, image_ten, class_name) in enumerate(val_loader):
|
427 |
+
if args.gpu is not None:
|
428 |
+
image_ten = image_ten.cuda(args.gpu, non_blocking=True)
|
429 |
+
if torch.cuda.is_available():
|
430 |
+
seg_map = seg_map.cuda(args.gpu, non_blocking=True)
|
431 |
+
class_name = class_name.cuda(args.gpu, non_blocking=True)
|
432 |
+
|
433 |
+
with torch.enable_grad():
|
434 |
+
image_ten.requires_grad = True
|
435 |
+
output = model(image_ten)
|
436 |
+
|
437 |
+
# segmentation loss
|
438 |
+
batch_size = image_ten.shape[0]
|
439 |
+
index = class_name
|
440 |
+
if index == None:
|
441 |
+
index = np.argmax(output.cpu().data.numpy(), axis=-1)
|
442 |
+
index = torch.tensor(index)
|
443 |
+
|
444 |
+
one_hot = np.zeros((batch_size, output.shape[-1]), dtype=np.float32)
|
445 |
+
one_hot[torch.arange(batch_size), index.data.cpu().numpy()] = 1
|
446 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
447 |
+
one_hot = torch.sum(one_hot.to(image_ten.device) * output)
|
448 |
+
model.zero_grad()
|
449 |
+
relevance = torch.autograd.grad(one_hot, image_ten)[0]
|
450 |
+
|
451 |
+
reverse_seg_map = seg_map.clone()
|
452 |
+
reverse_seg_map[reverse_seg_map == 1] = -1
|
453 |
+
reverse_seg_map[reverse_seg_map == 0] = 1
|
454 |
+
reverse_seg_map[reverse_seg_map == -1] = 0
|
455 |
+
grad_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
|
456 |
+
segmentation_loss = grad_loss
|
457 |
+
|
458 |
+
# classification loss
|
459 |
+
output = model(image_ten)
|
460 |
+
with torch.no_grad():
|
461 |
+
output_orig = orig_model(image_ten)
|
462 |
+
if args.temperature != 1:
|
463 |
+
output = output / args.temperature
|
464 |
+
classification_loss = criterion(output, class_name.flatten())
|
465 |
+
|
466 |
+
loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
|
467 |
+
|
468 |
+
# save results
|
469 |
+
if i % args.save_interval == 0:
|
470 |
+
with torch.enable_grad():
|
471 |
+
orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
|
472 |
+
for j in range(image_ten.shape[0]):
|
473 |
+
image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
|
474 |
+
new_vis = get_image_with_relevance(image_ten[j]*relevance[j], torch.ones_like(image_ten[j]))
|
475 |
+
old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
|
476 |
+
gt = get_image_with_relevance(image_ten[j], seg_map[j])
|
477 |
+
h_img = cv2.hconcat([image, gt, old_vis, new_vis])
|
478 |
+
cv2.imwrite(f'{args.experiment_folder}/val_samples/res_{i}_{j}.jpg', h_img)
|
479 |
+
|
480 |
+
# measure accuracy and record loss
|
481 |
+
acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
|
482 |
+
losses.update(loss.item(), image_ten.size(0))
|
483 |
+
top1.update(acc1[0], image_ten.size(0))
|
484 |
+
top5.update(acc5[0], image_ten.size(0))
|
485 |
+
|
486 |
+
# metrics for original vit
|
487 |
+
acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
|
488 |
+
orig_top1.update(acc1_orig[0], image_ten.size(0))
|
489 |
+
orig_top5.update(acc5_orig[0], image_ten.size(0))
|
490 |
+
|
491 |
+
if i % args.print_freq == 0:
|
492 |
+
progress.display(i)
|
493 |
+
args.logger.add_scalar('{}/{}'.format('val', 'segmentation_loss'), segmentation_loss,
|
494 |
+
epoch * len(val_loader) + i)
|
495 |
+
args.logger.add_scalar('{}/{}'.format('val', 'classification_loss'), classification_loss,
|
496 |
+
epoch * len(val_loader) + i)
|
497 |
+
args.logger.add_scalar('{}/{}'.format('val', 'orig_top1'), acc1_orig,
|
498 |
+
epoch * len(val_loader) + i)
|
499 |
+
args.logger.add_scalar('{}/{}'.format('val', 'top1'), acc1,
|
500 |
+
epoch * len(val_loader) + i)
|
501 |
+
args.logger.add_scalar('{}/{}'.format('val', 'orig_top5'), acc5_orig,
|
502 |
+
epoch * len(val_loader) + i)
|
503 |
+
args.logger.add_scalar('{}/{}'.format('val', 'top5'), acc5,
|
504 |
+
epoch * len(val_loader) + i)
|
505 |
+
args.logger.add_scalar('{}/{}'.format('val', 'tot_loss'), loss,
|
506 |
+
epoch * len(val_loader) + i)
|
507 |
+
|
508 |
+
# TODO: this should also be done with the ProgressMeter
|
509 |
+
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
|
510 |
+
.format(top1=top1, top5=top5))
|
511 |
+
|
512 |
+
return losses.avg
|
513 |
+
|
514 |
+
|
515 |
+
def save_checkpoint(state, is_best, folder, filename='checkpoint.pth.tar'):
|
516 |
+
torch.save(state, f'{folder}/{filename}')
|
517 |
+
if is_best:
|
518 |
+
shutil.copyfile(f'{folder}/{filename}', f'{folder}/model_best.pth.tar')
|
519 |
+
|
520 |
+
|
521 |
+
class AverageMeter(object):
|
522 |
+
"""Computes and stores the average and current value"""
|
523 |
+
def __init__(self, name, fmt=':f'):
|
524 |
+
self.name = name
|
525 |
+
self.fmt = fmt
|
526 |
+
self.reset()
|
527 |
+
|
528 |
+
def reset(self):
|
529 |
+
self.val = 0
|
530 |
+
self.avg = 0
|
531 |
+
self.sum = 0
|
532 |
+
self.count = 0
|
533 |
+
|
534 |
+
def update(self, val, n=1):
|
535 |
+
self.val = val
|
536 |
+
self.sum += val * n
|
537 |
+
self.count += n
|
538 |
+
self.avg = self.sum / self.count
|
539 |
+
|
540 |
+
def __str__(self):
|
541 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
542 |
+
return fmtstr.format(**self.__dict__)
|
543 |
+
|
544 |
+
|
545 |
+
class ProgressMeter(object):
|
546 |
+
def __init__(self, num_batches, meters, prefix=""):
|
547 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
548 |
+
self.meters = meters
|
549 |
+
self.prefix = prefix
|
550 |
+
|
551 |
+
def display(self, batch):
|
552 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
553 |
+
entries += [str(meter) for meter in self.meters]
|
554 |
+
print('\t'.join(entries))
|
555 |
+
|
556 |
+
def _get_batch_fmtstr(self, num_batches):
|
557 |
+
num_digits = len(str(num_batches // 1))
|
558 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
559 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
560 |
+
|
561 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
562 |
+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
563 |
+
lr = args.lr * (0.85 ** (epoch // 2))
|
564 |
+
for param_group in optimizer.param_groups:
|
565 |
+
param_group['lr'] = lr
|
566 |
+
|
567 |
+
|
568 |
+
def accuracy(output, target, topk=(1,)):
|
569 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
570 |
+
with torch.no_grad():
|
571 |
+
maxk = max(topk)
|
572 |
+
batch_size = target.size(0)
|
573 |
+
|
574 |
+
_, pred = output.topk(maxk, 1, True, True)
|
575 |
+
pred = pred.t()
|
576 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
577 |
+
|
578 |
+
res = []
|
579 |
+
for k in topk:
|
580 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
581 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
582 |
+
return res
|
583 |
+
|
584 |
+
|
585 |
+
if __name__ == '__main__':
|
586 |
+
main()
|
imagenet_finetune_rrr.py
ADDED
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import shutil
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.parallel
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
import torch.distributed as dist
|
13 |
+
import torch.optim
|
14 |
+
import torch.multiprocessing as mp
|
15 |
+
import torch.utils.data
|
16 |
+
import torch.utils.data.distributed
|
17 |
+
import torchvision.transforms as transforms
|
18 |
+
import torchvision.datasets as datasets
|
19 |
+
import torchvision.models as models
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from segmentation_dataset import SegmentationDataset, VAL_PARTITION, TRAIN_PARTITION
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
# Uncomment the expected model below
|
25 |
+
|
26 |
+
# ViT
|
27 |
+
from ViT.ViT import vit_base_patch16_224 as vit
|
28 |
+
# from ViT.ViT import vit_large_patch16_224 as vit
|
29 |
+
|
30 |
+
# ViT-AugReg
|
31 |
+
# from ViT.ViT_new import vit_small_patch16_224 as vit
|
32 |
+
# from ViT.ViT_new import vit_base_patch16_224 as vit
|
33 |
+
# from ViT.ViT_new import vit_large_patch16_224 as vit
|
34 |
+
|
35 |
+
# DeiT
|
36 |
+
# from ViT.ViT import deit_base_patch16_224 as vit
|
37 |
+
# from ViT.ViT import deit_small_patch16_224 as vit
|
38 |
+
|
39 |
+
from ViT.explainer import generate_relevance, get_image_with_relevance
|
40 |
+
import torchvision
|
41 |
+
import cv2
|
42 |
+
from torch.utils.tensorboard import SummaryWriter
|
43 |
+
import json
|
44 |
+
|
45 |
+
model_names = sorted(name for name in models.__dict__
|
46 |
+
if name.islower() and not name.startswith("__")
|
47 |
+
and callable(models.__dict__[name]))
|
48 |
+
model_names.append("vit")
|
49 |
+
|
50 |
+
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
51 |
+
parser.add_argument('--data', metavar='DATA',
|
52 |
+
help='path to dataset')
|
53 |
+
parser.add_argument('--seg_data', metavar='SEG_DATA',
|
54 |
+
help='path to segmentation dataset')
|
55 |
+
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
|
56 |
+
choices=model_names,
|
57 |
+
help='model architecture: ' +
|
58 |
+
' | '.join(model_names) +
|
59 |
+
' (default: resnet18)')
|
60 |
+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
61 |
+
help='number of data loading workers (default: 4)')
|
62 |
+
parser.add_argument('--epochs', default=50, type=int, metavar='N',
|
63 |
+
help='number of total epochs to run')
|
64 |
+
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
65 |
+
help='manual epoch number (useful on restarts)')
|
66 |
+
parser.add_argument('-b', '--batch-size', default=8, type=int,
|
67 |
+
metavar='N',
|
68 |
+
help='mini-batch size (default: 256), this is the total '
|
69 |
+
'batch size of all GPUs on the current node when '
|
70 |
+
'using Data Parallel or Distributed Data Parallel')
|
71 |
+
parser.add_argument('--lr', '--learning-rate', default=3e-6, type=float,
|
72 |
+
metavar='LR', help='initial learning rate', dest='lr')
|
73 |
+
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
74 |
+
help='momentum')
|
75 |
+
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
76 |
+
metavar='W', help='weight decay (default: 1e-4)',
|
77 |
+
dest='weight_decay')
|
78 |
+
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
79 |
+
metavar='N', help='print frequency (default: 10)')
|
80 |
+
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
81 |
+
help='path to latest checkpoint (default: none)')
|
82 |
+
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
83 |
+
help='evaluate model on validation set')
|
84 |
+
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
85 |
+
help='use pre-trained model')
|
86 |
+
parser.add_argument('--world-size', default=-1, type=int,
|
87 |
+
help='number of nodes for distributed training')
|
88 |
+
parser.add_argument('--rank', default=-1, type=int,
|
89 |
+
help='node rank for distributed training')
|
90 |
+
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
91 |
+
help='url used to set up distributed training')
|
92 |
+
parser.add_argument('--dist-backend', default='nccl', type=str,
|
93 |
+
help='distributed backend')
|
94 |
+
parser.add_argument('--seed', default=None, type=int,
|
95 |
+
help='seed for initializing training. ')
|
96 |
+
parser.add_argument('--gpu', default=None, type=int,
|
97 |
+
help='GPU id to use.')
|
98 |
+
parser.add_argument('--save_interval', default=20, type=int,
|
99 |
+
help='interval to save segmentation results.')
|
100 |
+
parser.add_argument('--num_samples', default=3, type=int,
|
101 |
+
help='number of samples per class for training')
|
102 |
+
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
103 |
+
help='Use multi-processing distributed training to launch '
|
104 |
+
'N processes per node, which has N GPUs. This is the '
|
105 |
+
'fastest way to use PyTorch for either single node or '
|
106 |
+
'multi node data parallel training')
|
107 |
+
parser.add_argument('--lambda_seg', default=0.8, type=float,
|
108 |
+
help='influence of segmentation loss.')
|
109 |
+
parser.add_argument('--lambda_acc', default=0.2, type=float,
|
110 |
+
help='influence of accuracy loss.')
|
111 |
+
parser.add_argument('--experiment_folder', default=None, type=str,
|
112 |
+
help='path to folder to use for experiment.')
|
113 |
+
parser.add_argument('--num_classes', default=500, type=int,
|
114 |
+
help='coefficient of loss for segmentation foreground.')
|
115 |
+
parser.add_argument('--temperature', default=1, type=float,
|
116 |
+
help='temperature for softmax (mostly for DeiT).')
|
117 |
+
|
118 |
+
best_loss = float('inf')
|
119 |
+
|
120 |
+
def main():
|
121 |
+
args = parser.parse_args()
|
122 |
+
|
123 |
+
if args.experiment_folder is None:
|
124 |
+
args.experiment_folder = f'experiment/' \
|
125 |
+
f'lr_{args.lr}_seg_{args.lambda_seg}_acc_{args.lambda_acc}'
|
126 |
+
if args.temperature != 1:
|
127 |
+
args.experiment_folder = args.experiment_folder + f'_tempera_{args.temperature}'
|
128 |
+
if args.batch_size != 8:
|
129 |
+
args.experiment_folder = args.experiment_folder + f'_bs_{args.batch_size}'
|
130 |
+
if args.num_classes != 500:
|
131 |
+
args.experiment_folder = args.experiment_folder + f'_num_classes_{args.num_classes}'
|
132 |
+
if args.num_samples != 3:
|
133 |
+
args.experiment_folder = args.experiment_folder + f'_num_samples_{args.num_samples}'
|
134 |
+
if args.epochs != 150:
|
135 |
+
args.experiment_folder = args.experiment_folder + f'_num_epochs_{args.epochs}'
|
136 |
+
|
137 |
+
if os.path.exists(args.experiment_folder):
|
138 |
+
raise Exception(f"Experiment path {args.experiment_folder} already exists!")
|
139 |
+
os.mkdir(args.experiment_folder)
|
140 |
+
os.mkdir(f'{args.experiment_folder}/train_samples')
|
141 |
+
os.mkdir(f'{args.experiment_folder}/val_samples')
|
142 |
+
|
143 |
+
with open(f'{args.experiment_folder}/commandline_args.txt', 'w') as f:
|
144 |
+
json.dump(args.__dict__, f, indent=2)
|
145 |
+
|
146 |
+
if args.seed is not None:
|
147 |
+
random.seed(args.seed)
|
148 |
+
torch.manual_seed(args.seed)
|
149 |
+
cudnn.deterministic = True
|
150 |
+
warnings.warn('You have chosen to seed training. '
|
151 |
+
'This will turn on the CUDNN deterministic setting, '
|
152 |
+
'which can slow down your training considerably! '
|
153 |
+
'You may see unexpected behavior when restarting '
|
154 |
+
'from checkpoints.')
|
155 |
+
|
156 |
+
if args.gpu is not None:
|
157 |
+
warnings.warn('You have chosen a specific GPU. This will completely '
|
158 |
+
'disable data parallelism.')
|
159 |
+
|
160 |
+
if args.dist_url == "env://" and args.world_size == -1:
|
161 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
162 |
+
|
163 |
+
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
164 |
+
|
165 |
+
ngpus_per_node = torch.cuda.device_count()
|
166 |
+
if args.multiprocessing_distributed:
|
167 |
+
# Since we have ngpus_per_node processes per node, the total world_size
|
168 |
+
# needs to be adjusted accordingly
|
169 |
+
args.world_size = ngpus_per_node * args.world_size
|
170 |
+
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
171 |
+
# main_worker process function
|
172 |
+
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
173 |
+
else:
|
174 |
+
# Simply call main_worker function
|
175 |
+
main_worker(args.gpu, ngpus_per_node, args)
|
176 |
+
|
177 |
+
|
178 |
+
def main_worker(gpu, ngpus_per_node, args):
|
179 |
+
global best_loss
|
180 |
+
args.gpu = gpu
|
181 |
+
|
182 |
+
if args.gpu is not None:
|
183 |
+
print("Use GPU: {} for training".format(args.gpu))
|
184 |
+
|
185 |
+
if args.distributed:
|
186 |
+
if args.dist_url == "env://" and args.rank == -1:
|
187 |
+
args.rank = int(os.environ["RANK"])
|
188 |
+
if args.multiprocessing_distributed:
|
189 |
+
# For multiprocessing distributed training, rank needs to be the
|
190 |
+
# global rank among all the processes
|
191 |
+
args.rank = args.rank * ngpus_per_node + gpu
|
192 |
+
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
193 |
+
world_size=args.world_size, rank=args.rank)
|
194 |
+
# create model
|
195 |
+
print("=> creating model")
|
196 |
+
model = vit(pretrained=True).cuda()
|
197 |
+
model.train()
|
198 |
+
print("done")
|
199 |
+
|
200 |
+
if not torch.cuda.is_available():
|
201 |
+
print('using CPU, this will be slow')
|
202 |
+
elif args.distributed:
|
203 |
+
# For multiprocessing distributed, DistributedDataParallel constructor
|
204 |
+
# should always set the single device scope, otherwise,
|
205 |
+
# DistributedDataParallel will use all available devices.
|
206 |
+
if args.gpu is not None:
|
207 |
+
torch.cuda.set_device(args.gpu)
|
208 |
+
model.cuda(args.gpu)
|
209 |
+
# When using a single GPU per process and per
|
210 |
+
# DistributedDataParallel, we need to divide the batch size
|
211 |
+
# ourselves based on the total number of GPUs we have
|
212 |
+
args.batch_size = int(args.batch_size / ngpus_per_node)
|
213 |
+
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
214 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
215 |
+
else:
|
216 |
+
model.cuda()
|
217 |
+
# DistributedDataParallel will divide and allocate batch_size to all
|
218 |
+
# available GPUs if device_ids are not set
|
219 |
+
model = torch.nn.parallel.DistributedDataParallel(model)
|
220 |
+
elif args.gpu is not None:
|
221 |
+
torch.cuda.set_device(args.gpu)
|
222 |
+
model = model.cuda(args.gpu)
|
223 |
+
else:
|
224 |
+
# DataParallel will divide and allocate batch_size to all available GPUs
|
225 |
+
print("start")
|
226 |
+
model = torch.nn.DataParallel(model).cuda()
|
227 |
+
|
228 |
+
# define loss function (criterion) and optimizer
|
229 |
+
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
|
230 |
+
optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.weight_decay)
|
231 |
+
|
232 |
+
# optionally resume from a checkpoint
|
233 |
+
if args.resume:
|
234 |
+
if os.path.isfile(args.resume):
|
235 |
+
print("=> loading checkpoint '{}'".format(args.resume))
|
236 |
+
if args.gpu is None:
|
237 |
+
checkpoint = torch.load(args.resume)
|
238 |
+
else:
|
239 |
+
# Map model to be loaded to specified single gpu.
|
240 |
+
loc = 'cuda:{}'.format(args.gpu)
|
241 |
+
checkpoint = torch.load(args.resume, map_location=loc)
|
242 |
+
args.start_epoch = checkpoint['epoch']
|
243 |
+
best_loss = checkpoint['best_loss']
|
244 |
+
if args.gpu is not None:
|
245 |
+
# best_loss may be from a checkpoint from a different GPU
|
246 |
+
best_loss = best_loss.to(args.gpu)
|
247 |
+
model.load_state_dict(checkpoint['state_dict'])
|
248 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
249 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
250 |
+
.format(args.resume, checkpoint['epoch']))
|
251 |
+
else:
|
252 |
+
print("=> no checkpoint found at '{}'".format(args.resume))
|
253 |
+
|
254 |
+
cudnn.benchmark = True
|
255 |
+
|
256 |
+
train_dataset = SegmentationDataset(args.seg_data, args.data, partition=TRAIN_PARTITION, train_classes=args.num_classes,
|
257 |
+
num_samples=args.num_samples)
|
258 |
+
|
259 |
+
if args.distributed:
|
260 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
261 |
+
else:
|
262 |
+
train_sampler = None
|
263 |
+
|
264 |
+
train_loader = torch.utils.data.DataLoader(
|
265 |
+
train_dataset, batch_size=args.batch_size, shuffle=False,
|
266 |
+
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
|
267 |
+
|
268 |
+
val_dataset = SegmentationDataset(args.seg_data, args.data, partition=VAL_PARTITION, train_classes=args.num_classes,
|
269 |
+
num_samples=1)
|
270 |
+
|
271 |
+
val_loader = torch.utils.data.DataLoader(
|
272 |
+
val_dataset, batch_size=5, shuffle=False,
|
273 |
+
num_workers=args.workers, pin_memory=True)
|
274 |
+
|
275 |
+
if args.evaluate:
|
276 |
+
validate(val_loader, model, criterion, 0, args)
|
277 |
+
return
|
278 |
+
|
279 |
+
for epoch in range(args.start_epoch, args.epochs):
|
280 |
+
if args.distributed:
|
281 |
+
train_sampler.set_epoch(epoch)
|
282 |
+
adjust_learning_rate(optimizer, epoch, args)
|
283 |
+
|
284 |
+
log_dir = os.path.join(args.experiment_folder, 'logs')
|
285 |
+
logger = SummaryWriter(log_dir=log_dir)
|
286 |
+
args.logger = logger
|
287 |
+
|
288 |
+
# train for one epoch
|
289 |
+
train(train_loader, model, criterion, optimizer, epoch, args)
|
290 |
+
|
291 |
+
# evaluate on validation set
|
292 |
+
loss1 = validate(val_loader, model, criterion, epoch, args)
|
293 |
+
|
294 |
+
# remember best acc@1 and save checkpoint
|
295 |
+
is_best = loss1 < best_loss
|
296 |
+
best_loss = min(loss1, best_loss)
|
297 |
+
|
298 |
+
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
299 |
+
and args.rank % ngpus_per_node == 0):
|
300 |
+
save_checkpoint({
|
301 |
+
'epoch': epoch + 1,
|
302 |
+
'state_dict': model.state_dict(),
|
303 |
+
'best_loss': best_loss,
|
304 |
+
'optimizer' : optimizer.state_dict(),
|
305 |
+
}, is_best, folder=args.experiment_folder)
|
306 |
+
|
307 |
+
def train(train_loader, model, criterion, optimizer, epoch, args):
|
308 |
+
losses = AverageMeter('Loss', ':.4e')
|
309 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
310 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
311 |
+
orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
|
312 |
+
orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
|
313 |
+
progress = ProgressMeter(
|
314 |
+
len(train_loader),
|
315 |
+
[losses, top1, top5, orig_top1, orig_top5],
|
316 |
+
prefix="Epoch: [{}]".format(epoch))
|
317 |
+
|
318 |
+
orig_model = vit(pretrained=True).cuda()
|
319 |
+
orig_model.eval()
|
320 |
+
|
321 |
+
# switch to train mode
|
322 |
+
model.train()
|
323 |
+
|
324 |
+
for i, (seg_map, image_ten, class_name) in enumerate(train_loader):
|
325 |
+
if torch.cuda.is_available():
|
326 |
+
image_ten = image_ten.cuda(args.gpu, non_blocking=True)
|
327 |
+
seg_map = seg_map.cuda(args.gpu, non_blocking=True)
|
328 |
+
class_name = class_name.cuda(args.gpu, non_blocking=True)
|
329 |
+
|
330 |
+
|
331 |
+
image_ten.requires_grad = True
|
332 |
+
output = model(image_ten)
|
333 |
+
|
334 |
+
# segmentation loss
|
335 |
+
EPS = 10e-12
|
336 |
+
y_pred = torch.sum(torch.log(F.softmax(output, dim=1) + EPS))
|
337 |
+
relevance = torch.autograd.grad(y_pred, image_ten, retain_graph=True)[0]
|
338 |
+
reverse_seg_map = seg_map.clone()
|
339 |
+
reverse_seg_map[reverse_seg_map == 1] = -1
|
340 |
+
reverse_seg_map[reverse_seg_map == 0] = 1
|
341 |
+
reverse_seg_map[reverse_seg_map == -1] = 0
|
342 |
+
rrr_loss = (relevance * reverse_seg_map)**2
|
343 |
+
segmentation_loss = rrr_loss.sum()
|
344 |
+
|
345 |
+
# classification loss
|
346 |
+
with torch.no_grad():
|
347 |
+
output_orig = orig_model(image_ten)
|
348 |
+
if args.temperature != 1:
|
349 |
+
output = output / args.temperature
|
350 |
+
classification_loss = criterion(output, class_name.flatten())
|
351 |
+
|
352 |
+
loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
|
353 |
+
|
354 |
+
# debugging output
|
355 |
+
if i % args.save_interval == 0:
|
356 |
+
orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
|
357 |
+
for j in range(image_ten.shape[0]):
|
358 |
+
image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
|
359 |
+
new_vis = get_image_with_relevance(image_ten[j]*relevance[j], torch.ones_like(image_ten[j]))
|
360 |
+
old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
|
361 |
+
gt = get_image_with_relevance(image_ten[j], seg_map[j])
|
362 |
+
h_img = cv2.hconcat([image, gt, old_vis, new_vis])
|
363 |
+
cv2.imwrite(f'{args.experiment_folder}/train_samples/res_{i}_{j}.jpg', h_img)
|
364 |
+
|
365 |
+
# measure accuracy and record loss
|
366 |
+
acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
|
367 |
+
losses.update(loss.item(), image_ten.size(0))
|
368 |
+
top1.update(acc1[0], image_ten.size(0))
|
369 |
+
top5.update(acc5[0], image_ten.size(0))
|
370 |
+
|
371 |
+
# metrics for original vit
|
372 |
+
acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
|
373 |
+
orig_top1.update(acc1_orig[0], image_ten.size(0))
|
374 |
+
orig_top5.update(acc5_orig[0], image_ten.size(0))
|
375 |
+
|
376 |
+
# compute gradient and do SGD step
|
377 |
+
optimizer.zero_grad()
|
378 |
+
loss.backward()
|
379 |
+
optimizer.step()
|
380 |
+
|
381 |
+
if i % args.print_freq == 0:
|
382 |
+
progress.display(i)
|
383 |
+
args.logger.add_scalar('{}/{}'.format('train', 'segmentation_loss'), segmentation_loss,
|
384 |
+
epoch*len(train_loader)+i)
|
385 |
+
args.logger.add_scalar('{}/{}'.format('train', 'classification_loss'), classification_loss,
|
386 |
+
epoch * len(train_loader) + i)
|
387 |
+
args.logger.add_scalar('{}/{}'.format('train', 'orig_top1'), acc1_orig,
|
388 |
+
epoch * len(train_loader) + i)
|
389 |
+
args.logger.add_scalar('{}/{}'.format('train', 'top1'), acc1,
|
390 |
+
epoch * len(train_loader) + i)
|
391 |
+
args.logger.add_scalar('{}/{}'.format('train', 'orig_top5'), acc5_orig,
|
392 |
+
epoch * len(train_loader) + i)
|
393 |
+
args.logger.add_scalar('{}/{}'.format('train', 'top5'), acc5,
|
394 |
+
epoch * len(train_loader) + i)
|
395 |
+
args.logger.add_scalar('{}/{}'.format('train', 'tot_loss'), loss,
|
396 |
+
epoch * len(train_loader) + i)
|
397 |
+
|
398 |
+
|
399 |
+
def validate(val_loader, model, criterion, epoch, args):
|
400 |
+
mse_criterion = torch.nn.MSELoss(reduction='mean')
|
401 |
+
|
402 |
+
losses = AverageMeter('Loss', ':.4e')
|
403 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
404 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
405 |
+
orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
|
406 |
+
orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
|
407 |
+
progress = ProgressMeter(
|
408 |
+
len(val_loader),
|
409 |
+
[losses, top1, top5, orig_top1, orig_top5],
|
410 |
+
prefix="Epoch: [{}]".format(val_loader))
|
411 |
+
|
412 |
+
# switch to evaluate mode
|
413 |
+
model.eval()
|
414 |
+
|
415 |
+
orig_model = vit(pretrained=True).cuda()
|
416 |
+
orig_model.eval()
|
417 |
+
|
418 |
+
with torch.no_grad():
|
419 |
+
for i, (seg_map, image_ten, class_name) in enumerate(val_loader):
|
420 |
+
if args.gpu is not None:
|
421 |
+
image_ten = image_ten.cuda(args.gpu, non_blocking=True)
|
422 |
+
if torch.cuda.is_available():
|
423 |
+
seg_map = seg_map.cuda(args.gpu, non_blocking=True)
|
424 |
+
class_name = class_name.cuda(args.gpu, non_blocking=True)
|
425 |
+
|
426 |
+
with torch.enable_grad():
|
427 |
+
image_ten.requires_grad = True
|
428 |
+
output = model(image_ten)
|
429 |
+
|
430 |
+
# segmentation loss
|
431 |
+
EPS = 10e-12
|
432 |
+
y_pred = torch.sum(torch.log(F.softmax(output, dim=1) + EPS))
|
433 |
+
relevance = torch.autograd.grad(y_pred, image_ten, retain_graph=True)[0]
|
434 |
+
|
435 |
+
reverse_seg_map = seg_map.clone()
|
436 |
+
reverse_seg_map[reverse_seg_map == 1] = -1
|
437 |
+
reverse_seg_map[reverse_seg_map == 0] = 1
|
438 |
+
reverse_seg_map[reverse_seg_map == -1] = 0
|
439 |
+
rrr_loss = (relevance * reverse_seg_map) ** 2
|
440 |
+
segmentation_loss = rrr_loss.sum()
|
441 |
+
|
442 |
+
# classification loss
|
443 |
+
output = model(image_ten)
|
444 |
+
with torch.no_grad():
|
445 |
+
output_orig = orig_model(image_ten)
|
446 |
+
if args.temperature != 1:
|
447 |
+
output = output / args.temperature
|
448 |
+
classification_loss = criterion(output, class_name.flatten())
|
449 |
+
|
450 |
+
loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
|
451 |
+
|
452 |
+
# save results
|
453 |
+
if i % args.save_interval == 0:
|
454 |
+
with torch.enable_grad():
|
455 |
+
orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
|
456 |
+
for j in range(image_ten.shape[0]):
|
457 |
+
image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
|
458 |
+
new_vis = get_image_with_relevance(image_ten[j]*relevance[j], torch.ones_like(image_ten[j]))
|
459 |
+
old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
|
460 |
+
gt = get_image_with_relevance(image_ten[j], seg_map[j])
|
461 |
+
h_img = cv2.hconcat([image, gt, old_vis, new_vis])
|
462 |
+
cv2.imwrite(f'{args.experiment_folder}/val_samples/res_{i}_{j}.jpg', h_img)
|
463 |
+
|
464 |
+
# measure accuracy and record loss
|
465 |
+
acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
|
466 |
+
losses.update(loss.item(), image_ten.size(0))
|
467 |
+
top1.update(acc1[0], image_ten.size(0))
|
468 |
+
top5.update(acc5[0], image_ten.size(0))
|
469 |
+
|
470 |
+
# metrics for original vit
|
471 |
+
acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
|
472 |
+
orig_top1.update(acc1_orig[0], image_ten.size(0))
|
473 |
+
orig_top5.update(acc5_orig[0], image_ten.size(0))
|
474 |
+
|
475 |
+
if i % args.print_freq == 0:
|
476 |
+
progress.display(i)
|
477 |
+
args.logger.add_scalar('{}/{}'.format('val', 'segmentation_loss'), segmentation_loss,
|
478 |
+
epoch * len(val_loader) + i)
|
479 |
+
args.logger.add_scalar('{}/{}'.format('val', 'classification_loss'), classification_loss,
|
480 |
+
epoch * len(val_loader) + i)
|
481 |
+
args.logger.add_scalar('{}/{}'.format('val', 'orig_top1'), acc1_orig,
|
482 |
+
epoch * len(val_loader) + i)
|
483 |
+
args.logger.add_scalar('{}/{}'.format('val', 'top1'), acc1,
|
484 |
+
epoch * len(val_loader) + i)
|
485 |
+
args.logger.add_scalar('{}/{}'.format('val', 'orig_top5'), acc5_orig,
|
486 |
+
epoch * len(val_loader) + i)
|
487 |
+
args.logger.add_scalar('{}/{}'.format('val', 'top5'), acc5,
|
488 |
+
epoch * len(val_loader) + i)
|
489 |
+
args.logger.add_scalar('{}/{}'.format('val', 'tot_loss'), loss,
|
490 |
+
epoch * len(val_loader) + i)
|
491 |
+
|
492 |
+
# TODO: this should also be done with the ProgressMeter
|
493 |
+
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
|
494 |
+
.format(top1=top1, top5=top5))
|
495 |
+
|
496 |
+
return losses.avg
|
497 |
+
|
498 |
+
|
499 |
+
def save_checkpoint(state, is_best, folder, filename='checkpoint.pth.tar'):
|
500 |
+
torch.save(state, f'{folder}/{filename}')
|
501 |
+
if is_best:
|
502 |
+
shutil.copyfile(f'{folder}/{filename}', f'{folder}/model_best.pth.tar')
|
503 |
+
|
504 |
+
|
505 |
+
class AverageMeter(object):
|
506 |
+
"""Computes and stores the average and current value"""
|
507 |
+
def __init__(self, name, fmt=':f'):
|
508 |
+
self.name = name
|
509 |
+
self.fmt = fmt
|
510 |
+
self.reset()
|
511 |
+
|
512 |
+
def reset(self):
|
513 |
+
self.val = 0
|
514 |
+
self.avg = 0
|
515 |
+
self.sum = 0
|
516 |
+
self.count = 0
|
517 |
+
|
518 |
+
def update(self, val, n=1):
|
519 |
+
self.val = val
|
520 |
+
self.sum += val * n
|
521 |
+
self.count += n
|
522 |
+
self.avg = self.sum / self.count
|
523 |
+
|
524 |
+
def __str__(self):
|
525 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
526 |
+
return fmtstr.format(**self.__dict__)
|
527 |
+
|
528 |
+
|
529 |
+
class ProgressMeter(object):
|
530 |
+
def __init__(self, num_batches, meters, prefix=""):
|
531 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
532 |
+
self.meters = meters
|
533 |
+
self.prefix = prefix
|
534 |
+
|
535 |
+
def display(self, batch):
|
536 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
537 |
+
entries += [str(meter) for meter in self.meters]
|
538 |
+
print('\t'.join(entries))
|
539 |
+
|
540 |
+
def _get_batch_fmtstr(self, num_batches):
|
541 |
+
num_digits = len(str(num_batches // 1))
|
542 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
543 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
544 |
+
|
545 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
546 |
+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
547 |
+
lr = args.lr * (0.85 ** (epoch // 2))
|
548 |
+
for param_group in optimizer.param_groups:
|
549 |
+
param_group['lr'] = lr
|
550 |
+
|
551 |
+
|
552 |
+
def accuracy(output, target, topk=(1,)):
|
553 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
554 |
+
with torch.no_grad():
|
555 |
+
maxk = max(topk)
|
556 |
+
batch_size = target.size(0)
|
557 |
+
|
558 |
+
_, pred = output.topk(maxk, 1, True, True)
|
559 |
+
pred = pred.t()
|
560 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
561 |
+
|
562 |
+
res = []
|
563 |
+
for k in topk:
|
564 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
565 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
566 |
+
return res
|
567 |
+
|
568 |
+
|
569 |
+
if __name__ == '__main__':
|
570 |
+
main()
|
imagenet_finetune_tokencut.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import shutil
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.parallel
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
import torch.distributed as dist
|
13 |
+
import torch.optim
|
14 |
+
import torch.multiprocessing as mp
|
15 |
+
import torch.utils.data
|
16 |
+
import torch.utils.data.distributed
|
17 |
+
import torchvision.transforms as transforms
|
18 |
+
import torchvision.datasets as datasets
|
19 |
+
import torchvision.models as models
|
20 |
+
from tokencut_dataset import SegmentationDataset, VAL_PARTITION, TRAIN_PARTITION
|
21 |
+
|
22 |
+
# Uncomment the expected model below
|
23 |
+
|
24 |
+
# ViT
|
25 |
+
from ViT.ViT import vit_base_patch16_224 as vit
|
26 |
+
# from ViT.ViT import vit_large_patch16_224 as vit
|
27 |
+
|
28 |
+
# ViT-AugReg
|
29 |
+
# from ViT.ViT_new import vit_small_patch16_224 as vit
|
30 |
+
# from ViT.ViT_new import vit_base_patch16_224 as vit
|
31 |
+
# from ViT.ViT_new import vit_large_patch16_224 as vit
|
32 |
+
|
33 |
+
# DeiT
|
34 |
+
# from ViT.ViT import deit_base_patch16_224 as vit
|
35 |
+
# from ViT.ViT import deit_small_patch16_224 as vit
|
36 |
+
|
37 |
+
from ViT.explainer import generate_relevance, get_image_with_relevance
|
38 |
+
import torchvision
|
39 |
+
import cv2
|
40 |
+
from torch.utils.tensorboard import SummaryWriter
|
41 |
+
import json
|
42 |
+
|
43 |
+
model_names = sorted(name for name in models.__dict__
|
44 |
+
if name.islower() and not name.startswith("__")
|
45 |
+
and callable(models.__dict__[name]))
|
46 |
+
model_names.append("vit")
|
47 |
+
|
48 |
+
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
49 |
+
parser.add_argument('--data', metavar='DATA',
|
50 |
+
help='path to dataset')
|
51 |
+
parser.add_argument('--seg_data', metavar='SEG_DATA',
|
52 |
+
help='path to segmentation dataset')
|
53 |
+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
54 |
+
help='number of data loading workers (default: 4)')
|
55 |
+
parser.add_argument('--epochs', default=150, type=int, metavar='N',
|
56 |
+
help='number of total epochs to run')
|
57 |
+
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
58 |
+
help='manual epoch number (useful on restarts)')
|
59 |
+
parser.add_argument('-b', '--batch-size', default=10, type=int,
|
60 |
+
metavar='N',
|
61 |
+
help='mini-batch size (default: 256), this is the total '
|
62 |
+
'batch size of all GPUs on the current node when '
|
63 |
+
'using Data Parallel or Distributed Data Parallel')
|
64 |
+
parser.add_argument('--lr', '--learning-rate', default=3e-6, type=float,
|
65 |
+
metavar='LR', help='initial learning rate', dest='lr')
|
66 |
+
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
67 |
+
help='momentum')
|
68 |
+
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
69 |
+
metavar='W', help='weight decay (default: 1e-4)',
|
70 |
+
dest='weight_decay')
|
71 |
+
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
72 |
+
metavar='N', help='print frequency (default: 10)')
|
73 |
+
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
74 |
+
help='path to latest checkpoint (default: none)')
|
75 |
+
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
76 |
+
help='evaluate model on validation set')
|
77 |
+
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
78 |
+
help='use pre-trained model')
|
79 |
+
parser.add_argument('--world-size', default=-1, type=int,
|
80 |
+
help='number of nodes for distributed training')
|
81 |
+
parser.add_argument('--rank', default=-1, type=int,
|
82 |
+
help='node rank for distributed training')
|
83 |
+
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
84 |
+
help='url used to set up distributed training')
|
85 |
+
parser.add_argument('--dist-backend', default='nccl', type=str,
|
86 |
+
help='distributed backend')
|
87 |
+
parser.add_argument('--seed', default=None, type=int,
|
88 |
+
help='seed for initializing training. ')
|
89 |
+
parser.add_argument('--gpu', default=None, type=int,
|
90 |
+
help='GPU id to use.')
|
91 |
+
parser.add_argument('--save_interval', default=20, type=int,
|
92 |
+
help='interval to save segmentation results.')
|
93 |
+
parser.add_argument('--num_samples', default=3, type=int,
|
94 |
+
help='number of samples per class for training')
|
95 |
+
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
96 |
+
help='Use multi-processing distributed training to launch '
|
97 |
+
'N processes per node, which has N GPUs. This is the '
|
98 |
+
'fastest way to use PyTorch for either single node or '
|
99 |
+
'multi node data parallel training')
|
100 |
+
parser.add_argument('--lambda_seg', default=0.1, type=float,
|
101 |
+
help='influence of segmentation loss.')
|
102 |
+
parser.add_argument('--lambda_acc', default=1, type=float,
|
103 |
+
help='influence of accuracy loss.')
|
104 |
+
parser.add_argument('--experiment_folder', default=None, type=str,
|
105 |
+
help='path to folder to use for experiment.')
|
106 |
+
parser.add_argument('--dilation', default=0, type=float,
|
107 |
+
help='Use dilation on the segmentation maps.')
|
108 |
+
parser.add_argument('--lambda_background', default=1, type=float,
|
109 |
+
help='coefficient of loss for segmentation background.')
|
110 |
+
parser.add_argument('--lambda_foreground', default=0.3, type=float,
|
111 |
+
help='coefficient of loss for segmentation foreground.')
|
112 |
+
parser.add_argument('--num_classes', default=500, type=int,
|
113 |
+
help='coefficient of loss for segmentation foreground.')
|
114 |
+
parser.add_argument('--temperature', default=1, type=float,
|
115 |
+
help='temperature for softmax (mostly for DeiT).')
|
116 |
+
|
117 |
+
best_loss = float('inf')
|
118 |
+
|
119 |
+
def main():
|
120 |
+
args = parser.parse_args()
|
121 |
+
|
122 |
+
if args.experiment_folder is None:
|
123 |
+
args.experiment_folder = f'experiment/' \
|
124 |
+
f'lr_{args.lr}_seg_{args.lambda_seg}_acc_{args.lambda_acc}' \
|
125 |
+
f'_bckg_{args.lambda_background}_fgd_{args.lambda_foreground}'
|
126 |
+
if args.temperature != 1:
|
127 |
+
args.experiment_folder = args.experiment_folder + f'_tempera_{args.temperature}'
|
128 |
+
if args.batch_size != 8:
|
129 |
+
args.experiment_folder = args.experiment_folder + f'_bs_{args.batch_size}'
|
130 |
+
if args.num_classes != 500:
|
131 |
+
args.experiment_folder = args.experiment_folder + f'_num_classes_{args.num_classes}'
|
132 |
+
if args.num_samples != 3:
|
133 |
+
args.experiment_folder = args.experiment_folder + f'_num_samples_{args.num_samples}'
|
134 |
+
if args.epochs != 150:
|
135 |
+
args.experiment_folder = args.experiment_folder + f'_num_epochs_{args.epochs}'
|
136 |
+
|
137 |
+
if os.path.exists(args.experiment_folder):
|
138 |
+
raise Exception(f"Experiment path {args.experiment_folder} already exists!")
|
139 |
+
os.mkdir(args.experiment_folder)
|
140 |
+
os.mkdir(f'{args.experiment_folder}/train_samples')
|
141 |
+
os.mkdir(f'{args.experiment_folder}/val_samples')
|
142 |
+
|
143 |
+
with open(f'{args.experiment_folder}/commandline_args.txt', 'w') as f:
|
144 |
+
json.dump(args.__dict__, f, indent=2)
|
145 |
+
|
146 |
+
if args.seed is not None:
|
147 |
+
random.seed(args.seed)
|
148 |
+
torch.manual_seed(args.seed)
|
149 |
+
cudnn.deterministic = True
|
150 |
+
warnings.warn('You have chosen to seed training. '
|
151 |
+
'This will turn on the CUDNN deterministic setting, '
|
152 |
+
'which can slow down your training considerably! '
|
153 |
+
'You may see unexpected behavior when restarting '
|
154 |
+
'from checkpoints.')
|
155 |
+
|
156 |
+
if args.gpu is not None:
|
157 |
+
warnings.warn('You have chosen a specific GPU. This will completely '
|
158 |
+
'disable data parallelism.')
|
159 |
+
|
160 |
+
if args.dist_url == "env://" and args.world_size == -1:
|
161 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
162 |
+
|
163 |
+
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
164 |
+
|
165 |
+
ngpus_per_node = torch.cuda.device_count()
|
166 |
+
if args.multiprocessing_distributed:
|
167 |
+
# Since we have ngpus_per_node processes per node, the total world_size
|
168 |
+
# needs to be adjusted accordingly
|
169 |
+
args.world_size = ngpus_per_node * args.world_size
|
170 |
+
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
171 |
+
# main_worker process function
|
172 |
+
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
173 |
+
else:
|
174 |
+
# Simply call main_worker function
|
175 |
+
main_worker(args.gpu, ngpus_per_node, args)
|
176 |
+
|
177 |
+
|
178 |
+
def main_worker(gpu, ngpus_per_node, args):
|
179 |
+
global best_loss
|
180 |
+
args.gpu = gpu
|
181 |
+
|
182 |
+
if args.gpu is not None:
|
183 |
+
print("Use GPU: {} for training".format(args.gpu))
|
184 |
+
|
185 |
+
if args.distributed:
|
186 |
+
if args.dist_url == "env://" and args.rank == -1:
|
187 |
+
args.rank = int(os.environ["RANK"])
|
188 |
+
if args.multiprocessing_distributed:
|
189 |
+
# For multiprocessing distributed training, rank needs to be the
|
190 |
+
# global rank among all the processes
|
191 |
+
args.rank = args.rank * ngpus_per_node + gpu
|
192 |
+
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
193 |
+
world_size=args.world_size, rank=args.rank)
|
194 |
+
# create model
|
195 |
+
print("=> creating model")
|
196 |
+
model = vit(pretrained=True).cuda()
|
197 |
+
model.train()
|
198 |
+
print("done")
|
199 |
+
|
200 |
+
if not torch.cuda.is_available():
|
201 |
+
print('using CPU, this will be slow')
|
202 |
+
elif args.distributed:
|
203 |
+
# For multiprocessing distributed, DistributedDataParallel constructor
|
204 |
+
# should always set the single device scope, otherwise,
|
205 |
+
# DistributedDataParallel will use all available devices.
|
206 |
+
if args.gpu is not None:
|
207 |
+
torch.cuda.set_device(args.gpu)
|
208 |
+
model.cuda(args.gpu)
|
209 |
+
# When using a single GPU per process and per
|
210 |
+
# DistributedDataParallel, we need to divide the batch size
|
211 |
+
# ourselves based on the total number of GPUs we have
|
212 |
+
args.batch_size = int(args.batch_size / ngpus_per_node)
|
213 |
+
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
214 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
215 |
+
else:
|
216 |
+
model.cuda()
|
217 |
+
# DistributedDataParallel will divide and allocate batch_size to all
|
218 |
+
# available GPUs if device_ids are not set
|
219 |
+
model = torch.nn.parallel.DistributedDataParallel(model)
|
220 |
+
elif args.gpu is not None:
|
221 |
+
torch.cuda.set_device(args.gpu)
|
222 |
+
model = model.cuda(args.gpu)
|
223 |
+
else:
|
224 |
+
# DataParallel will divide and allocate batch_size to all available GPUs
|
225 |
+
print("start")
|
226 |
+
model = torch.nn.DataParallel(model).cuda()
|
227 |
+
|
228 |
+
# define loss function (criterion) and optimizer
|
229 |
+
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
|
230 |
+
optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.weight_decay)
|
231 |
+
|
232 |
+
# optionally resume from a checkpoint
|
233 |
+
if args.resume:
|
234 |
+
if os.path.isfile(args.resume):
|
235 |
+
print("=> loading checkpoint '{}'".format(args.resume))
|
236 |
+
if args.gpu is None:
|
237 |
+
checkpoint = torch.load(args.resume)
|
238 |
+
else:
|
239 |
+
# Map model to be loaded to specified single gpu.
|
240 |
+
loc = 'cuda:{}'.format(args.gpu)
|
241 |
+
checkpoint = torch.load(args.resume, map_location=loc)
|
242 |
+
args.start_epoch = checkpoint['epoch']
|
243 |
+
best_loss = checkpoint['best_loss']
|
244 |
+
if args.gpu is not None:
|
245 |
+
# best_loss may be from a checkpoint from a different GPU
|
246 |
+
best_loss = best_loss.to(args.gpu)
|
247 |
+
model.load_state_dict(checkpoint['state_dict'])
|
248 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
249 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
250 |
+
.format(args.resume, checkpoint['epoch']))
|
251 |
+
else:
|
252 |
+
print("=> no checkpoint found at '{}'".format(args.resume))
|
253 |
+
|
254 |
+
cudnn.benchmark = True
|
255 |
+
|
256 |
+
train_dataset = SegmentationDataset(args.seg_data, args.data, partition=TRAIN_PARTITION, train_classes=args.num_classes,
|
257 |
+
num_samples=args.num_samples)
|
258 |
+
|
259 |
+
if args.distributed:
|
260 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
261 |
+
else:
|
262 |
+
train_sampler = None
|
263 |
+
|
264 |
+
train_loader = torch.utils.data.DataLoader(
|
265 |
+
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
266 |
+
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
|
267 |
+
|
268 |
+
val_dataset = SegmentationDataset(args.seg_data, args.data, partition=VAL_PARTITION, train_classes=args.num_classes,
|
269 |
+
num_samples=1)
|
270 |
+
|
271 |
+
val_loader = torch.utils.data.DataLoader(
|
272 |
+
val_dataset, batch_size=10, shuffle=False,
|
273 |
+
num_workers=args.workers, pin_memory=True)
|
274 |
+
|
275 |
+
if args.evaluate:
|
276 |
+
validate(val_loader, model, criterion, 0, args)
|
277 |
+
return
|
278 |
+
|
279 |
+
for epoch in range(args.start_epoch, args.epochs):
|
280 |
+
if args.distributed:
|
281 |
+
train_sampler.set_epoch(epoch)
|
282 |
+
adjust_learning_rate(optimizer, epoch, args)
|
283 |
+
|
284 |
+
log_dir = os.path.join(args.experiment_folder, 'logs')
|
285 |
+
logger = SummaryWriter(log_dir=log_dir)
|
286 |
+
args.logger = logger
|
287 |
+
|
288 |
+
# train for one epoch
|
289 |
+
train(train_loader, model, criterion, optimizer, epoch, args)
|
290 |
+
|
291 |
+
# evaluate on validation set
|
292 |
+
loss1 = validate(val_loader, model, criterion, epoch, args)
|
293 |
+
|
294 |
+
# remember best acc@1 and save checkpoint
|
295 |
+
is_best = loss1 <= best_loss
|
296 |
+
best_loss = min(loss1, best_loss)
|
297 |
+
|
298 |
+
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
299 |
+
and args.rank % ngpus_per_node == 0):
|
300 |
+
save_checkpoint({
|
301 |
+
'epoch': epoch + 1,
|
302 |
+
'state_dict': model.state_dict(),
|
303 |
+
'best_loss': best_loss,
|
304 |
+
'optimizer' : optimizer.state_dict(),
|
305 |
+
}, is_best, folder=args.experiment_folder)
|
306 |
+
|
307 |
+
|
308 |
+
def train(train_loader, model, criterion, optimizer, epoch, args):
|
309 |
+
mse_criterion = torch.nn.MSELoss(reduction='mean')
|
310 |
+
losses = AverageMeter('Loss', ':.4e')
|
311 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
312 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
313 |
+
orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
|
314 |
+
orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
|
315 |
+
progress = ProgressMeter(
|
316 |
+
len(train_loader),
|
317 |
+
# [batch_time, data_time, losses, top1, top5, orig_top1, orig_top5],
|
318 |
+
[losses, top1, top5, orig_top1, orig_top5],
|
319 |
+
prefix="Epoch: [{}]".format(epoch))
|
320 |
+
|
321 |
+
orig_model = vit(pretrained=True).cuda()
|
322 |
+
orig_model.eval()
|
323 |
+
|
324 |
+
# switch to train mode
|
325 |
+
model.train()
|
326 |
+
|
327 |
+
end = time.time()
|
328 |
+
for i, (seg_map, image_ten, class_name) in enumerate(train_loader):
|
329 |
+
|
330 |
+
if torch.cuda.is_available():
|
331 |
+
image_ten = image_ten.cuda(args.gpu, non_blocking=True)
|
332 |
+
seg_map = seg_map.cuda(args.gpu, non_blocking=True)
|
333 |
+
class_name = class_name.cuda(args.gpu, non_blocking=True)
|
334 |
+
|
335 |
+
# compute output
|
336 |
+
|
337 |
+
# segmentation loss
|
338 |
+
relevance = generate_relevance(model, image_ten, index=class_name)
|
339 |
+
|
340 |
+
reverse_seg_map = seg_map.clone()
|
341 |
+
reverse_seg_map[reverse_seg_map == 1] = -1
|
342 |
+
reverse_seg_map[reverse_seg_map == 0] = 1
|
343 |
+
reverse_seg_map[reverse_seg_map == -1] = 0
|
344 |
+
background_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
|
345 |
+
foreground_loss = mse_criterion(relevance * seg_map, seg_map)
|
346 |
+
segmentation_loss = args.lambda_background * background_loss
|
347 |
+
segmentation_loss += args.lambda_foreground * foreground_loss
|
348 |
+
|
349 |
+
# classification loss
|
350 |
+
output = model(image_ten)
|
351 |
+
with torch.no_grad():
|
352 |
+
output_orig = orig_model(image_ten)
|
353 |
+
|
354 |
+
_, pred = output.topk(1, 1, True, True)
|
355 |
+
pred = pred.flatten()
|
356 |
+
if args.temperature != 1:
|
357 |
+
output = output / args.temperature
|
358 |
+
classification_loss = criterion(output, pred)
|
359 |
+
|
360 |
+
loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
|
361 |
+
|
362 |
+
# debugging output
|
363 |
+
if i % args.save_interval == 0:
|
364 |
+
orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
|
365 |
+
for j in range(image_ten.shape[0]):
|
366 |
+
image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
|
367 |
+
new_vis = get_image_with_relevance(image_ten[j], relevance[j])
|
368 |
+
old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
|
369 |
+
gt = get_image_with_relevance(image_ten[j], seg_map[j])
|
370 |
+
h_img = cv2.hconcat([image, gt, old_vis, new_vis])
|
371 |
+
cv2.imwrite(f'{args.experiment_folder}/train_samples/res_{i}_{j}.jpg', h_img)
|
372 |
+
|
373 |
+
# measure accuracy and record loss
|
374 |
+
acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
|
375 |
+
losses.update(loss.item(), image_ten.size(0))
|
376 |
+
top1.update(acc1[0], image_ten.size(0))
|
377 |
+
top5.update(acc5[0], image_ten.size(0))
|
378 |
+
|
379 |
+
# metrics for original vit
|
380 |
+
acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
|
381 |
+
orig_top1.update(acc1_orig[0], image_ten.size(0))
|
382 |
+
orig_top5.update(acc5_orig[0], image_ten.size(0))
|
383 |
+
|
384 |
+
# compute gradient and do SGD step
|
385 |
+
optimizer.zero_grad()
|
386 |
+
loss.backward()
|
387 |
+
optimizer.step()
|
388 |
+
|
389 |
+
if i % args.print_freq == 0:
|
390 |
+
progress.display(i)
|
391 |
+
args.logger.add_scalar('{}/{}'.format('train', 'segmentation_loss'), segmentation_loss,
|
392 |
+
epoch*len(train_loader)+i)
|
393 |
+
args.logger.add_scalar('{}/{}'.format('train', 'classification_loss'), classification_loss,
|
394 |
+
epoch * len(train_loader) + i)
|
395 |
+
args.logger.add_scalar('{}/{}'.format('train', 'orig_top1'), acc1_orig,
|
396 |
+
epoch * len(train_loader) + i)
|
397 |
+
args.logger.add_scalar('{}/{}'.format('train', 'top1'), acc1,
|
398 |
+
epoch * len(train_loader) + i)
|
399 |
+
args.logger.add_scalar('{}/{}'.format('train', 'orig_top5'), acc5_orig,
|
400 |
+
epoch * len(train_loader) + i)
|
401 |
+
args.logger.add_scalar('{}/{}'.format('train', 'top5'), acc5,
|
402 |
+
epoch * len(train_loader) + i)
|
403 |
+
args.logger.add_scalar('{}/{}'.format('train', 'tot_loss'), loss,
|
404 |
+
epoch * len(train_loader) + i)
|
405 |
+
|
406 |
+
|
407 |
+
def validate(val_loader, model, criterion, epoch, args):
|
408 |
+
mse_criterion = torch.nn.MSELoss(reduction='mean')
|
409 |
+
losses = AverageMeter('Loss', ':.4e')
|
410 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
411 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
412 |
+
orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
|
413 |
+
orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
|
414 |
+
progress = ProgressMeter(
|
415 |
+
len(val_loader),
|
416 |
+
[losses, top1, top5, orig_top1, orig_top5],
|
417 |
+
prefix="Epoch: [{}]".format(val_loader))
|
418 |
+
|
419 |
+
# switch to evaluate mode
|
420 |
+
model.eval()
|
421 |
+
|
422 |
+
orig_model = vit(pretrained=True).cuda()
|
423 |
+
orig_model.eval()
|
424 |
+
|
425 |
+
with torch.no_grad():
|
426 |
+
end = time.time()
|
427 |
+
for i, (seg_map, image_ten, class_name) in enumerate(val_loader):
|
428 |
+
if args.gpu is not None:
|
429 |
+
image_ten = image_ten.cuda(args.gpu, non_blocking=True)
|
430 |
+
if torch.cuda.is_available():
|
431 |
+
seg_map = seg_map.cuda(args.gpu, non_blocking=True)
|
432 |
+
class_name = class_name.cuda(args.gpu, non_blocking=True)
|
433 |
+
|
434 |
+
# segmentation loss
|
435 |
+
with torch.enable_grad():
|
436 |
+
relevance = generate_relevance(model, image_ten, index=class_name)
|
437 |
+
|
438 |
+
reverse_seg_map = seg_map.clone()
|
439 |
+
reverse_seg_map[reverse_seg_map == 1] = -1
|
440 |
+
reverse_seg_map[reverse_seg_map == 0] = 1
|
441 |
+
reverse_seg_map[reverse_seg_map == -1] = 0
|
442 |
+
background_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
|
443 |
+
foreground_loss = mse_criterion(relevance * seg_map, seg_map)
|
444 |
+
segmentation_loss = args.lambda_background * background_loss
|
445 |
+
segmentation_loss += args.lambda_foreground * foreground_loss
|
446 |
+
|
447 |
+
# classification loss
|
448 |
+
with torch.no_grad():
|
449 |
+
output = model(image_ten)
|
450 |
+
output_orig = orig_model(image_ten)
|
451 |
+
|
452 |
+
_, pred = output.topk(1, 1, True, True)
|
453 |
+
pred = pred.flatten()
|
454 |
+
if args.temperature != 1:
|
455 |
+
output = output / args.temperature
|
456 |
+
classification_loss = criterion(output, pred)
|
457 |
+
loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
|
458 |
+
|
459 |
+
# save results
|
460 |
+
if i % args.save_interval == 0:
|
461 |
+
with torch.enable_grad():
|
462 |
+
orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
|
463 |
+
for j in range(image_ten.shape[0]):
|
464 |
+
image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
|
465 |
+
new_vis = get_image_with_relevance(image_ten[j], relevance[j])
|
466 |
+
old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
|
467 |
+
gt = get_image_with_relevance(image_ten[j], seg_map[j])
|
468 |
+
h_img = cv2.hconcat([image, gt, old_vis, new_vis])
|
469 |
+
cv2.imwrite(f'{args.experiment_folder}/val_samples/res_{i}_{j}.jpg', h_img)
|
470 |
+
|
471 |
+
# measure accuracy and record loss
|
472 |
+
acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
|
473 |
+
losses.update(loss.item(), image_ten.size(0))
|
474 |
+
top1.update(acc1[0], image_ten.size(0))
|
475 |
+
top5.update(acc5[0], image_ten.size(0))
|
476 |
+
|
477 |
+
# metrics for original vit
|
478 |
+
acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
|
479 |
+
orig_top1.update(acc1_orig[0], image_ten.size(0))
|
480 |
+
orig_top5.update(acc5_orig[0], image_ten.size(0))
|
481 |
+
|
482 |
+
if i % args.print_freq == 0:
|
483 |
+
progress.display(i)
|
484 |
+
args.logger.add_scalar('{}/{}'.format('val', 'segmentation_loss'), segmentation_loss,
|
485 |
+
epoch * len(val_loader) + i)
|
486 |
+
args.logger.add_scalar('{}/{}'.format('val', 'classification_loss'), classification_loss,
|
487 |
+
epoch * len(val_loader) + i)
|
488 |
+
args.logger.add_scalar('{}/{}'.format('val', 'orig_top1'), acc1_orig,
|
489 |
+
epoch * len(val_loader) + i)
|
490 |
+
args.logger.add_scalar('{}/{}'.format('val', 'top1'), acc1,
|
491 |
+
epoch * len(val_loader) + i)
|
492 |
+
args.logger.add_scalar('{}/{}'.format('val', 'orig_top5'), acc5_orig,
|
493 |
+
epoch * len(val_loader) + i)
|
494 |
+
args.logger.add_scalar('{}/{}'.format('val', 'top5'), acc5,
|
495 |
+
epoch * len(val_loader) + i)
|
496 |
+
args.logger.add_scalar('{}/{}'.format('val', 'tot_loss'), loss,
|
497 |
+
epoch * len(val_loader) + i)
|
498 |
+
|
499 |
+
# TODO: this should also be done with the ProgressMeter
|
500 |
+
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
|
501 |
+
.format(top1=top1, top5=top5))
|
502 |
+
|
503 |
+
return losses.avg
|
504 |
+
|
505 |
+
|
506 |
+
def save_checkpoint(state, is_best, folder, filename='checkpoint.pth.tar'):
|
507 |
+
torch.save(state, f'{folder}/{filename}')
|
508 |
+
if is_best:
|
509 |
+
shutil.copyfile(f'{folder}/{filename}', f'{folder}/model_best.pth.tar')
|
510 |
+
|
511 |
+
|
512 |
+
class AverageMeter(object):
|
513 |
+
"""Computes and stores the average and current value"""
|
514 |
+
def __init__(self, name, fmt=':f'):
|
515 |
+
self.name = name
|
516 |
+
self.fmt = fmt
|
517 |
+
self.reset()
|
518 |
+
|
519 |
+
def reset(self):
|
520 |
+
self.val = 0
|
521 |
+
self.avg = 0
|
522 |
+
self.sum = 0
|
523 |
+
self.count = 0
|
524 |
+
|
525 |
+
def update(self, val, n=1):
|
526 |
+
self.val = val
|
527 |
+
self.sum += val * n
|
528 |
+
self.count += n
|
529 |
+
self.avg = self.sum / self.count
|
530 |
+
|
531 |
+
def __str__(self):
|
532 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
533 |
+
return fmtstr.format(**self.__dict__)
|
534 |
+
|
535 |
+
|
536 |
+
class ProgressMeter(object):
|
537 |
+
def __init__(self, num_batches, meters, prefix=""):
|
538 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
539 |
+
self.meters = meters
|
540 |
+
self.prefix = prefix
|
541 |
+
|
542 |
+
def display(self, batch):
|
543 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
544 |
+
entries += [str(meter) for meter in self.meters]
|
545 |
+
print('\t'.join(entries))
|
546 |
+
|
547 |
+
def _get_batch_fmtstr(self, num_batches):
|
548 |
+
num_digits = len(str(num_batches // 1))
|
549 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
550 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
551 |
+
|
552 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
553 |
+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
554 |
+
lr = args.lr * (0.85 ** (epoch // 2))
|
555 |
+
for param_group in optimizer.param_groups:
|
556 |
+
param_group['lr'] = lr
|
557 |
+
|
558 |
+
|
559 |
+
def accuracy(output, target, topk=(1,)):
|
560 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
561 |
+
with torch.no_grad():
|
562 |
+
maxk = max(topk)
|
563 |
+
batch_size = target.size(0)
|
564 |
+
|
565 |
+
_, pred = output.topk(maxk, 1, True, True)
|
566 |
+
pred = pred.t()
|
567 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
568 |
+
|
569 |
+
res = []
|
570 |
+
for k in topk:
|
571 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
572 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
573 |
+
return res
|
574 |
+
|
575 |
+
|
576 |
+
if __name__ == '__main__':
|
577 |
+
main()
|
label_str_to_imagenet_classes.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Dictionary mapping labels (strings) to imagenet classes (ints).
|
16 |
+
|
17 |
+
Generated manually.
|
18 |
+
"""
|
19 |
+
|
20 |
+
label_str_to_imagenet_classes = {
|
21 |
+
'ambulance': 407,
|
22 |
+
'armadillo': 363,
|
23 |
+
'artichoke': 944,
|
24 |
+
'backpack': 414,
|
25 |
+
'bagel': 931,
|
26 |
+
'balance beam': 416,
|
27 |
+
'banana': 954,
|
28 |
+
'band-aid': 419,
|
29 |
+
'beaker': 438,
|
30 |
+
'bell pepper': 945,
|
31 |
+
'billiard table': 736,
|
32 |
+
'binoculars': 447,
|
33 |
+
'broccoli': 937,
|
34 |
+
'brown bear': 294,
|
35 |
+
'burrito': 965,
|
36 |
+
'candle': 470,
|
37 |
+
'canoe': 472,
|
38 |
+
'cello': 486,
|
39 |
+
'cheetah': 293,
|
40 |
+
'cocktail shaker': 503,
|
41 |
+
'common fig': 952,
|
42 |
+
'computer mouse': 673,
|
43 |
+
'cowboy hat': 515,
|
44 |
+
'cucumber': 943,
|
45 |
+
'diaper': 529,
|
46 |
+
'digital clock': 530,
|
47 |
+
'dumbbell': 543,
|
48 |
+
'envelope': 549,
|
49 |
+
'eraser': 767,
|
50 |
+
'filing cabinet': 553,
|
51 |
+
'flowerpot': 738,
|
52 |
+
'flute': 558,
|
53 |
+
'frying pan': 567,
|
54 |
+
'golf ball': 574,
|
55 |
+
'goose': 99,
|
56 |
+
'guacamole': 924,
|
57 |
+
'hair dryer': 589,
|
58 |
+
'hair spray': 585,
|
59 |
+
'hammer': 587,
|
60 |
+
'hamster': 333,
|
61 |
+
'harmonica': 593,
|
62 |
+
'hedgehog': 334,
|
63 |
+
'hippopotamus': 344,
|
64 |
+
'hot dog': 934,
|
65 |
+
'ipod': 605,
|
66 |
+
'jeans': 608,
|
67 |
+
'kite': 21,
|
68 |
+
'koala': 105,
|
69 |
+
'ladle': 618,
|
70 |
+
'laptop': 620,
|
71 |
+
'lemon': 951,
|
72 |
+
'light switch': 844,
|
73 |
+
'lighthouse': 437,
|
74 |
+
'limousine': 627,
|
75 |
+
'lipstick': 629,
|
76 |
+
'lynx': 287,
|
77 |
+
'magpie': 18,
|
78 |
+
'maracas': 641,
|
79 |
+
'measuring cup': 647,
|
80 |
+
'microwave oven': 651,
|
81 |
+
'miniskirt': 655,
|
82 |
+
'missile': 657,
|
83 |
+
'mixing bowl': 659,
|
84 |
+
'mobile phone': 487,
|
85 |
+
'mushroom': 947,
|
86 |
+
'orange': 950,
|
87 |
+
'ostrich': 9,
|
88 |
+
'otter': 360,
|
89 |
+
'paper towel': 700,
|
90 |
+
'pencil case': 709,
|
91 |
+
'pig': 341,
|
92 |
+
'pillow': 721,
|
93 |
+
'pitcher (container)': 725,
|
94 |
+
'pizza': 963,
|
95 |
+
'plastic bag': 728,
|
96 |
+
'polar bear': 296,
|
97 |
+
'pomegranate': 957,
|
98 |
+
'pretzel': 932,
|
99 |
+
'printer': 742,
|
100 |
+
'punching bag': 747,
|
101 |
+
'racket': 752,
|
102 |
+
'red panda': 387,
|
103 |
+
'remote control': 761,
|
104 |
+
'rugby ball': 768,
|
105 |
+
'ruler': 769,
|
106 |
+
'saxophone': 776,
|
107 |
+
'screwdriver': 784,
|
108 |
+
'sea lion': 150,
|
109 |
+
'seat belt': 785,
|
110 |
+
'skunk': 361,
|
111 |
+
'snowmobile': 802,
|
112 |
+
'soap dispenser': 804,
|
113 |
+
'sock': 806,
|
114 |
+
'sombrero': 808,
|
115 |
+
'spatula': 813,
|
116 |
+
'starfish': 327,
|
117 |
+
'strawberry': 949,
|
118 |
+
'studio couch': 831,
|
119 |
+
'taxi': 468,
|
120 |
+
'teapot': 849,
|
121 |
+
'teddy bear': 850,
|
122 |
+
'tennis ball': 852,
|
123 |
+
'toaster': 859,
|
124 |
+
'toilet paper': 999,
|
125 |
+
'torch': 862,
|
126 |
+
'traffic light': 920,
|
127 |
+
'vase': 883,
|
128 |
+
'volleyball (ball)': 890,
|
129 |
+
'washing machine': 897,
|
130 |
+
'wok': 909,
|
131 |
+
'zebra': 340,
|
132 |
+
'zucchini': 939
|
133 |
+
}
|
objectnet_dataset.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from torch.utils import data
|
3 |
+
from torchvision.datasets import ImageFolder
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import argparse
|
9 |
+
from tqdm import tqdm
|
10 |
+
from munkres import Munkres
|
11 |
+
import multiprocessing
|
12 |
+
from multiprocessing import Process, Manager
|
13 |
+
import collections
|
14 |
+
import torchvision.transforms as transforms
|
15 |
+
import torchvision.transforms.functional as TF
|
16 |
+
import random
|
17 |
+
import torchvision
|
18 |
+
import cv2
|
19 |
+
from label_str_to_imagenet_classes import label_str_to_imagenet_classes
|
20 |
+
|
21 |
+
torch.manual_seed(0)
|
22 |
+
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
|
23 |
+
std=[0.5, 0.5, 0.5])
|
24 |
+
|
25 |
+
transform = transforms.Compose([
|
26 |
+
transforms.Resize(256),
|
27 |
+
transforms.CenterCrop(224),
|
28 |
+
transforms.ToTensor(),
|
29 |
+
normalize,
|
30 |
+
])
|
31 |
+
|
32 |
+
class ObjectNetDataset(ImageFolder):
|
33 |
+
def __init__(self, imagenet_path):
|
34 |
+
self._imagenet_path = imagenet_path
|
35 |
+
self._all_images = []
|
36 |
+
|
37 |
+
o_dataset = ImageFolder(self._imagenet_path)
|
38 |
+
# get mappings folder
|
39 |
+
mappings_folder = os.path.abspath(
|
40 |
+
os.path.join(self._imagenet_path, "../mappings")
|
41 |
+
)
|
42 |
+
|
43 |
+
# get ObjectNet label to ImageNet label mapping
|
44 |
+
with open(
|
45 |
+
os.path.join(mappings_folder, "objectnet_to_imagenet_1k.json")
|
46 |
+
) as file_handle:
|
47 |
+
o_label_to_all_i_labels = json.load(file_handle)
|
48 |
+
|
49 |
+
# now remove double i labels to avoid confusion
|
50 |
+
o_label_to_i_labels = {
|
51 |
+
o_label: all_i_label.split("; ")
|
52 |
+
for o_label, all_i_label in o_label_to_all_i_labels.items()
|
53 |
+
}
|
54 |
+
|
55 |
+
# some in-between mappings ...
|
56 |
+
o_folder_to_o_idx = o_dataset.class_to_idx
|
57 |
+
with open(
|
58 |
+
os.path.join(mappings_folder, "folder_to_objectnet_label.json")
|
59 |
+
) as file_handle:
|
60 |
+
o_folder_o_label = json.load(file_handle)
|
61 |
+
|
62 |
+
# now get mapping from o_label to o_idx
|
63 |
+
o_label_to_o_idx = {
|
64 |
+
o_label: o_folder_to_o_idx[o_folder]
|
65 |
+
for o_folder, o_label in o_folder_o_label.items()
|
66 |
+
}
|
67 |
+
|
68 |
+
# some in-between mappings ...
|
69 |
+
with open(
|
70 |
+
os.path.join(mappings_folder, "pytorch_to_imagenet_2012_id.json")
|
71 |
+
) as file_handle:
|
72 |
+
i_idx_to_i_line = json.load(file_handle)
|
73 |
+
with open(
|
74 |
+
os.path.join(mappings_folder, "imagenet_to_label_2012_v2")
|
75 |
+
) as file_handle:
|
76 |
+
i_line_to_i_label = file_handle.readlines()
|
77 |
+
|
78 |
+
i_line_to_i_label = {
|
79 |
+
i_line: i_label[:-1]
|
80 |
+
for i_line, i_label in enumerate(i_line_to_i_label)
|
81 |
+
}
|
82 |
+
|
83 |
+
# now get mapping from i_label to i_idx
|
84 |
+
i_label_to_i_idx = {
|
85 |
+
i_line_to_i_label[i_line]: int(i_idx)
|
86 |
+
for i_idx, i_line in i_idx_to_i_line.items()
|
87 |
+
}
|
88 |
+
|
89 |
+
# now get the final mapping of interest!!!
|
90 |
+
o_idx_to_i_idxs = {
|
91 |
+
o_label_to_o_idx[o_label]: [
|
92 |
+
i_label_to_i_idx[i_label] for i_label in i_labels
|
93 |
+
]
|
94 |
+
for o_label, i_labels in o_label_to_i_labels.items()
|
95 |
+
}
|
96 |
+
|
97 |
+
self._tag_list = []
|
98 |
+
# now get a list of files of interest
|
99 |
+
for filepath, o_idx in o_dataset.samples:
|
100 |
+
if o_idx not in o_idx_to_i_idxs:
|
101 |
+
continue
|
102 |
+
rel_file = os.path.relpath(filepath, self._imagenet_path)
|
103 |
+
if o_idx_to_i_idxs[o_idx][0] not in self._tag_list:
|
104 |
+
self._tag_list.append(o_idx_to_i_idxs[o_idx][0])
|
105 |
+
self._all_images.append((rel_file, o_idx_to_i_idxs[o_idx][0]))
|
106 |
+
|
107 |
+
def __getitem__(self, item):
|
108 |
+
image_path, classification = self._all_images[item]
|
109 |
+
image_path = os.path.join(self._imagenet_path, image_path)
|
110 |
+
image = Image.open(image_path)
|
111 |
+
image = image.convert('RGB')
|
112 |
+
image = transform(image)
|
113 |
+
|
114 |
+
return image, classification
|
115 |
+
|
116 |
+
def __len__(self):
|
117 |
+
return len(self._all_images)
|
robustness_dataset.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from torch.utils import data
|
3 |
+
from torchvision.datasets import ImageFolder
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import argparse
|
9 |
+
from tqdm import tqdm
|
10 |
+
from munkres import Munkres
|
11 |
+
import multiprocessing
|
12 |
+
from multiprocessing import Process, Manager
|
13 |
+
import collections
|
14 |
+
import torchvision.transforms as transforms
|
15 |
+
import torchvision.transforms.functional as TF
|
16 |
+
import random
|
17 |
+
import torchvision
|
18 |
+
import cv2
|
19 |
+
from label_str_to_imagenet_classes import label_str_to_imagenet_classes
|
20 |
+
|
21 |
+
torch.manual_seed(0)
|
22 |
+
|
23 |
+
ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))
|
24 |
+
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
|
25 |
+
std=[0.5, 0.5, 0.5])
|
26 |
+
|
27 |
+
transform = transforms.Compose([
|
28 |
+
transforms.Resize(256),
|
29 |
+
transforms.CenterCrop(224),
|
30 |
+
transforms.ToTensor(),
|
31 |
+
normalize,
|
32 |
+
])
|
33 |
+
|
34 |
+
class RobustnessDataset(ImageFolder):
|
35 |
+
def __init__(self, imagenet_path, imagenet_classes_path='imagenet_classes.json', isV2=False, isSI=False):
|
36 |
+
self._isV2 = isV2
|
37 |
+
self._isSI = isSI
|
38 |
+
self._imagenet_path = imagenet_path
|
39 |
+
with open(imagenet_classes_path, 'r') as f:
|
40 |
+
self._imagenet_classes = json.load(f)
|
41 |
+
self._tag_list = [tag for tag in os.listdir(self._imagenet_path)]
|
42 |
+
self._all_images = []
|
43 |
+
for tag in self._tag_list:
|
44 |
+
base_dir = os.path.join(self._imagenet_path, tag)
|
45 |
+
for i, file in enumerate(os.listdir(base_dir)):
|
46 |
+
self._all_images.append(ImageItem(file, tag))
|
47 |
+
|
48 |
+
|
49 |
+
def __getitem__(self, item):
|
50 |
+
image_item = self._all_images[item]
|
51 |
+
image_path = os.path.join(self._imagenet_path, image_item.tag, image_item.image_name)
|
52 |
+
image = Image.open(image_path)
|
53 |
+
image = image.convert('RGB')
|
54 |
+
image = transform(image)
|
55 |
+
|
56 |
+
if self._isV2:
|
57 |
+
class_name = int(image_item.tag)
|
58 |
+
elif self._isSI:
|
59 |
+
class_name = int(label_str_to_imagenet_classes[image_item.tag])
|
60 |
+
else:
|
61 |
+
class_name = int(self._imagenet_classes[image_item.tag])
|
62 |
+
|
63 |
+
return image, class_name
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
return len(self._all_images)
|
robustness_dataset_per_class.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from torchvision.datasets import ImageFolder
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
from PIL import Image
|
6 |
+
import collections
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
from label_str_to_imagenet_classes import label_str_to_imagenet_classes
|
9 |
+
|
10 |
+
torch.manual_seed(0)
|
11 |
+
|
12 |
+
ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))
|
13 |
+
|
14 |
+
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
|
15 |
+
std=[0.5, 0.5, 0.5])
|
16 |
+
|
17 |
+
transform = transforms.Compose([
|
18 |
+
transforms.Resize(256),
|
19 |
+
transforms.CenterCrop(224),
|
20 |
+
transforms.ToTensor(),
|
21 |
+
normalize,
|
22 |
+
])
|
23 |
+
|
24 |
+
class RobustnessDataset(ImageFolder):
|
25 |
+
def __init__(self, imagenet_path, folder, imagenet_classes_path='imagenet_classes.json', isV2=False, isSI=False):
|
26 |
+
self._isV2 = isV2
|
27 |
+
self._isSI = isSI
|
28 |
+
self._folder = folder
|
29 |
+
self._imagenet_path = imagenet_path
|
30 |
+
with open(imagenet_classes_path, 'r') as f:
|
31 |
+
self._imagenet_classes = json.load(f)
|
32 |
+
self._all_images = []
|
33 |
+
|
34 |
+
base_dir = os.path.join(self._imagenet_path, folder)
|
35 |
+
for i, file in enumerate(os.listdir(base_dir)):
|
36 |
+
self._all_images.append(ImageItem(file, folder))
|
37 |
+
|
38 |
+
|
39 |
+
def __getitem__(self, item):
|
40 |
+
image_item = self._all_images[item]
|
41 |
+
image_path = os.path.join(self._imagenet_path, image_item.tag, image_item.image_name)
|
42 |
+
image = Image.open(image_path)
|
43 |
+
image = image.convert('RGB')
|
44 |
+
image = transform(image)
|
45 |
+
|
46 |
+
if self._isV2:
|
47 |
+
class_name = int(image_item.tag)
|
48 |
+
elif self._isSI:
|
49 |
+
class_name = int(label_str_to_imagenet_classes[image_item.tag])
|
50 |
+
else:
|
51 |
+
class_name = int(self._imagenet_classes[image_item.tag])
|
52 |
+
|
53 |
+
return image, class_name
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self._all_images)
|
57 |
+
|
58 |
+
def get_classname(self):
|
59 |
+
if self._isV2:
|
60 |
+
class_name = int(self._folder)
|
61 |
+
elif self._isSI:
|
62 |
+
class_name = int(label_str_to_imagenet_classes[self._folder])
|
63 |
+
else:
|
64 |
+
class_name = int(self._imagenet_classes[self._folder])
|
65 |
+
return class_name
|
samples/augreg_base/1_in.png
ADDED
samples/augreg_base/2_in.png
ADDED
samples/augreg_base/3_in.png
ADDED
samples/augreg_base/a.png
ADDED
samples/augreg_base/a_2.png
ADDED
samples/augreg_base/a_3.png
ADDED
samples/catdog.png
ADDED
samples/deit_base/1_in.png
ADDED
samples/deit_base/2_in.png
ADDED
samples/deit_base/3_in.png
ADDED
samples/deit_base/a.png
ADDED
samples/deit_base/a_2.png
ADDED