medaid-simple / app.py
mrmuminov's picture
Update app.py
9a444fd verified
import torch
import numpy as np
import pydicom
import gradio as gr
from torchvision import transforms
from PIL import Image
# Load your PyTorch model
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
# Define a function to preprocess the DICOM
def preprocess_dicom_to_image(dicom_path):
# Load DICOM file
dicom = pydicom.dcmread(dicom_path)
image = dicom.pixel_array # Extract image data
# Normalize to [0, 1] and convert to PIL Image for transforms
image = (image - np.min(image)) / (np.max(image) - np.min(image))
return Image.fromarray((image * 255).astype(np.uint8))
# Define a function to preprocess the DICOM
def preprocess_dicom(dicom_path):
image = preprocess_dicom_to_image(dicom_path)
# Apply transforms
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to model's input size
transforms.ToTensor(),
])
return transform(image).unsqueeze(0) # Add batch dimension
# Prediction function
def predict_dicom(dicom_file):
# Preprocess
input_tensor = preprocess_dicom(dicom_file.name)
# Inference
with torch.no_grad():
output = model(input_tensor)
# Convert output tensor to image (dummy example, replace as needed)
output_image = output.squeeze().numpy()
output_image = (output_image - np.min(output_image)) / (np.max(output_image) - np.min(output_image)) * 255
output_image = Image.fromarray(output_image.astype(np.uint8))
return output_image
# Prediction function
def predict_dicom2(dicom_file):
return preprocess_dicom_to_image(dicom_file)
# Create Gradio interface
interface = gr.Interface(
fn=predict_dicom2,
inputs=gr.File(label="Upload DICOM File"),
outputs="image",
title="DICOM Image Prediction"
)
# Launch the Gradio app
interface.launch()