SatwikKambham's picture
Added app.py and requirements.txt
22ae476
import gradio as gr
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
from torchvision import transforms
class SimpleResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, set_stride=False):
super().__init__()
stride = 2 if in_channels != out_channels and set_stride else 1
self.conv1 = nn.LazyConv2d(
out_channels,
kernel_size=3,
padding="same" if stride == 1 else 1,
stride=stride,
)
self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same")
self.bn1 = nn.LazyBatchNorm2d()
self.bn2 = nn.LazyBatchNorm2d()
self.relu = nn.ReLU()
if in_channels != out_channels:
self.residual = nn.Sequential(
nn.LazyConv2d(out_channels, kernel_size=1, stride=stride),
nn.LazyBatchNorm2d(),
)
else:
self.residual = nn.Identity()
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.residual(x)
out = self.relu(out)
return out
class BottleneckResidualBlock(nn.Module):
def __init__(
self, in_channels, out_channels, identity_mapping=False, set_stride=False
):
super().__init__()
stride = 2 if in_channels != out_channels and set_stride else 1
self.conv1 = nn.LazyConv2d(
out_channels,
kernel_size=1,
padding="same" if stride == 1 else 0,
stride=stride,
)
self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same")
self.conv3 = nn.LazyConv2d(out_channels * 4, kernel_size=1, padding="same")
self.bn1 = nn.LazyBatchNorm2d()
self.bn2 = nn.LazyBatchNorm2d()
self.bn3 = nn.LazyBatchNorm2d()
self.relu = nn.ReLU()
if in_channels != out_channels or not identity_mapping:
self.residual = nn.Sequential(
nn.LazyConv2d(out_channels * 4, kernel_size=1, stride=stride),
nn.LazyBatchNorm2d(),
)
else:
self.residual = nn.Identity()
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.residual(x)
out = self.relu(out)
return out
RESNET_18 = [2, 2, 2, 2]
RESNET_34 = [3, 4, 6, 3]
RESNET_50 = [3, 4, 6, 3]
RESNET_101 = [3, 4, 23, 3]
RESNET_152 = [3, 8, 36, 3]
class ResNet(nn.Module):
def __init__(self, arch=RESNET_18, block="simple", num_classes=256):
super().__init__()
self.conv1 = nn.Sequential(
nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),
nn.LazyBatchNorm2d(),
nn.ReLU(),
)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
self.conv2 = self._make_layer(64, 64, arch[0], set_stride=False, block=block)
self.conv3 = self._make_layer(64, 128, arch[1], block=block)
self.conv4 = self._make_layer(128, 256, arch[2], block=block)
self.conv5 = self._make_layer(256, 512, arch[3], block=block)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.flatten = nn.Flatten()
self.fc = nn.LazyLinear(num_classes)
def _make_layer(
self, in_channels, out_channels, num_blocks, set_stride=True, block="simple"
):
"""Block is either 'simple' or 'bottleneck'"""
layers = []
for i in range(num_blocks):
layers.append(
SimpleResidualBlock(in_channels, out_channels, set_stride=set_stride)
if block == "simple"
else BottleneckResidualBlock(
in_channels if i == 0 else out_channels * 4,
out_channels,
set_stride=set_stride,
)
)
set_stride = False
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.maxpool(self.conv2(out))
out = self.conv3(out)
out = self.conv4(out)
out = self.conv5(out)
out = self.avgpool(out)
out = self.flatten(out)
out = self.fc(out)
return out
def _init_weights(module):
# Initlize weights with glorot uniform
if isinstance(module, nn.Conv2d):
nn.init.xavier_uniform_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
nn.init.zeros_(module.bias)
class ImageClassifier:
def __init__(self, checkpoint_path):
self.checkpoint_path = checkpoint_path
self.model = self.load_model(checkpoint_path)
self.transform = self.get_transform((244, 244))
self.labels = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
def load_model(self, checkpoint_path):
classifier = ResNet(
arch=RESNET_18,
block="simple",
num_classes=10,
)
classifier.load_state_dict(torch.load(checkpoint_path))
classifier = classifier.cpu()
classifier.eval()
return classifier
def get_transform(self, img_shape):
preprocess_transform = transforms.Compose(
[
transforms.Resize(img_shape),
transforms.ToTensor(),
]
)
return preprocess_transform
def predict(self, image):
image_tensor = self.transform(image).unsqueeze(0)
with torch.no_grad():
logits = self.model(image_tensor)
probs = logits.softmax(dim=1)[0]
return {label: prob.item() for label, prob in zip(self.labels, probs)}
def classify(self, input_image):
return self.predict(input_image)
def classify(input_image):
return classifier.classify(input_image)
checkpoint_path = hf_hub_download(
repo_id="SatwikKambham/resnet18-cifar10",
filename="model.pt",
)
classifier = ImageClassifier(checkpoint_path)
iface = gr.Interface(
classify,
inputs=[
gr.Image(label="Input Image", type="pil"),
],
outputs=gr.Label(num_top_classes=3),
)
iface.launch()