SatwikKambham commited on
Commit
22ae476
1 Parent(s): e8561a6

Added app.py and requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +211 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import transforms
7
+
8
+
9
+ class SimpleResidualBlock(nn.Module):
10
+ def __init__(self, in_channels, out_channels, set_stride=False):
11
+ super().__init__()
12
+ stride = 2 if in_channels != out_channels and set_stride else 1
13
+
14
+ self.conv1 = nn.LazyConv2d(
15
+ out_channels,
16
+ kernel_size=3,
17
+ padding="same" if stride == 1 else 1,
18
+ stride=stride,
19
+ )
20
+ self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same")
21
+
22
+ self.bn1 = nn.LazyBatchNorm2d()
23
+ self.bn2 = nn.LazyBatchNorm2d()
24
+
25
+ self.relu = nn.ReLU()
26
+
27
+ if in_channels != out_channels:
28
+ self.residual = nn.Sequential(
29
+ nn.LazyConv2d(out_channels, kernel_size=1, stride=stride),
30
+ nn.LazyBatchNorm2d(),
31
+ )
32
+ else:
33
+ self.residual = nn.Identity()
34
+
35
+ def forward(self, x):
36
+ out = self.relu(self.bn1(self.conv1(x)))
37
+ out = self.bn2(self.conv2(out))
38
+ out += self.residual(x)
39
+ out = self.relu(out)
40
+ return out
41
+
42
+
43
+ class BottleneckResidualBlock(nn.Module):
44
+ def __init__(
45
+ self, in_channels, out_channels, identity_mapping=False, set_stride=False
46
+ ):
47
+ super().__init__()
48
+ stride = 2 if in_channels != out_channels and set_stride else 1
49
+
50
+ self.conv1 = nn.LazyConv2d(
51
+ out_channels,
52
+ kernel_size=1,
53
+ padding="same" if stride == 1 else 0,
54
+ stride=stride,
55
+ )
56
+ self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same")
57
+ self.conv3 = nn.LazyConv2d(out_channels * 4, kernel_size=1, padding="same")
58
+
59
+ self.bn1 = nn.LazyBatchNorm2d()
60
+ self.bn2 = nn.LazyBatchNorm2d()
61
+ self.bn3 = nn.LazyBatchNorm2d()
62
+
63
+ self.relu = nn.ReLU()
64
+
65
+ if in_channels != out_channels or not identity_mapping:
66
+ self.residual = nn.Sequential(
67
+ nn.LazyConv2d(out_channels * 4, kernel_size=1, stride=stride),
68
+ nn.LazyBatchNorm2d(),
69
+ )
70
+ else:
71
+ self.residual = nn.Identity()
72
+
73
+ def forward(self, x):
74
+ out = self.relu(self.bn1(self.conv1(x)))
75
+ out = self.relu(self.bn2(self.conv2(out)))
76
+ out = self.bn3(self.conv3(out))
77
+ out += self.residual(x)
78
+ out = self.relu(out)
79
+ return out
80
+
81
+
82
+ RESNET_18 = [2, 2, 2, 2]
83
+ RESNET_34 = [3, 4, 6, 3]
84
+ RESNET_50 = [3, 4, 6, 3]
85
+ RESNET_101 = [3, 4, 23, 3]
86
+ RESNET_152 = [3, 8, 36, 3]
87
+
88
+
89
+ class ResNet(nn.Module):
90
+ def __init__(self, arch=RESNET_18, block="simple", num_classes=256):
91
+ super().__init__()
92
+ self.conv1 = nn.Sequential(
93
+ nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),
94
+ nn.LazyBatchNorm2d(),
95
+ nn.ReLU(),
96
+ )
97
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
98
+ self.conv2 = self._make_layer(64, 64, arch[0], set_stride=False, block=block)
99
+ self.conv3 = self._make_layer(64, 128, arch[1], block=block)
100
+ self.conv4 = self._make_layer(128, 256, arch[2], block=block)
101
+ self.conv5 = self._make_layer(256, 512, arch[3], block=block)
102
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
103
+ self.flatten = nn.Flatten()
104
+ self.fc = nn.LazyLinear(num_classes)
105
+
106
+ def _make_layer(
107
+ self, in_channels, out_channels, num_blocks, set_stride=True, block="simple"
108
+ ):
109
+ """Block is either 'simple' or 'bottleneck'"""
110
+ layers = []
111
+ for i in range(num_blocks):
112
+ layers.append(
113
+ SimpleResidualBlock(in_channels, out_channels, set_stride=set_stride)
114
+ if block == "simple"
115
+ else BottleneckResidualBlock(
116
+ in_channels if i == 0 else out_channels * 4,
117
+ out_channels,
118
+ set_stride=set_stride,
119
+ )
120
+ )
121
+ set_stride = False
122
+ return nn.Sequential(*layers)
123
+
124
+ def forward(self, x):
125
+ out = self.conv1(x)
126
+ out = self.maxpool(self.conv2(out))
127
+ out = self.conv3(out)
128
+ out = self.conv4(out)
129
+ out = self.conv5(out)
130
+ out = self.avgpool(out)
131
+ out = self.flatten(out)
132
+ out = self.fc(out)
133
+ return out
134
+
135
+ def _init_weights(module):
136
+ # Initlize weights with glorot uniform
137
+ if isinstance(module, nn.Conv2d):
138
+ nn.init.xavier_uniform_(module.weight)
139
+ nn.init.zeros_(module.bias)
140
+ elif isinstance(module, nn.Linear):
141
+ nn.init.xavier_uniform_(module.weight)
142
+ nn.init.zeros_(module.bias)
143
+
144
+
145
+ class ImageClassifier:
146
+ def __init__(self, checkpoint_path):
147
+ self.checkpoint_path = checkpoint_path
148
+ self.model = self.load_model(checkpoint_path)
149
+ self.transform = self.get_transform((244, 244))
150
+ self.labels = [
151
+ "airplane",
152
+ "automobile",
153
+ "bird",
154
+ "cat",
155
+ "deer",
156
+ "dog",
157
+ "frog",
158
+ "horse",
159
+ "ship",
160
+ "truck",
161
+ ]
162
+
163
+ def load_model(self, checkpoint_path):
164
+ classifier = ResNet(
165
+ arch=RESNET_18,
166
+ block="simple",
167
+ num_classes=10,
168
+ )
169
+ classifier.load_state_dict(torch.load(checkpoint_path))
170
+ classifier = classifier.cpu()
171
+ classifier.eval()
172
+ return classifier
173
+
174
+ def get_transform(self, img_shape):
175
+ preprocess_transform = transforms.Compose(
176
+ [
177
+ transforms.Resize(img_shape),
178
+ transforms.ToTensor(),
179
+ ]
180
+ )
181
+ return preprocess_transform
182
+
183
+ def predict(self, image):
184
+ image_tensor = self.transform(image).unsqueeze(0)
185
+ with torch.no_grad():
186
+ logits = self.model(image_tensor)
187
+ probs = logits.softmax(dim=1)[0]
188
+ return {label: prob.item() for label, prob in zip(self.labels, probs)}
189
+
190
+ def classify(self, input_image):
191
+ return self.predict(input_image)
192
+
193
+
194
+ def classify(input_image):
195
+ return classifier.classify(input_image)
196
+
197
+
198
+ checkpoint_path = hf_hub_download(
199
+ repo_id="SatwikKambham/resnet18-cifar10",
200
+ filename="model.pt",
201
+ )
202
+ classifier = ImageClassifier(checkpoint_path)
203
+ iface = gr.Interface(
204
+ classify,
205
+ inputs=[
206
+ gr.Image(label="Input Image", type="pil"),
207
+ ],
208
+ outputs=gr.Label(num_top_classes=3),
209
+ )
210
+
211
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ huggingface_hub