Update app.py
Browse files
app.py
CHANGED
@@ -1,47 +1,98 @@
|
|
1 |
import os
|
|
|
2 |
|
3 |
# Disable JIT
|
4 |
os.environ["PYTORCH_JIT"] = "0"
|
5 |
|
6 |
from einops import rearrange
|
7 |
import gradio as gr
|
|
|
8 |
import spaces
|
9 |
-
import torch
|
|
|
10 |
import torch.nn.functional as F
|
11 |
from PIL import Image, ImageOps
|
12 |
from transformers import AutoModel, CLIPImageProcessor
|
|
|
|
|
13 |
|
14 |
-
hf_repo = "nvidia/RADIO-L"
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
# RADIO
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
Outperforming teachers in ImageNet zero-shot (+6.8%), kNN (+2.39%), and linear probing segmentation (+3.8%) and vision-language models (LLaVa 1.5 up to 1.5%), it scales to any resolution, supports non-square images.
|
29 |
|
30 |
-
|
|
|
31 |
|
32 |
-
|
33 |
-
"""
|
34 |
|
35 |
-
|
36 |
-
gr.Image(type="pil")
|
37 |
-
]
|
38 |
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
-
outputs = [
|
42 |
-
gr.Textbox(label="Feature Shape"),
|
43 |
-
gr.Image(),
|
44 |
-
]
|
45 |
|
46 |
def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
|
47 |
# features: (N, C)
|
@@ -110,11 +161,11 @@ def get_pca_map(
|
|
110 |
return pca_color
|
111 |
|
112 |
|
113 |
-
def
|
114 |
-
# Calculate the new dimensions to make them multiples
|
115 |
width, height = image.size
|
116 |
-
new_width = (width +
|
117 |
-
new_height = (height +
|
118 |
|
119 |
# Calculate the padding needed on each side
|
120 |
pad_width = new_width - width
|
@@ -131,17 +182,83 @@ def pad_image_to_multiple_of_16(image):
|
|
131 |
return padded_image
|
132 |
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
@spaces.GPU
|
135 |
def infer_radio(image):
|
136 |
"""Define the function to generate the output."""
|
137 |
model.cuda()
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
142 |
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
144 |
|
|
|
145 |
|
146 |
num_rows = height // model.patch_size
|
147 |
num_cols = width // model.patch_size
|
@@ -150,15 +267,49 @@ def infer_radio(image):
|
|
150 |
features = rearrange(features, 'b (h w) c -> b h w c', h=num_rows, w=num_cols).float()
|
151 |
|
152 |
pca_viz = get_pca_map(features, (height, width), interpolation='bilinear')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
-
return f"{features.shape}", pca_viz
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
# Create the Gradio interface
|
158 |
demo = gr.Interface(
|
159 |
fn=infer_radio,
|
160 |
inputs=inputs,
|
161 |
-
examples=
|
162 |
outputs=outputs,
|
163 |
title=title,
|
164 |
description=description,
|
@@ -167,4 +318,3 @@ demo = gr.Interface(
|
|
167 |
|
168 |
if __name__ == "__main__":
|
169 |
demo.launch()
|
170 |
-
|
|
|
1 |
import os
|
2 |
+
import requests
|
3 |
|
4 |
# Disable JIT
|
5 |
os.environ["PYTORCH_JIT"] = "0"
|
6 |
|
7 |
from einops import rearrange
|
8 |
import gradio as gr
|
9 |
+
import numpy as np
|
10 |
import spaces
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
import torch.nn.functional as F
|
14 |
from PIL import Image, ImageOps
|
15 |
from transformers import AutoModel, CLIPImageProcessor
|
16 |
+
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
17 |
+
from segment_anything.modeling.image_encoder import ImageEncoderViT
|
18 |
|
|
|
19 |
|
20 |
+
class RADIOVenc(nn.Module):
|
21 |
+
def __init__(self, radio: nn.Module, img_enc: ImageEncoderViT, img_size: int = 1024):
|
22 |
+
super().__init__()
|
23 |
+
self.radio = radio
|
24 |
+
self.neck = img_enc.neck
|
25 |
+
self.img_size = img_size
|
26 |
+
self.dtype = radio.input_conditioner.dtype
|
27 |
|
28 |
+
def forward(self, x: torch.Tensor):
|
29 |
+
h, w = x.shape[-2:]
|
30 |
|
31 |
+
if self.dtype is not None:
|
32 |
+
x = x.to(dtype=self.dtype)
|
|
|
33 |
|
34 |
+
with torch.autocast('cuda', dtype=torch.bfloat16, enabled=self.dtype is None):
|
35 |
+
output = self.radio(x)
|
36 |
+
features = output["sam"].features
|
|
|
37 |
|
38 |
+
rows = h // 16
|
39 |
+
cols = w // 16
|
40 |
|
41 |
+
features = rearrange(features, 'b (h w) c -> b c h w', h=rows, w=cols)
|
|
|
42 |
|
43 |
+
features = self.neck(features)
|
|
|
|
|
44 |
|
45 |
+
return features
|
46 |
+
|
47 |
+
|
48 |
+
def download_file(url, save_path):
|
49 |
+
# Check if the file already exists
|
50 |
+
if os.path.exists(save_path):
|
51 |
+
print(f"File already exists at {save_path}. Skipping download.")
|
52 |
+
return
|
53 |
+
|
54 |
+
print(f"Downloading from {url}")
|
55 |
+
|
56 |
+
# Send a GET request to the URL
|
57 |
+
response = requests.get(url, stream=True)
|
58 |
+
|
59 |
+
# Check if the request was successful
|
60 |
+
if response.status_code == 200:
|
61 |
+
# Open the file in binary write mode
|
62 |
+
with open(save_path, 'wb') as file:
|
63 |
+
# Iterate over the response content in chunks
|
64 |
+
for chunk in response.iter_content(chunk_size=1024):
|
65 |
+
if chunk: # filter out keep-alive new chunks
|
66 |
+
file.write(chunk)
|
67 |
+
print(f"File downloaded successfully and saved as {save_path}")
|
68 |
+
else:
|
69 |
+
print(f"Failed to download file. HTTP Status Code: {response.status_code}")
|
70 |
+
|
71 |
+
|
72 |
+
hf_repo = "nvidia/RADIO-L"
|
73 |
+
image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
|
74 |
+
|
75 |
+
model_version = "radio_v2.5-l" # for RADIOv2.5-L model (ViT-L/16)
|
76 |
+
|
77 |
+
model = torch.hub.load(
|
78 |
+
'NVlabs/RADIO',
|
79 |
+
'radio_model',
|
80 |
+
version=model_version,
|
81 |
+
progress=True,
|
82 |
+
skip_validation=True,
|
83 |
+
adaptor_names='sam',
|
84 |
+
vitdet_window_size=16)
|
85 |
+
model.eval()
|
86 |
+
|
87 |
+
local_sam_checkpoint_path = "sam_vit_h_4b8939.pth"
|
88 |
+
download_file("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", local_sam_checkpoint_path)
|
89 |
+
sam = sam_model_registry["vit_h"](checkpoint=local_sam_checkpoint_path)
|
90 |
+
model._patch_size = 16
|
91 |
+
sam.image_encoder = RADIOVenc(model, sam.image_encoder, img_size=1024)
|
92 |
+
conditioner = model.make_preprocessor_external()
|
93 |
+
sam.pixel_mean = conditioner.norm_mean * 255
|
94 |
+
sam.pixel_std = conditioner.norm_std * 255
|
95 |
|
|
|
|
|
|
|
|
|
96 |
|
97 |
def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
|
98 |
# features: (N, C)
|
|
|
161 |
return pca_color
|
162 |
|
163 |
|
164 |
+
def pad_image_to_multiple_of(image, multiple=16):
|
165 |
+
# Calculate the new dimensions to make them multiples
|
166 |
width, height = image.size
|
167 |
+
new_width = (width + multiple -1) // multiple * multiple
|
168 |
+
new_height = (height + multiple -1) // multiple * multiple
|
169 |
|
170 |
# Calculate the padding needed on each side
|
171 |
pad_width = new_width - width
|
|
|
182 |
return padded_image
|
183 |
|
184 |
|
185 |
+
def center_crop_resize(image, size=(1024, 1024)):
|
186 |
+
# Get dimensions
|
187 |
+
width, height = image.size
|
188 |
+
|
189 |
+
# Determine the center crop box
|
190 |
+
if width > height:
|
191 |
+
new_width = height
|
192 |
+
new_height = height
|
193 |
+
left = (width - new_width) / 2
|
194 |
+
top = 0
|
195 |
+
right = (width + new_width) / 2
|
196 |
+
bottom = height
|
197 |
+
else:
|
198 |
+
new_width = width
|
199 |
+
new_height = width
|
200 |
+
left = 0
|
201 |
+
top = (height - new_height) / 2
|
202 |
+
right = width
|
203 |
+
bottom = (height + new_height) / 2
|
204 |
+
|
205 |
+
# Crop the image to a square
|
206 |
+
image = image.crop((left, top, right, bottom))
|
207 |
+
|
208 |
+
# Resize the cropped image to the target size
|
209 |
+
image = image.resize(size, Image.LANCZOS)
|
210 |
+
|
211 |
+
return image
|
212 |
+
|
213 |
+
|
214 |
+
def visualize_anns(orig_image: np.ndarray, anns):
|
215 |
+
if len(anns) == 0:
|
216 |
+
return orig_image
|
217 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
218 |
+
|
219 |
+
kernel = torch.ones(1, 1, 5, 5, dtype=torch.float32)
|
220 |
+
|
221 |
+
# RGBA
|
222 |
+
mask = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4), dtype=np.float32)
|
223 |
+
mask[:,:,3] = 0
|
224 |
+
for ann in sorted_anns:
|
225 |
+
m = ann['segmentation']
|
226 |
+
color_mask = np.concatenate([np.random.random(3), [0.35]])
|
227 |
+
|
228 |
+
tm = torch.as_tensor(m).reshape(1, 1, *m.shape).float()
|
229 |
+
cvtm = F.conv2d(tm, kernel, padding=2)
|
230 |
+
|
231 |
+
border_mask = (cvtm < 25).flatten(0, 2).numpy()
|
232 |
+
|
233 |
+
mask[m] = color_mask
|
234 |
+
mask[m & border_mask, 3] *= 1.0 / 0.35
|
235 |
+
|
236 |
+
color, alpha = mask[..., :3], mask[..., -1:]
|
237 |
+
|
238 |
+
orig_image = orig_image.astype(np.float32) / 255
|
239 |
+
overlay = alpha * color + (1 - alpha) * orig_image
|
240 |
+
|
241 |
+
overlay = (overlay * 255).astype(np.uint8)
|
242 |
+
return overlay
|
243 |
+
|
244 |
+
|
245 |
+
|
246 |
@spaces.GPU
|
247 |
def infer_radio(image):
|
248 |
"""Define the function to generate the output."""
|
249 |
model.cuda()
|
250 |
+
conditioner.cuda()
|
251 |
+
sam.cuda()
|
252 |
+
sam_generator = SamAutomaticMaskGenerator(sam, output_mode="binary_mask")
|
|
|
253 |
|
254 |
+
# PCA feature visalization
|
255 |
+
padded_image=pad_image_to_multiple_of(image, multiple=256)
|
256 |
+
width, height = padded_image.size
|
257 |
+
pixel_values = image_processor(images=padded_image, return_tensors='pt').pixel_values
|
258 |
+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
259 |
+
pixel_values = conditioner(pixel_values)
|
260 |
|
261 |
+
_, features = model(pixel_values)["backbone"]
|
262 |
|
263 |
num_rows = height // model.patch_size
|
264 |
num_cols = width // model.patch_size
|
|
|
267 |
features = rearrange(features, 'b (h w) c -> b h w c', h=num_rows, w=num_cols).float()
|
268 |
|
269 |
pca_viz = get_pca_map(features, (height, width), interpolation='bilinear')
|
270 |
+
|
271 |
+
# SAM feature visualization
|
272 |
+
resized_image = center_crop_resize(image)
|
273 |
+
image_array = np.array(image)
|
274 |
+
print("image size", image_array.shape)
|
275 |
+
#image_array = np.transpose(image_array, (2, 0, 1))
|
276 |
+
masks = sam_generator.generate(image_array)
|
277 |
+
overlay = visualize_anns(image_array, masks)
|
278 |
|
279 |
+
return f"{features.shape}", pca_viz, overlay
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
title = """RADIO: Reduce All Domains Into One"""
|
284 |
+
|
285 |
+
description = """
|
286 |
+
# RADIO
|
287 |
+
|
288 |
+
AM-RADIO is a framework to distill Large Vision Foundation models into a single one.
|
289 |
+
RADIO, a new vision foundation model, excels across visual domains, serving as a superior replacement for vision backbones.
|
290 |
+
Integrating CLIP variants, DINOv2, and SAM through distillation, it preserves unique features like text grounding and segmentation correspondence.
|
291 |
+
Outperforming teachers in ImageNet zero-shot (+6.8%), kNN (+2.39%), and linear probing segmentation (+3.8%) and vision-language models (LLaVa 1.5 up to 1.5%), it scales to any resolution, supports non-square images.
|
292 |
+
|
293 |
+
# Instructions
|
294 |
+
|
295 |
+
Simply paste an image or pick one from the gallery of examples and then click the "Submit" button.
|
296 |
+
"""
|
297 |
+
|
298 |
+
inputs = [
|
299 |
+
gr.Image(type="pil")
|
300 |
+
]
|
301 |
|
302 |
+
outputs = [
|
303 |
+
gr.Textbox(label="Feature Shape"),
|
304 |
+
gr.Image(label="PCA Feature Visalization"),
|
305 |
+
gr.Image(label="SAM Masks"),
|
306 |
+
]
|
307 |
|
308 |
# Create the Gradio interface
|
309 |
demo = gr.Interface(
|
310 |
fn=infer_radio,
|
311 |
inputs=inputs,
|
312 |
+
examples="./samples/",
|
313 |
outputs=outputs,
|
314 |
title=title,
|
315 |
description=description,
|
|
|
318 |
|
319 |
if __name__ == "__main__":
|
320 |
demo.launch()
|
|