File size: 4,646 Bytes
54b4471
 
 
 
 
80d0374
54b4471
 
 
 
 
16fa719
54b4471
34cc7b2
54b4471
80d0374
 
 
 
 
 
 
55223b8
80d0374
 
 
 
9ba0bac
 
 
 
 
 
55223b8
54b4471
16fa719
201e3ec
 
 
 
 
16fa719
201e3ec
 
 
 
 
4b69b97
201e3ec
 
 
 
 
 
80d0374
201e3ec
 
 
 
 
 
 
 
 
 
 
 
54b4471
4a657e9
 
 
 
 
 
54b4471
34cc7b2
 
 
80d0374
34cc7b2
 
 
 
54b4471
80d0374
af6805d
 
80d0374
 
 
34cc7b2
 
 
 
16fa719
 
201e3ec
5bb4fec
42aa5e0
5bb4fec
a2aae29
 
201e3ec
 
42aa5e0
201e3ec
 
 
 
5bb4fec
201e3ec
 
5bb4fec
 
201e3ec
54b4471
5bb4fec
201e3ec
34cc7b2
 
54b4471
 
80d0374
34cc7b2
80d0374
 
34cc7b2
 
03725aa
34cc7b2
80d0374
 
34cc7b2
 
 
80d0374
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
import nrrd
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import io
from gradio_image_prompter import ImagePrompter

def load_nrrd(file_path):
    data, header = nrrd.read(file_path)
    
    # If the data is 3D, take the middle slice
    if len(data.shape) == 3:
        middle_slice = data.shape[2] // 2
        img = data[:, :, middle_slice]
    else:
        img = data

    # Normalize the image to 0-255 range
    img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(np.uint8)

    # Convert grayscale to 3-channel RGB by replicating channels
    if len(img.shape) == 2:  # Grayscale image (height, width)
        img = np.stack((img,)*3, axis=-1)  # Replicate grayscale channel to get (height, width, 3)

    H, W = img.shape[:2]
    return img, H, W

@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    if len(box_torch.shape) == 2:
        box_torch = box_torch[:, None, :] # (B, 1, 4)

    box_torch=box_torch.reshape(1,4) 
    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed, # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
        multimask_output=False,
    )

    low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )  # (1, 1, gt.shape)
    low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (256, 256)
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
    return medsam_seg

def visualize(image, mask, box):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(image, cmap='gray')
    ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none"))
    ax[1].imshow(image, cmap='gray')
    ax[1].imshow(mask, alpha=0.5, cmap="jet")
    plt.tight_layout()

    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    plt.close(fig)
    buf.seek(0)
    pil_img = Image.open(buf)
    
    return pil_img 

def process_nrrd(nrrd_file, points):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Load and preprocess NRRD file
    image, H, W = load_nrrd(nrrd_file.name)
    
    if len(points) >= 6:
        x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
    else:
        raise ValueError("Insufficient data for bounding box coordinates.")

    image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
    image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
    image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)

    # Initialize the MedSAM model and set the device
    model_checkpoint_path = "medsam_vit_b.pth"  # Replace with the correct path to your checkpoint
    medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
    medsam_model = medsam_model.to(device)
    medsam_model.eval()

    # Generate image embedding
    with torch.no_grad():
        img_embed = medsam_model.image_encoder(image_tensor)

    # Calculate resized box coordinates
    scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
    box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors

    # Perform inference
    mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)

    # Visualization
    visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
    return visualization

# Set up Gradio interface
iface = gr.Interface(
    fn=process_nrrd,
    inputs=[
        gr.File(label="NRRD File"),
        gr.JSON(label="Bounding Box Coordinates")
    ],
    outputs=[
        gr.Image(type="pil", label="Processed Image")
    ],
    title="ROI Selection with MEDSAM for NRRD Files",
    description="Upload an NRRD file and provide bounding box coordinates for processing."
)

# Launch the interface
iface.launch()