import timm import torch from torch import nn import albumentations from PIL import Image import numpy as np augmentations = albumentations.Compose( [ albumentations.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, always_apply=True ) ] ) target_map = { 0: 'Cranberry', 1: 'Musk melon', 2: 'Pineapple', 3: 'Watermelon', 4: 'Orange', 5: 'Dragon fruit', 6: 'Bananas', 7: 'Blue berries', 8: 'Jack fruit', 9: 'Avacados', } class ImageModelInfer(nn.Module): def __init__(self, model_path, num_classes): super().__init__() model_path = model_path self.model = timm.create_model(model_path, pretrained=False, num_classes=num_classes) def forward(self, data): logits = self.model(data) return logits def prepare_image(image): # image = Image.open(image) # image = image.convert("RGB") # image = image.resize((256, 256), resample=Image.BILINEAR) image = np.array(image) augmented = augmentations(image=image) image = augmented['image'] image = np.transpose(image, (2, 0, 1)).astype(np.float32) return torch.tensor(image, dtype=torch.float) def predict_fruit_type(img): img = prepare_image(img) prediction = model(img.unsqueeze(0))[0].detach().numpy() class_ = np.argmax(prediction) return target_map[class_] model = ImageModelInfer('vgg16', num_classes=10) model.load_state_dict(torch.load('best_loss_0.ckpt', map_location=torch.device('cpu'))['state_dict']);