jcarnero commited on
Commit
18f2566
·
1 Parent(s): e338a46

gradio app. models refactor

Browse files
README.md CHANGED
@@ -3,3 +3,15 @@
3
  Train model for birds classification and gradio app
4
 
5
  Training is done using fastai, deployment mimics its transforms to publish a gradio app that has no fastai dependencies.
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  Train model for birds classification and gradio app
4
 
5
  Training is done using fastai, deployment mimics its transforms to publish a gradio app that has no fastai dependencies.
6
+
7
+ ## Train
8
+
9
+ ```bash
10
+ conda env create -f environment.yml
11
+ ```
12
+
13
+ ```bash
14
+ conda activate fastai
15
+ cd training
16
+ python -m birds.train
17
+ ```
deployment/app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from model import get_model, apply_weights, copy_weight
5
+ from vocab import vocab
6
+ from transforms import resized_crop_pad, gpu_crop
7
+ from torchvision.transforms import Normalize, ToTensor
8
+
9
+ model = get_model()
10
+ state = torch.load("../models/vit_saved.pth", map_location="cpu")
11
+ apply_weights(model, state, copy_weight)
12
+
13
+ to_tensor = ToTensor()
14
+ norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
15
+
16
+
17
+ def classify_image(inp):
18
+ inp = Image.fromarray(inp)
19
+ transformed_input = resized_crop_pad(inp, (460, 460))
20
+ transformed_input = to_tensor(transformed_input).unsqueeze(0)
21
+ transformed_input = gpu_crop(transformed_input, (224, 224))
22
+ transformed_input = norm(transformed_input)
23
+ model.eval()
24
+ with torch.no_grad():
25
+ pred = model(transformed_input)
26
+ pred = torch.argmax(pred, dim=1)
27
+ return vocab[pred]
28
+
29
+
30
+ iface = gr.Interface(
31
+ fn=classify_image,
32
+ inputs=gr.inputs.Image(),
33
+ outputs="text",
34
+ title="Birds Classifier without Fastai",
35
+ description="A birds classifier over 200 species trained with Fastai"
36
+ " and deployed with plain pytorch in Gradio.",
37
+ ).launch()
deployment/model.py CHANGED
@@ -46,8 +46,8 @@ def apply_weights(
46
  application_function: callable,
47
  ):
48
  """
49
- Takes an input state_dict and applies those weights to the `input_model`, potentially
50
- with a modifier function.
51
 
52
  Args:
53
  input_model (`nn.Module`):
@@ -56,7 +56,8 @@ def apply_weights(
56
  A dictionary of weights, the trained model's `state_dict()`
57
  application_function (`callable`):
58
  A function that takes in one parameter and layer name from `input_model`
59
- and the `input_weights`. Should apply the weights from the state dict into `input_model`.
 
60
  """
61
  model_dict = input_model.state_dict()
62
  for name, parameter in model_dict.items():
 
46
  application_function: callable,
47
  ):
48
  """
49
+ Takes an input state_dict and applies those weights to the `input_model`,
50
+ potentially with a modifier function.
51
 
52
  Args:
53
  input_model (`nn.Module`):
 
56
  A dictionary of weights, the trained model's `state_dict()`
57
  application_function (`callable`):
58
  A function that takes in one parameter and layer name from `input_model`
59
+ and the `input_weights`. Should apply the weights from the state dict into
60
+ `input_model`.
61
  """
62
  model_dict = input_model.state_dict()
63
  for name, parameter in model_dict.items():
deployment/requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio==3.18.0
2
  pillow==9.4.0
3
  timm==0.6.12
4
  torch==1.13.1
 
1
+ gradio==3.20.1
2
  pillow==9.4.0
3
  timm==0.6.12
4
  torch==1.13.1
