mrmuminov commited on
Commit
9a444fd
·
verified ·
1 Parent(s): 1e7bd9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -8,14 +8,19 @@ from PIL import Image
8
  # Load your PyTorch model
9
  model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
10
 
 
11
  # Define a function to preprocess the DICOM
12
- def preprocess_dicom(dicom_path):
13
  # Load DICOM file
14
  dicom = pydicom.dcmread(dicom_path)
15
  image = dicom.pixel_array # Extract image data
16
  # Normalize to [0, 1] and convert to PIL Image for transforms
17
  image = (image - np.min(image)) / (np.max(image) - np.min(image))
18
- image = Image.fromarray((image * 255).astype(np.uint8))
 
 
 
 
19
  # Apply transforms
20
  transform = transforms.Compose([
21
  transforms.Resize((224, 224)), # Resize to model's input size
@@ -36,9 +41,13 @@ def predict_dicom(dicom_file):
36
  output_image = Image.fromarray(output_image.astype(np.uint8))
37
  return output_image
38
 
 
 
 
 
39
  # Create Gradio interface
40
  interface = gr.Interface(
41
- fn=predict_dicom,
42
  inputs=gr.File(label="Upload DICOM File"),
43
  outputs="image",
44
  title="DICOM Image Prediction"
 
8
  # Load your PyTorch model
9
  model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
10
 
11
+
12
  # Define a function to preprocess the DICOM
13
+ def preprocess_dicom_to_image(dicom_path):
14
  # Load DICOM file
15
  dicom = pydicom.dcmread(dicom_path)
16
  image = dicom.pixel_array # Extract image data
17
  # Normalize to [0, 1] and convert to PIL Image for transforms
18
  image = (image - np.min(image)) / (np.max(image) - np.min(image))
19
+ return Image.fromarray((image * 255).astype(np.uint8))
20
+
21
+ # Define a function to preprocess the DICOM
22
+ def preprocess_dicom(dicom_path):
23
+ image = preprocess_dicom_to_image(dicom_path)
24
  # Apply transforms
25
  transform = transforms.Compose([
26
  transforms.Resize((224, 224)), # Resize to model's input size
 
41
  output_image = Image.fromarray(output_image.astype(np.uint8))
42
  return output_image
43
 
44
+ # Prediction function
45
+ def predict_dicom2(dicom_file):
46
+ return preprocess_dicom_to_image(dicom_file)
47
+
48
  # Create Gradio interface
49
  interface = gr.Interface(
50
+ fn=predict_dicom2,
51
  inputs=gr.File(label="Upload DICOM File"),
52
  outputs="image",
53
  title="DICOM Image Prediction"