from torchvision import transforms from torchvision import models from efficientnet_pytorch import EfficientNet import torch from CustomModels import DinoVisionClassifier from functools import lru_cache classes = {0: 'Glas', 1: 'Organic', 2: 'Papier', 3: 'Restmüll', 4: 'Wertstoff'} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transform = transforms.Compose( [transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ] ) transform_dinov2 = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ] ) @lru_cache def load_specific_model(model_name): current_model = None if model_name == "EfficientNet-B3": current_model = EfficientNet.from_pretrained("efficientnet-b3", num_classes=len(classes.keys())) current_model.load_state_dict(torch.load("src/models/eff_b3_model.pt", map_location="cpu")) elif model_name == "EfficientNet-B4": current_model = EfficientNet.from_pretrained("efficientnet-b4", num_classes=len(classes.keys())) current_model.load_state_dict(torch.load("src/models/eff_b4.pt", map_location="cpu")) elif model_name == "vgg19": current_model = models.vgg19() in_features = current_model.classifier[0].in_features current_model.classifier = torch.nn.Linear(in_features, len(classes.keys())) current_model.load_state_dict(torch.load("src/models/vgg19.pt", map_location="cpu")) elif model_name == "resnet50": current_model = models.resnet50() in_features = current_model.fc.in_features current_model.fc = torch.nn.Linear(in_features, len(classes.keys())) current_model.load_state_dict(torch.load("src/models/resnet50.pt", map_location="cpu")) elif model_name == "dinov2_vits14": current_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vits14") current_model = DinoVisionClassifier(current_model, num_classes=len(classes.keys())) current_model.load_state_dict(torch.load("src/models/dinov2_vits14_0.054_98.00.pth", map_location="cpu")) print(f"Loaded model {model_name}") return current_model.eval().to(device) def inference(model, inp): model.eval() inp = transform(inp) if model.__class__.__name__ != "DinoVisionClassifier" else transform_dinov2(inp) inp = inp.unsqueeze(0).to(device) if torch.cuda.is_available(): with torch.no_grad(), torch.cuda.amp.autocast(): prediction = torch.nn.functional.softmax(model(inp)[0], dim=0).cpu().numpy() else: with torch.no_grad(): prediction = torch.nn.functional.softmax(model(inp)[0], dim=0).cpu().numpy() confidences = {classes[i]: float(prediction[i]) for i in range(len(classes.keys()))} return confidences