Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
class SimpleResidualBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, set_stride=False): | |
super().__init__() | |
stride = 2 if in_channels != out_channels and set_stride else 1 | |
self.conv1 = nn.LazyConv2d( | |
out_channels, | |
kernel_size=3, | |
padding="same" if stride == 1 else 1, | |
stride=stride, | |
) | |
self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same") | |
self.bn1 = nn.LazyBatchNorm2d() | |
self.bn2 = nn.LazyBatchNorm2d() | |
self.relu = nn.ReLU() | |
if in_channels != out_channels: | |
self.residual = nn.Sequential( | |
nn.LazyConv2d(out_channels, kernel_size=1, stride=stride), | |
nn.LazyBatchNorm2d(), | |
) | |
else: | |
self.residual = nn.Identity() | |
def forward(self, x): | |
out = self.relu(self.bn1(self.conv1(x))) | |
out = self.bn2(self.conv2(out)) | |
out += self.residual(x) | |
out = self.relu(out) | |
return out | |
class BottleneckResidualBlock(nn.Module): | |
def __init__( | |
self, in_channels, out_channels, identity_mapping=False, set_stride=False | |
): | |
super().__init__() | |
stride = 2 if in_channels != out_channels and set_stride else 1 | |
self.conv1 = nn.LazyConv2d( | |
out_channels, | |
kernel_size=1, | |
padding="same" if stride == 1 else 0, | |
stride=stride, | |
) | |
self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same") | |
self.conv3 = nn.LazyConv2d(out_channels * 4, kernel_size=1, padding="same") | |
self.bn1 = nn.LazyBatchNorm2d() | |
self.bn2 = nn.LazyBatchNorm2d() | |
self.bn3 = nn.LazyBatchNorm2d() | |
self.relu = nn.ReLU() | |
if in_channels != out_channels or not identity_mapping: | |
self.residual = nn.Sequential( | |
nn.LazyConv2d(out_channels * 4, kernel_size=1, stride=stride), | |
nn.LazyBatchNorm2d(), | |
) | |
else: | |
self.residual = nn.Identity() | |
def forward(self, x): | |
out = self.relu(self.bn1(self.conv1(x))) | |
out = self.relu(self.bn2(self.conv2(out))) | |
out = self.bn3(self.conv3(out)) | |
out += self.residual(x) | |
out = self.relu(out) | |
return out | |
RESNET_18 = [2, 2, 2, 2] | |
RESNET_34 = [3, 4, 6, 3] | |
RESNET_50 = [3, 4, 6, 3] | |
RESNET_101 = [3, 4, 23, 3] | |
RESNET_152 = [3, 8, 36, 3] | |
class ResNet(nn.Module): | |
def __init__(self, arch=RESNET_18, block="simple", num_classes=256): | |
super().__init__() | |
self.conv1 = nn.Sequential( | |
nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3), | |
nn.LazyBatchNorm2d(), | |
nn.ReLU(), | |
) | |
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) | |
self.conv2 = self._make_layer(64, 64, arch[0], set_stride=False, block=block) | |
self.conv3 = self._make_layer(64, 128, arch[1], block=block) | |
self.conv4 = self._make_layer(128, 256, arch[2], block=block) | |
self.conv5 = self._make_layer(256, 512, arch[3], block=block) | |
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
self.flatten = nn.Flatten() | |
self.fc = nn.LazyLinear(num_classes) | |
def _make_layer( | |
self, in_channels, out_channels, num_blocks, set_stride=True, block="simple" | |
): | |
"""Block is either 'simple' or 'bottleneck'""" | |
layers = [] | |
for i in range(num_blocks): | |
layers.append( | |
SimpleResidualBlock(in_channels, out_channels, set_stride=set_stride) | |
if block == "simple" | |
else BottleneckResidualBlock( | |
in_channels if i == 0 else out_channels * 4, | |
out_channels, | |
set_stride=set_stride, | |
) | |
) | |
set_stride = False | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
out = self.conv1(x) | |
out = self.maxpool(self.conv2(out)) | |
out = self.conv3(out) | |
out = self.conv4(out) | |
out = self.conv5(out) | |
out = self.avgpool(out) | |
out = self.flatten(out) | |
out = self.fc(out) | |
return out | |
def _init_weights(module): | |
# Initlize weights with glorot uniform | |
if isinstance(module, nn.Conv2d): | |
nn.init.xavier_uniform_(module.weight) | |
nn.init.zeros_(module.bias) | |
elif isinstance(module, nn.Linear): | |
nn.init.xavier_uniform_(module.weight) | |
nn.init.zeros_(module.bias) | |
class ImageClassifier: | |
def __init__(self, checkpoint_path): | |
self.checkpoint_path = checkpoint_path | |
self.model = self.load_model(checkpoint_path) | |
self.transform = self.get_transform((244, 244)) | |
self.labels = [ | |
"airplane", | |
"automobile", | |
"bird", | |
"cat", | |
"deer", | |
"dog", | |
"frog", | |
"horse", | |
"ship", | |
"truck", | |
] | |
def load_model(self, checkpoint_path): | |
classifier = ResNet( | |
arch=RESNET_18, | |
block="simple", | |
num_classes=10, | |
) | |
classifier.load_state_dict(torch.load(checkpoint_path)) | |
classifier = classifier.cpu() | |
classifier.eval() | |
return classifier | |
def get_transform(self, img_shape): | |
preprocess_transform = transforms.Compose( | |
[ | |
transforms.Resize(img_shape), | |
transforms.ToTensor(), | |
] | |
) | |
return preprocess_transform | |
def predict(self, image): | |
image_tensor = self.transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
logits = self.model(image_tensor) | |
probs = logits.softmax(dim=1)[0] | |
return {label: prob.item() for label, prob in zip(self.labels, probs)} | |
def classify(self, input_image): | |
return self.predict(input_image) | |
def classify(input_image): | |
return classifier.classify(input_image) | |
checkpoint_path = hf_hub_download( | |
repo_id="SatwikKambham/resnet18-cifar10", | |
filename="model.pt", | |
) | |
classifier = ImageClassifier(checkpoint_path) | |
iface = gr.Interface( | |
classify, | |
inputs=[ | |
gr.Image(label="Input Image", type="pil"), | |
], | |
outputs=gr.Label(num_top_classes=3), | |
) | |
iface.launch() | |