deployment/vocab.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vocab = [
2
+ "Acadian_Flycatcher",
3
+ "American_Crow",
4
+ "American_Goldfinch",
5
+ "American_Pipit",
6
+ "American_Redstart",
7
+ "American_Three_Toed_Woodpecker",
8
+ "Anna_Hummingbird",
9
+ "Artic_Tern",
10
+ "Baird_Sparrow",
11
+ "Baltimore_Oriole",
12
+ "Bank_Swallow",
13
+ "Barn_Swallow",
14
+ "Bay_Breasted_Warbler",
15
+ "Belted_Kingfisher",
16
+ "Bewick_Wren",
17
+ "Black_And_White_Warbler",
18
+ "Black_Billed_Cuckoo",
19
+ "Black_Capped_Vireo",
20
+ "Black_Footed_Albatross",
21
+ "Black_Tern",
22
+ "Black_Throated_Blue_Warbler",
23
+ "Black_Throated_Sparrow",
24
+ "Blue_Grosbeak",
25
+ "Blue_Headed_Vireo",
26
+ "Blue_Jay",
27
+ "Blue_Winged_Warbler",
28
+ "Boat_Tailed_Grackle",
29
+ "Bobolink",
30
+ "Bohemian_Waxwing",
31
+ "Brandt_Cormorant",
32
+ "Brewer_Blackbird",
33
+ "Brewer_Sparrow",
34
+ "Bronzed_Cowbird",
35
+ "Brown_Creeper",
36
+ "Brown_Pelican",
37
+ "Brown_Thrasher",
38
+ "Cactus_Wren",
39
+ "California_Gull",
40
+ "Canada_Warbler",
41
+ "Cape_Glossy_Starling",
42
+ "Cape_May_Warbler",
43
+ "Cardinal",
44
+ "Carolina_Wren",
45
+ "Caspian_Tern",
46
+ "Cedar_Waxwing",
47
+ "Cerulean_Warbler",
48
+ "Chestnut_Sided_Warbler",
49
+ "Chipping_Sparrow",
50
+ "Chuck_Will_Widow",
51
+ "Clark_Nutcracker",
52
+ "Clay_Colored_Sparrow",
53
+ "Cliff_Swallow",
54
+ "Common_Raven",
55
+ "Common_Tern",
56
+ "Common_Yellowthroat",
57
+ "Crested_Auklet",
58
+ "Dark_Eyed_Junco",
59
+ "Downy_Woodpecker",
60
+ "Eared_Grebe",
61
+ "Eastern_Towhee",
62
+ "Elegant_Tern",
63
+ "European_Goldfinch",
64
+ "Evening_Grosbeak",
65
+ "Field_Sparrow",
66
+ "Fish_Crow",
67
+ "Florida_Jay",
68
+ "Forsters_Tern",
69
+ "Fox_Sparrow",
70
+ "Frigatebird",
71
+ "Gadwall",
72
+ "Geococcyx",
73
+ "Glaucous_Winged_Gull",
74
+ "Golden_Winged_Warbler",
75
+ "Grasshopper_Sparrow",
76
+ "Gray_Catbird",
77
+ "Gray_Crowned_Rosy_Finch",
78
+ "Gray_Kingbird",
79
+ "Great_Crested_Flycatcher",
80
+ "Great_Grey_Shrike",
81
+ "Green_Jay",
82
+ "Green_Kingfisher",
83
+ "Green_Tailed_Towhee",
84
+ "Green_Violetear",
85
+ "Groove_Billed_Ani",
86
+ "Harris_Sparrow",
87
+ "Heermann_Gull",
88
+ "Henslow_Sparrow",
89
+ "Herring_Gull",
90
+ "Hooded_Merganser",
91
+ "Hooded_Oriole",
92
+ "Hooded_Warbler",
93
+ "Horned_Grebe",
94
+ "Horned_Lark",
95
+ "Horned_Puffin",
96
+ "House_Sparrow",
97
+ "House_Wren",
98
+ "Indigo_Bunting",
99
+ "Ivory_Gull",
100
+ "Kentucky_Warbler",
101
+ "Laysan_Albatross",
102
+ "Lazuli_Bunting",
103
+ "Le_Conte_Sparrow",
104
+ "Least_Auklet",
105
+ "Least_Flycatcher",
106
+ "Least_Tern",
107
+ "Lincoln_Sparrow",
108
+ "Loggerhead_Shrike",
109
+ "Long_Tailed_Jaeger",
110
+ "Louisiana_Waterthrush",
111
+ "Magnolia_Warbler",
112
+ "Mallard",
113
+ "Mangrove_Cuckoo",
114
+ "Marsh_Wren",
115
+ "Mockingbird",
116
+ "Mourning_Warbler",
117
+ "Myrtle_Warbler",
118
+ "Nashville_Warbler",
119
+ "Nelson_Sharp_Tailed_Sparrow",
120
+ "Nighthawk",
121
+ "Northern_Flicker",
122
+ "Northern_Fulmar",
123
+ "Northern_Waterthrush",
124
+ "Olive_Sided_Flycatcher",
125
+ "Orange_Crowned_Warbler",
126
+ "Orchard_Oriole",
127
+ "Ovenbird",
128
+ "Pacific_Loon",
129
+ "Painted_Bunting",
130
+ "Palm_Warbler",
131
+ "Parakeet_Auklet",
132
+ "Pelagic_Cormorant",
133
+ "Philadelphia_Vireo",
134
+ "Pied_Billed_Grebe",
135
+ "Pied_Kingfisher",
136
+ "Pigeon_Guillemot",
137
+ "Pileated_Woodpecker",
138
+ "Pine_Grosbeak",
139
+ "Pine_Warbler",
140
+ "Pomarine_Jaeger",
141
+ "Prairie_Warbler",
142
+ "Prothonotary_Warbler",
143
+ "Purple_Finch",
144
+ "Red_Bellied_Woodpecker",
145
+ "Red_Breasted_Merganser",
146
+ "Red_Cockaded_Woodpecker",
147
+ "Red_Eyed_Vireo",
148
+ "Red_Faced_Cormorant",
149
+ "Red_Headed_Woodpecker",
150
+ "Red_Legged_Kittiwake",
151
+ "Red_Winged_Blackbird",
152
+ "Rhinoceros_Auklet",
153
+ "Ring_Billed_Gull",
154
+ "Ringed_Kingfisher",
155
+ "Rock_Wren",
156
+ "Rose_Breasted_Grosbeak",
157
+ "Ruby_Throated_Hummingbird",
158
+ "Rufous_Hummingbird",
159
+ "Rusty_Blackbird",
160
+ "Sage_Thrasher",
161
+ "Savannah_Sparrow",
162
+ "Sayornis",
163
+ "Scarlet_Tanager",
164
+ "Scissor_Tailed_Flycatcher",
165
+ "Scott_Oriole",
166
+ "Seaside_Sparrow",
167
+ "Shiny_Cowbird",
168
+ "Slaty_Backed_Gull",
169
+ "Song_Sparrow",
170
+ "Sooty_Albatross",
171
+ "Spotted_Catbird",
172
+ "Summer_Tanager",
173
+ "Swainson_Warbler",
174
+ "Tennessee_Warbler",
175
+ "Tree_Sparrow",
176
+ "Tree_Swallow",
177
+ "Tropical_Kingbird",
178
+ "Vermilion_Flycatcher",
179
+ "Vesper_Sparrow",
180
+ "Warbling_Vireo",
181
+ "Western_Grebe",
182
+ "Western_Gull",
183
+ "Western_Meadowlark",
184
+ "Western_Wood_Pewee",
185
+ "Whip_Poor_Will",
186
+ "White_Breasted_Kingfisher",
187
+ "White_Breasted_Nuthatch",
188
+ "White_Crowned_Sparrow",
189
+ "White_Eyed_Vireo",
190
+ "White_Necked_Raven",
191
+ "White_Pelican",
192
+ "White_Throated_Sparrow",
193
+ "Wilson_Warbler",
194
+ "Winter_Wren",
195
+ "Worm_Eating_Warbler",
196
+ "Yellow_Bellied_Flycatcher",
197
+ "Yellow_Billed_Cuckoo",
198
+ "Yellow_Breasted_Chat",
199
+ "Yellow_Headed_Blackbird",
200
+ "Yellow_Throated_Vireo",
201
+ "Yellow_Warbler",
202
+ ]
training/environment.yml → environment.yml RENAMED
@@ -17,6 +17,7 @@ dependencies:
17
  - pip:
