import random import os import numpy as np import gradio as gr from PIL import Image from groundingdino.util.inference import load_model as load_groundingdino_model from groundingdino.util.inference import predict as grounding_dino_predict import groundingdino.datasets.transforms as T import torch from torchvision.ops import box_convert from torchvision.transforms.functional import to_tensor from torchvision.transforms import GaussianBlur import time # ---------------------------- # DINOv2 Classifier Imports # ---------------------------- import torch.nn as nn from torchvision import transforms import pandas as pd from typing import List, Tuple import copy import matplotlib.pyplot as plt # ---------------------------- # DINOv2 Classifier Definitions # ---------------------------- # 1. PadToSquare Class class PadToSquare: """ Pads an image to make it square by adding padding to the shorter side. """ def __init__(self, fill=0): self.fill = fill def __call__(self, img): w, h = img.size max_wh = max(w, h) hp = (max_wh - w) // 2 vp = (max_wh - h) // 2 padding = (hp, vp, max_wh - w - hp, max_wh - h - vp) return transforms.functional.pad(img, padding, fill=self.fill, padding_mode='constant') # 2. DinoVisionTransformerClassifier Class (Modified to include entropy-based approach) class DinoVisionTransformerClassifier(nn.Module): """ DINOv2 Vision Transformer-based classifier with entropy-based "Unknown" class handling. """ def __init__(self, num_classes, hidden_size=256, dropout_p=0.5, negative_slope=0.01): super(DinoVisionTransformerClassifier, self).__init__() # Load DINOv2 model from torch.hub self.transformer = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14', pretrained=True) self.transformer.norm = nn.Identity() # Remove existing normalization if necessary # Batch Normalization after transformer self.batch_norm1 = nn.BatchNorm1d(384) # 384 is the embedding size # Classification head self.classifier = nn.Sequential( nn.Linear(384, hidden_size), nn.BatchNorm1d(hidden_size), nn.LeakyReLU(negative_slope=negative_slope, inplace=True), nn.Dropout(p=dropout_p), nn.Linear(hidden_size, num_classes) ) # Initialize weights self._initialize_weights() def forward(self, x): features = self.transformer(x) # Forward pass through the transformer features = self.batch_norm1(features) # Apply Batch Normalization logits = self.classifier(features) # Forward pass through the classification head return logits, features # Return both logits and features def _initialize_weights(self): # Initialize weights of the classifier layers for m in self.classifier.modules(): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, a=0.01, mode='fan_in', nonlinearity='leaky_relu') if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm1d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) # 3. Model Loading Function (Updated for Entropy-Based Classifier) def load_model(model_path, device): """ Loads the trained model and class information from the saved checkpoint. Args: model_path (str): Path to the saved .pth model file. device (torch.device): Device to load the model onto. Returns: model (nn.Module): The loaded PyTorch model. class_names (List[str]): List of class names. """ if not os.path.exists(model_path): raise FileNotFoundError(f"Model file '{model_path}' does not exist.") checkpoint = torch.load(model_path, map_location=device) class_names = checkpoint['class_names'] num_classes = len(class_names) # Initialize the model architecture model = DinoVisionTransformerClassifier(num_classes=num_classes) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() # Set to evaluation mode return model, class_names # 4. Image Preprocessing Function (Updated to accept PIL Image directly) def preprocess_image_pil(pil_image: Image.Image, transform: transforms.Compose) -> torch.Tensor: """ Applies the transformation pipeline to a PIL image. Args: pil_image (PIL.Image.Image): The image to preprocess. transform (transforms.Compose): The transformation pipeline. Returns: torch.Tensor: The preprocessed image tensor. """ return transform(pil_image) # ---------------------------- # Gradio App Definitions # ---------------------------- # Automatically set device based on availability DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {DEVICE}") PROMPT = "bug" # Define a custom transform for Gaussian blur (Unused in current context) def gaussian_blur(x, p=0.5, kernel_size_min=3, kernel_size_max=20, sigma_min=0.1, sigma_max=3): if x.ndim == 4: for i in range(x.shape[0]): if random.random() < p: kernel_size = random.randrange(kernel_size_min, kernel_size_max + 1, 2) sigma = random.uniform(sigma_min, sigma_max) x[i] = GaussianBlur(kernel_size=kernel_size, sigma=sigma)(x[i]) return x # Custom Label Function (Unused in current context) def custom_label_func(fpath): # this directs the labels to be 2 levels up from the image folder label = fpath.parents[2].name return label # Image loading function for GroundingDINO def load_image(image_source): transform = T.Compose( [ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image_source = image_source.convert("RGB") image_transformed, _ = transform(image_source, None) return image_transformed # Load GroundingDINO object detection model od_model = load_groundingdino_model( model_checkpoint_path="groundingdino_swint_ogc.pth", model_config_path="GroundingDINO_SwinT_OGC.cfg.py", device=DEVICE) print("Object detection model loaded") # Load DINOv2 classifier model (Updated to use the entropy-based classifier) # Update MODEL_PATH to the path where your DINOv2 model checkpoint is stored MODEL_PATH = 'dinov2_classifier_with_vos_unsure.pth' # Updated model path dinov2_model, class_names = load_model(MODEL_PATH, torch.device(DEVICE)) print(f"DINOv2 Classification model loaded with {len(class_names)} classes.") # Optionally, append "Unknown" to class names if needed # Removed the line that appends "Unknown" as the model handles it via thresholding # Replace specific class names if necessary # Example: Replace "Scolotodes_schwarzi" with "Scolytodes_glaber" target = "Scolotodes_schwarzi" if target in class_names: idx = class_names.index(target) class_names[idx] = "Scolytodes_glaber" print(f"Replaced '{target}' with 'Scolytodes_glaber' in class names.") else: print(f"'{target}' not found in class names. No replacement made.") # Define the transformation pipeline for DINOv2 model dinov2_transform = transforms.Compose([ transforms.Resize(224), # Resize smaller edge to 224 PadToSquare(), # Pad to make the image square transforms.Resize((224, 224)), # Resize to 224x224 transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], # Normalize with ImageNet mean [0.229, 0.224, 0.225]) # Normalize with ImageNet std ]) # Object Detection Function def detect_objects(og_image, model=od_model, prompt="bug . insect", device="cpu"): TEXT_PROMPT = prompt BOX_THRESHOLD = 0.15 # 35 Adjusted back to original value TEXT_THRESHOLD = 0.15 # 25 Adjusted back to original value DEVICE = device # cuda or cpu # Convert numpy array to PIL Image if needed if isinstance(og_image, np.ndarray): og_image_obj = Image.fromarray(og_image) else: og_image_obj = og_image # Assuming og_image is already a PIL Image # Transform the image image_transformed = load_image(image_source = og_image_obj) # Model prediction boxes, logits, phrases = grounding_dino_predict( model=model, image=image_transformed, caption=TEXT_PROMPT, box_threshold=BOX_THRESHOLD, text_threshold=TEXT_THRESHOLD, device=DEVICE) # Use og_image_obj directly for further processing width, height = og_image_obj.size # Corrected to (width, height) boxes_norm = boxes * torch.Tensor([width, height, width, height]) xyxy = box_convert( boxes=boxes_norm, in_fmt="cxcywh", out_fmt="xyxy").numpy() img_lst = [] for i in range(len(boxes_norm)): crop_img = og_image_obj.crop((xyxy[i])) img_lst.append(crop_img) print(f"Detected {len(img_lst)} objects.") return img_lst # Inference/Class Prediction Function using the Entropy-Based DINOv2 Classifier def classify_beetle(img: Image.Image, threshold=75.0): """ Classifies the input image using the DINOv2 classifier with entropy-based "Unknown" class. Args: img (PIL.Image.Image): The image to classify. threshold (float): Confidence threshold to assign "Unknown". Returns: dict: The top 3 class labels with their corresponding confidence scores and "Unknown" if applicable. """ # Preprocess the image input_tensor = preprocess_image_pil(img, dinov2_transform).unsqueeze(0).to(torch.device(DEVICE)) print(f"Input tensor shape: {input_tensor.shape}") with torch.no_grad(): outputs, _ = dinov2_model(input_tensor) print(f"Model outputs: {outputs}") probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # p(x) in [0,1] print(f"Probabilities (0-1 scale): {probabilities}") # Calculate entropy # Adding a small epsilon to avoid log(0) epsilon = 1e-12 entropy = -np.sum(probabilities * np.log(probabilities + epsilon)) # Maximum entropy for uniform distribution max_entropy = -np.sum((1.0 / len(probabilities)) * np.log(1.0 / len(probabilities))) normalized_entropy = entropy / max_entropy # Normalize between 0 and 1 unknown_prob = normalized_entropy print(f"Entropy: {entropy}, Normalized Entropy: {normalized_entropy}, Unknown Probability: {unknown_prob}") # Convert probabilities to percentage for display probabilities_percent = np.around(probabilities * 100, decimals=1) print(f"Probabilities (Percentage): {probabilities_percent}") # Get top 3 classes top_indices = np.argsort(probabilities_percent)[-3:][::-1] # Indices of top 3 classes top_probs = probabilities_percent[top_indices] top_classes = [class_names[i] for i in top_indices] # Initialize conf_dict with top 3 classes conf_dict = {top_classes[i]: float(top_probs[i]) for i in range(len(top_classes))} # Assign "Unknown" based on entropy and threshold if top_probs[0] < threshold: conf_dict["Unknown"] = float(np.around(unknown_prob, decimals=1)) print(f"Conf_dict: {conf_dict}") return conf_dict # Main Prediction Function for Gradio def predict_beetle(img): print("Detecting objects in the image...") start_time = time.perf_counter() # Start timing # Detect objects in the image image_lst = detect_objects(og_image=img, model=od_model, prompt=PROMPT, device=DEVICE) print(f"Detected {len(image_lst)} objects.") # Initialize lists to hold results output_lst = [] img_cnt = len(image_lst) for i in range(img_cnt): print(f"Classifying object {i+1}/{img_cnt}...") conf_dict = classify_beetle(image_lst[i]) output_lst.append([image_lst[i], conf_dict]) print(f"Object {i+1} classified.") end_time = time.perf_counter() processing_time = end_time - start_time print(f"Total processing duration: {processing_time:.2f} seconds") return output_lst # ---------------------------- # Gradio Interface Setup # ---------------------------- sample_images_dir = "example_images" # Sample images with labels example_images = [ os.path.join(sample_images_dir, "example1.jpg"), os.path.join(sample_images_dir, "example2.jpg"), os.path.join(sample_images_dir, "example3.jpg"), os.path.join(sample_images_dir, "mixed.jpg") ] # Corresponding labels for the example images example_labels = ["Example Beetles 1", "Example Beetles 2", "Example Beetles 3", "Example Beetles 4"] with gr.Blocks() as demo: gr.Markdown("