import gradio as gr import torch from PIL import Image import torchvision.transforms as transforms import numpy as np from safetensors.torch import load_model, save_model from models import * import os class WasteClassifier: def __init__(self, model, class_names, device): self.model = model self.class_names = class_names self.device = device self.transform = transforms.Compose( [ transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) def predict(self, image): self.model.eval() if not isinstance(image, Image.Image): image = Image.fromarray(image) original_size = image.size img_tensor = self.transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.model(img_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) probs = probabilities[0].cpu().numpy() pred_class = self.class_names[np.argmax(probs)] confidence = np.max(probs) results = { "predicted_class": pred_class, "confidence": confidence, "class_probabilities": { class_name: float(prob) for class_name, prob in zip(self.class_names, probs) }, } return results def interface(classifier): def process_image(image): results = classifier.predict(image) output_str = f"Predicted Class: {results['predicted_class']}\n" output_str += f"Confidence: {results['confidence']*100:.2f}%\n\n" output_str += "Class Probabilities:\n" sorted_probs = sorted( results["class_probabilities"].items(), key=lambda x: x[1], reverse=True ) for class_name, prob in sorted_probs: output_str += f"{class_name}: {prob*100:.2f}%\n" return output_str demo = gr.Interface( fn=process_image, inputs=[gr.Image(type="pil", label="Upload Image")], outputs=[gr.Textbox(label="Classification Results")], title="Waste Classification System", description=""" Upload an image of waste to classify it into different categories. The model will predict the type of waste and show confidence scores for each category. """, examples=( [["example1.jpg"], ["example2.jpg"], ["example3.jpg"]] if os.path.exists("example1.jpg") else None ), analytics_enabled=False, theme="default", ) return demo device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class_names = [ "Cardboard", "Food Organics", "Glass", "Metal", "Miscellaneous Trash", "Paper", "Plastic", "Textile Trash", "Vegetation", ] best_model = ResNet50(num_classes=len(class_names)) best_model = best_model.to(device) load_model( best_model, os.path.join(os.path.dirname(os.path.abspath(__file__)), "bjf8fp.safetensors"), ) classifier = WasteClassifier(best_model, class_names, device) demo = interface(classifier) demo.launch()