File size: 10,574 Bytes
d1e57c0 2de3215 d1e57c0 8b56af0 bbe5ab2 2de3215 bbe5ab2 2de3215 8b56af0 2de3215 8b56af0 2de3215 8b56af0 2de3215 8b56af0 2de3215 271fef8 2de3215 271fef8 2de3215 271fef8 2de3215 271fef8 2de3215 271fef8 2de3215 2420cd1 2de3215 271fef8 2de3215 271fef8 2de3215 271fef8 2de3215 271fef8 a31fb38 2de3215 271fef8 2de3215 271fef8 2de3215 271fef8 2de3215 271fef8 703aea5 2de3215 703aea5 2de3215 703aea5 2de3215 271fef8 703aea5 2de3215 703aea5 2de3215 271fef8 2de3215 271fef8 ee2641c 271fef8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 |
import os
import requests
# Disable JIT
os.environ["PYTORCH_JIT"] = "0"
from einops import rearrange
import gradio as gr
import numpy as np
import spaces
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageOps
from transformers import AutoModel, CLIPImageProcessor
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from segment_anything.modeling.image_encoder import ImageEncoderViT
class RADIOVenc(nn.Module):
def __init__(self, radio: nn.Module, img_enc: ImageEncoderViT, img_size: int = 1024):
super().__init__()
self.radio = radio
self.neck = img_enc.neck
self.img_size = img_size
self.dtype = radio.input_conditioner.dtype
def forward(self, x: torch.Tensor):
h, w = x.shape[-2:]
if self.dtype is not None:
x = x.to(dtype=self.dtype)
with torch.autocast('cuda', dtype=torch.bfloat16, enabled=self.dtype is None):
output = self.radio(x)
features = output["sam"].features
rows = h // 16
cols = w // 16
features = rearrange(features, 'b (h w) c -> b c h w', h=rows, w=cols)
features = self.neck(features)
return features
def download_file(url, save_path):
# Check if the file already exists
if os.path.exists(save_path):
print(f"File already exists at {save_path}. Skipping download.")
return
print(f"Downloading from {url}")
# Send a GET request to the URL
response = requests.get(url, stream=True)
# Check if the request was successful
if response.status_code == 200:
# Open the file in binary write mode
with open(save_path, 'wb') as file:
# Iterate over the response content in chunks
for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
file.write(chunk)
print(f"File downloaded successfully and saved as {save_path}")
else:
print(f"Failed to download file. HTTP Status Code: {response.status_code}")
hf_repo = "nvidia/RADIO-L"
image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
model_version = "radio_v2.5-l" # for RADIOv2.5-L model (ViT-L/16)
model = torch.hub.load(
'NVlabs/RADIO',
'radio_model',
version=model_version,
progress=True,
skip_validation=True,
adaptor_names='sam')
model.eval()
local_sam_checkpoint_path = "sam_vit_h_4b8939.pth"
download_file("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", local_sam_checkpoint_path)
sam = sam_model_registry["vit_h"](checkpoint=local_sam_checkpoint_path)
model._patch_size = 16
sam.image_encoder = RADIOVenc(model, sam.image_encoder, img_size=1024)
conditioner = model.make_preprocessor_external()
sam.pixel_mean = conditioner.norm_mean * 255
sam.pixel_std = conditioner.norm_std * 255
def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
# features: (N, C)
# m: a hyperparam controlling how many std dev outside for outliers
assert len(features.shape) == 2, "features should be (N, C)"
reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
colors = features @ reduction_mat
if remove_first_component:
colors_min = colors.min(dim=0).values
colors_max = colors.max(dim=0).values
tmp_colors = (colors - colors_min) / (colors_max - colors_min)
fg_mask = tmp_colors[..., 0] < 0.2
reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
colors = features @ reduction_mat
else:
fg_mask = torch.ones_like(colors[:, 0]).bool()
d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
mdev = torch.median(d, dim=0).values
s = d / mdev
try:
rins = colors[fg_mask][s[:, 0] < m, 0]
gins = colors[fg_mask][s[:, 1] < m, 1]
bins = colors[fg_mask][s[:, 2] < m, 2]
rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
except:
rins = colors
gins = colors
bins = colors
rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)
def get_pca_map(
feature_map: torch.Tensor,
img_size,
interpolation="bicubic",
return_pca_stats=False,
pca_stats=None,
):
"""
feature_map: (1, h, w, C) is the feature map of a single image.
"""
if feature_map.shape[0] != 1:
# make it (1, h, w, C)
feature_map = feature_map[None]
if pca_stats is None:
reduct_mat, color_min, color_max = get_robust_pca(
feature_map.reshape(-1, feature_map.shape[-1])
)
else:
reduct_mat, color_min, color_max = pca_stats
pca_color = feature_map @ reduct_mat
pca_color = (pca_color - color_min) / (color_max - color_min)
pca_color = pca_color.clamp(0, 1)
pca_color = F.interpolate(
pca_color.permute(0, 3, 1, 2),
size=img_size,
mode=interpolation,
).permute(0, 2, 3, 1)
pca_color = pca_color.cpu().numpy().squeeze(0)
if return_pca_stats:
return pca_color, (reduct_mat, color_min, color_max)
return pca_color
def pad_image_to_multiple_of(image, multiple=16):
# Calculate the new dimensions to make them multiples
width, height = image.size
new_width = (width + multiple -1) // multiple * multiple
new_height = (height + multiple -1) // multiple * multiple
# Calculate the padding needed on each side
pad_width = new_width - width
pad_height = new_height - height
left = pad_width // 2
right = pad_width - left
top = pad_height // 2
bottom = pad_height - top
# Apply the padding
padded_image = ImageOps.expand(image, (left, top, right, bottom), fill='black')
return padded_image
def center_crop_resize(image, size=(1024, 1024)):
# Get dimensions
width, height = image.size
# Determine the center crop box
if width > height:
new_width = height
new_height = height
left = (width - new_width) / 2
top = 0
right = (width + new_width) / 2
bottom = height
else:
new_width = width
new_height = width
left = 0
top = (height - new_height) / 2
right = width
bottom = (height + new_height) / 2
# Crop the image to a square
image = image.crop((left, top, right, bottom))
# Resize the cropped image to the target size
image = image.resize(size, Image.LANCZOS)
return image
def visualize_anns(orig_image: np.ndarray, anns):
if len(anns) == 0:
return orig_image
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
kernel = torch.ones(1, 1, 5, 5, dtype=torch.float32)
# RGBA
mask = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4), dtype=np.float32)
mask[:,:,3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.35]])
tm = torch.as_tensor(m).reshape(1, 1, *m.shape).float()
cvtm = F.conv2d(tm, kernel, padding=2)
border_mask = (cvtm < 25).flatten(0, 2).numpy()
mask[m] = color_mask
mask[m & border_mask, 3] *= 1.0 / 0.35
color, alpha = mask[..., :3], mask[..., -1:]
orig_image = orig_image.astype(np.float32) / 255
overlay = alpha * color + (1 - alpha) * orig_image
overlay = (overlay * 255).astype(np.uint8)
return overlay
@spaces.GPU
def infer_radio(image):
"""Define the function to generate the output."""
model.cuda()
conditioner.cuda()
sam.cuda()
sam_generator = SamAutomaticMaskGenerator(sam, output_mode="binary_mask")
# PCA feature visalization
padded_image=pad_image_to_multiple_of(image, multiple=256)
width, height = padded_image.size
pixel_values = image_processor(images=padded_image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()
pixel_values = conditioner(pixel_values)
_, features = model(pixel_values)["backbone"]
num_rows = height // model.patch_size
num_cols = width // model.patch_size
features = features.detach()
features = rearrange(features, 'b (h w) c -> b h w c', h=num_rows, w=num_cols).float()
pca_viz = get_pca_map(features, (height, width), interpolation='bilinear')
# SAM feature visualization
resized_image = center_crop_resize(image)
image_array = np.array(image)
print("image size", image_array.shape)
#image_array = np.transpose(image_array, (2, 0, 1))
masks = sam_generator.generate(image_array)
overlay = visualize_anns(image_array, masks)
return pca_viz, overlay, f"{features.shape}"
title = """RADIO: Reduce All Domains Into One"""
description = """
# RADIO
[AM-RADIO](https://github.com/NVlabs/RADIO) is a framework to distill Large Vision Foundation models into a single one.
RADIO, a new vision foundation model, excels across visual domains, serving as a superior replacement for vision backbones.
Integrating CLIP variants, DINOv2, and SAM through distillation, it preserves unique features like text grounding and segmentation correspondence.
Outperforming teachers in ImageNet zero-shot (+6.8%), kNN (+2.39%), and linear probing segmentation (+3.8%) and vision-language models (LLaVa 1.5 up to 1.5%), it scales to any resolution, supports non-square images.
# Instructions
Paste an image into the input box or pick one from the gallery of examples and then click the "Submit" button.
The RADIO backbone features are processed with a PCA projection to 3 channels and displayed as an RGB channels.
The SAM features are processed using the SAM decoder and shown as an overlay on top of the input image.
"""
inputs = [
gr.Image(type="pil")
]
outputs = [
gr.Image(label="PCA Feature Visalization"),
gr.Image(label="SAM Masks"),
gr.Textbox(label="Feature Shape"),
]
# Create the Gradio interface
demo = gr.Interface(
fn=infer_radio,
inputs=inputs,
examples="./samples/",
outputs=outputs,
title=title,
description=description,
cache_examples=False
)
if __name__ == "__main__":
demo.launch()
|