import gradio as gr import lightning import numpy as np import os import pandas as pd import timm import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader BACKBONE = "resnet18d" IMAGE_HEIGHT, IMAGE_WIDTH = 512, 512 trained_weights_path = "epoch=009.ckpt" trained_weights = torch.load(trained_weights_path, map_location=torch.device('cpu'))["state_dict"] # recreate the model class BoneAgeModel(lightning.LightningModule): def __init__(self, net, optimizer, scheduler, loss_fn): super().__init__() self.net = net self.optimizer = optimizer self.scheduler = scheduler self.loss_fn = loss_fn self.val_losses = [] def training_step(self, batch, batch_index): out = self.net(batch["x"]) loss = self.loss_fn(out, batch["y"]) return loss def validation_step(self, batch, batch_index): out = self.net(batch["x"]) loss = self.loss_fn(out, batch["y"]) self.val_losses.append(loss.item()) def on_validation_epoch_end(self, *args, **kwargs): val_loss = np.mean(self.val_losses) self.val_losses = [] print(f"Validation Loss : {val_loss:0.3f}") def configure_optimizers(self): lr_scheduler = {"scheduler": self.scheduler, "interval": "step"} return {"optimizer": self.optimizer, "lr_scheduler": lr_scheduler} net = timm.create_model(BACKBONE, pretrained=True, in_chans=1, num_classes=1) trained_model = BoneAgeModel(net, None, None, None) trained_model.load_state_dict(trained_weights) trained_model.eval() def predict_bone_age(Radiograph): img = torch.from_numpy(Radiograph) img = img.unsqueeze(0).unsqueeze(0) # add channel and batch dimensions img = img / 255. # use same normalization as in the PyTorch dataset with torch.inference_mode(): bone_age = trained_model.net(img)[0].item() years = int(bone_age) months = round((bone_age - years) * 12) return f"Predicted Bone Age: {years} years, {months} months" image = gr.Image(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, image_mode="L") # L for grayscale label = gr.Label(show_label=True, label="Bone Age Prediction") demo = gr.Interface(fn=predict_bone_age, inputs=[image], outputs=label) demo.launch(debug=True)