import torch import numpy as np import pydicom import gradio as gr from torchvision import transforms from PIL import Image # Load your PyTorch model model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval() # Define a function to preprocess the DICOM def preprocess_dicom_to_image(dicom_path): # Load DICOM file dicom = pydicom.dcmread(dicom_path) image = dicom.pixel_array # Extract image data # Normalize to [0, 1] and convert to PIL Image for transforms image = (image - np.min(image)) / (np.max(image) - np.min(image)) return Image.fromarray((image * 255).astype(np.uint8)) # Define a function to preprocess the DICOM def preprocess_dicom(dicom_path): image = preprocess_dicom_to_image(dicom_path) # Apply transforms transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize to model's input size transforms.ToTensor(), ]) return transform(image).unsqueeze(0) # Add batch dimension # Prediction function def predict_dicom(dicom_file): # Preprocess input_tensor = preprocess_dicom(dicom_file.name) # Inference with torch.no_grad(): output = model(input_tensor) # Convert output tensor to image (dummy example, replace as needed) output_image = output.squeeze().numpy() output_image = (output_image - np.min(output_image)) / (np.max(output_image) - np.min(output_image)) * 255 output_image = Image.fromarray(output_image.astype(np.uint8)) return output_image # Prediction function def predict_dicom2(dicom_file): return preprocess_dicom_to_image(dicom_file) # Create Gradio interface interface = gr.Interface( fn=predict_dicom2, inputs=gr.File(label="Upload DICOM File"), outputs="image", title="DICOM Image Prediction" ) # Launch the Gradio app interface.launch()