enesmanan's picture
add class description
1082a29 verified
import os
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import timm
class BaseModel(nn.Module):
def predict(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
logits = self(x)
return F.softmax(logits, dim=1)
def get_num_classes(self) -> int:
raise NotImplementedError
class CNNModel(BaseModel):
def __init__(self, num_classes: int, input_size: int = 224):
super(CNNModel, self).__init__()
self.conv_layers = nn.Sequential(
# First block: 32 filters
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
# Second block: 64 filters
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
# Third block: 128 filters
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
# Global Average Pooling
nn.AdaptiveAvgPool2d(1)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Dropout(0.5),
nn.Linear(128, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_layers(x)
return self.classifier(x)
def get_num_classes(self) -> int:
return self.classifier[-1].out_features
class EfficientNetModel(BaseModel):
def __init__(
self,
num_classes: int,
model_name: str = "efficientnet_b0",
pretrained: bool = True
):
super(EfficientNetModel, self).__init__()
self.base_model = timm.create_model(
model_name,
pretrained=pretrained,
num_classes=0
)
with torch.no_grad():
dummy_input = torch.randn(1, 3, 224, 224)
features = self.base_model(dummy_input)
feature_dim = features.shape[1]
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(feature_dim, num_classes)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.base_model(x)
return self.classifier(features)
def get_num_classes(self) -> int:
return self.classifier[-1].out_features
class AnimalClassifierApp:
def __init__(self):
"""Initialize the application."""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.labels = ["bird", "cat", "dog", "horse"]
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
self.models = self.load_models()
if not self.models:
print("Warning: No models found in checkpoints directory!")
def load_models(self):
"""Load both trained models."""
models = {}
try:
efficientnet = EfficientNetModel(num_classes=len(self.labels))
efficientnet_path = os.path.join("checkpoints", "efficientnet", "efficientnet_best_model.pth")
if os.path.exists(efficientnet_path):
checkpoint = torch.load(efficientnet_path, map_location=self.device, weights_only=True)
state_dict = checkpoint.get('model_state_dict', checkpoint)
efficientnet.load_state_dict(state_dict, strict=False)
efficientnet.eval()
models['EfficientNet'] = efficientnet
print("Successfully loaded EfficientNet model")
except Exception as e:
print(f"Error loading EfficientNet model: {str(e)}")
try:
cnn = CNNModel(num_classes=len(self.labels))
cnn_path = os.path.join("checkpoints", "cnn", "cnn_best_model.pth")
if os.path.exists(cnn_path):
checkpoint = torch.load(cnn_path, map_location=self.device, weights_only=True)
state_dict = checkpoint.get('model_state_dict', checkpoint)
cnn.load_state_dict(state_dict, strict=False)
cnn.eval()
models['CNN'] = cnn
print("Successfully loaded CNN model")
except Exception as e:
print(f"Error loading CNN model: {str(e)}")
return models
def predict(self, image: Image.Image):
if not self.models:
return ["No trained models found. Please train the models first.", ""]
# Preprocess image
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
results = {}
probabilities = {}
for model_name, model in self.models.items():
with torch.no_grad():
output = model(img_tensor)
probs = F.softmax(output, dim=1).squeeze().cpu().numpy()
probabilities[model_name] = probs
pred_idx = np.argmax(probs)
pred_label = self.labels[pred_idx]
pred_prob = probs[pred_idx]
results[model_name] = (pred_label, pred_prob)
fig = plt.figure(figsize=(12, 5))
if 'EfficientNet' in probabilities:
plt.subplot(1, 2, 1)
plt.bar(self.labels, probabilities['EfficientNet'], color='skyblue')
plt.title('EfficientNet Predictions')
plt.ylim(0, 1)
plt.xticks(rotation=45)
plt.ylabel('Probability')
if 'CNN' in probabilities:
plt.subplot(1, 2, 2)
plt.bar(self.labels, probabilities['CNN'], color='lightcoral')
plt.title('CNN Predictions')
plt.ylim(0, 1)
plt.xticks(rotation=45)
plt.ylabel('Probability')
plt.tight_layout()
text_results = "Model Predictions:\n\n"
for model_name, (label, prob) in results.items():
text_results += f"{model_name}:\n"
text_results += f"Top prediction: {label} ({prob:.2%})\n"
text_results += "All probabilities:\n"
for label, prob in zip(self.labels, probabilities[model_name]):
text_results += f" {label}: {prob:.2%}\n"
text_results += "\n"
return [fig, text_results]
def create_interface(self):
"""Create Gradio interface."""
return gr.Interface(
fn=self.predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Plot(label="Prediction Probabilities"),
gr.Textbox(label="Detailed Results", lines=10)
],
title="Animal Classifier - Model Comparison",
description=(
"Upload an image of one of these animals: Bird, Cat, Dog, or Horse.\n"
"The app will compare predictions from both EfficientNet and CNN models.\n\n"
"Note: For best results, ensure the animal is clearly visible in the image."
)
)
def main():
app = AnimalClassifierApp()
interface = app.create_interface()
interface.launch()
if __name__ == "__main__":
main()