Spaces:
Runtime error
Runtime error
# Importing libraries for gradio app | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import models | |
import torchvision.transforms as tt | |
from PIL import Image | |
# Moving both Data and Model into GPU | |
def get_default_device(): | |
"""Pick GPU if available, else CPU""" | |
if torch.cuda.is_available(): | |
return torch.device('cuda') | |
else: | |
return torch.device('cpu') | |
def to_device(data, device): | |
"""Move tensor(s) to chosen device""" | |
if isinstance(data, (list,tuple)): | |
return [to_device(x, device) for x in data] | |
return data.to(device, non_blocking=True) | |
class DeviceDataLoader(): | |
"""Wrap a dataloader to move data to a device""" | |
def __init__(self, dl, device): | |
self.dl = dl | |
self.device = device | |
def __iter__(self): | |
"""Yield a batch of data after moving it to device""" | |
for b in self.dl: | |
yield to_device(b, self.device) | |
def __len__(self): | |
"""Number of batches""" | |
return len(self.dl) | |
# Defining our Class for just prediction | |
def accuracy(outputs, labels): | |
_, preds = torch.max(outputs, dim=1) | |
return torch.tensor(torch.sum(preds == labels).item() / len(preds)) | |
class ImageClassificationBase(nn.Module): | |
def validation_step(self, batch): | |
images, labels = batch | |
out = self(images) # Generate predictions | |
loss = F.cross_entropy(out, labels) # Calculate loss | |
acc = accuracy(out, labels) # Calculate accuracy | |
return {'val_loss': loss.detach(), 'val_acc': acc} | |
def validation_epoch_end(self, outputs): | |
batch_losses = [x['val_loss'] for x in outputs] | |
epoch_loss = torch.stack(batch_losses).mean() # Combine losses | |
batch_accs = [x['val_acc'] for x in outputs] | |
epoch_acc = torch.stack(batch_accs).mean() # Combine accuracies | |
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()} | |
# Defining our finetuned Resnet50 Architecture with our Classification layer | |
class IndianFoodModelResnet50(ImageClassificationBase): | |
def __init__(self, num_classes, pretrained=True): | |
super().__init__() | |
# Use a pretrained model | |
self.network = models.resnet50(pretrained=pretrained) | |
# Replace last layer | |
self.network.fc = nn.Linear(self.network.fc.in_features, num_classes) | |
def forward(self, xb): | |
return self.network(xb) | |
# for prediction | |
def evaluate(model, val_loader): | |
model.eval() | |
outputs = [model.validation_step(batch) for batch in val_loader] | |
return model.validation_epoch_end(outputs) | |
# initialising our model and moving it to GPU | |
classes = ['burger', 'butter_naan', 'chai', 'chapati', 'chole_bhature', | |
'dal_makhani', 'dhokla', 'fried_rice', 'idli', 'jalebi', | |
'kaathi_rolls', 'kadai_paneer', 'kulfi', 'masala_dosa', 'momos', | |
'paani_puri', 'pakode', 'pav_bhaji', 'pizza', 'samosa'] | |
model = IndianFoodModelResnet50(len(classes), pretrained=True) | |
to_device(model, device); | |
# loading the model | |
ckp_path = './indianFood-resnet50.pth' | |
model.load_state_dict(torch.load(ckp_path)) | |
model.eval() | |
# image preprocessing before prediction | |
stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
img_tfms = tt.Compose([tt.Resize((224, 224)), | |
tt.ToTensor(), | |
tt.Normalize(*stats, inplace = True)]) | |
def predict_image(image, model): | |
# Convert to a batch of 1 | |
xb = to_device(image.unsqueeze(0), device) | |
# Get predictions from model | |
yb = model(xb) | |
# Pick index with highest probability | |
_, preds = torch.max(yb, dim=1) | |
# Retrieve the class label | |
return classes[preds[0].item()] | |
def classify_image(path): | |
img = Image.open(path) | |
img = img_tfms(img) | |
#img = img.permute(2, 0, 1) | |
label = predict_image(img, model) | |
return label | |
image = gr.inputs.Image(shape=(224, 224), type="filepath") | |
label = gr.outputs.Label(num_top_classes=1) | |
gr.Interface( | |
fn=classify_image, | |
inputs=image, | |
outputs=label, | |
examples = [["idli.jpg"], ["naan.jpg"]], | |
theme = "huggingface", | |
interpretation="default" | |
).launch() |