import gradio as gr from torchvision.models import resnet50, ResNet50_Weights from torchvision import transforms import torch.nn as nn import torch @staticmethod def create_model_from_checkpoint(): # Loads a model from a checkpoint model = resnet50() model.fc = nn.Linear(model.fc.in_features, 3) model.load_state_dict(torch.load("best_model")) model.eval() return model def prep_image(img): transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor() ]) transform_normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) transformed_img = transform(img) input = transform_normalize(transformed_img) input = input.unsqueeze(0) return input model = create_model_from_checkpoint() labels = [ "benign", "malignant", "normal" ] def predict(img): input = prep_image(img) with torch.no_grad(): prediction = torch.nn.functional.softmax(model(input)[0], dim=0) confidences = {labels[i]: float(prediction[i]) for i in range(3)} return confidences ui = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch() ui.launch(share=True)