demo-ml / app.py
spuuntries
fix: abs path on model
0a3a9fc
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
from safetensors.torch import load_model, save_model
from models import *
import os
class WasteClassifier:
def __init__(self, model, class_names, device):
self.model = model
self.class_names = class_names
self.device = device
self.transform = transforms.Compose(
[
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def predict(self, image):
self.model.eval()
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
original_size = image.size
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
probs = probabilities[0].cpu().numpy()
pred_class = self.class_names[np.argmax(probs)]
confidence = np.max(probs)
results = {
"predicted_class": pred_class,
"confidence": confidence,
"class_probabilities": {
class_name: float(prob)
for class_name, prob in zip(self.class_names, probs)
},
}
return results
def interface(classifier):
def process_image(image):
results = classifier.predict(image)
output_str = f"Predicted Class: {results['predicted_class']}\n"
output_str += f"Confidence: {results['confidence']*100:.2f}%\n\n"
output_str += "Class Probabilities:\n"
sorted_probs = sorted(
results["class_probabilities"].items(), key=lambda x: x[1], reverse=True
)
for class_name, prob in sorted_probs:
output_str += f"{class_name}: {prob*100:.2f}%\n"
return output_str
demo = gr.Interface(
fn=process_image,
inputs=[gr.Image(type="pil", label="Upload Image")],
outputs=[gr.Textbox(label="Classification Results")],
title="Waste Classification System",
description="""
Upload an image of waste to classify it into different categories.
The model will predict the type of waste and show confidence scores for each category.
""",
examples=(
[["example1.jpg"], ["example2.jpg"], ["example3.jpg"]]
if os.path.exists("example1.jpg")
else None
),
analytics_enabled=False,
theme="default",
)
return demo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = [
"Cardboard",
"Food Organics",
"Glass",
"Metal",
"Miscellaneous Trash",
"Paper",
"Plastic",
"Textile Trash",
"Vegetation",
]
best_model = ResNet50(num_classes=len(class_names))
best_model = best_model.to(device)
load_model(
best_model,
os.path.join(os.path.dirname(os.path.abspath(__file__)), "bjf8fp.safetensors"),
)
classifier = WasteClassifier(best_model, class_names, device)
demo = interface(classifier)
demo.launch()