18
  - ipykernel
19
  - ipywidgets
 
20
  - timm==0.6.12
21
  - kaggle==1.5.12
22
  - flake8
 
17
  - pip:
18
  - ipykernel
19
  - ipywidgets
20
+ - gradio==3.20.1
21
  - timm==0.6.12
22
  - kaggle==1.5.12
23
  - flake8
{training/models → models}/.gitignore RENAMED
File without changes
training/birds/config.py CHANGED
@@ -1,4 +1,5 @@
1
  DATA_STORAGE_PATH = "../data"
 
2
  DATASET = "200-bird-species-with-11788-images"
3
  OWNER = "veeralakrishna"
4
 
 
1
  DATA_STORAGE_PATH = "../data"
2
+ MODELS_STORAGE_PATH = "../models"
3
  DATASET = "200-bird-species-with-11788-images"
4
  OWNER = "veeralakrishna"
5
 
training/birds/train.py CHANGED
@@ -76,5 +76,6 @@ if __name__ == "__main__":
76
 
77
  learner.fine_tune(7, base_lr=0.001, freeze_epochs=12)
78
 
79
- learner.export("models/vit_exported")
 
80
  learner.save("vit_saved", with_opt=False)
 
76
 
77
  learner.fine_tune(7, base_lr=0.001, freeze_epochs=12)
78
 
79
+ learner.export(Path(config.MODELS_STORAGE_PATH).resolve() / "vit_exported.pkl")
80
+ learner.model_dir = Path(config.MODELS_STORAGE_PATH).resolve()
81
  learner.save("vit_saved", with_opt=False)