gheinrich commited on
Commit
2de3215
·
verified ·
1 Parent(s): 6d024c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -35
app.py CHANGED
@@ -1,47 +1,98 @@
1
  import os
 
2
 
3
  # Disable JIT
4
  os.environ["PYTORCH_JIT"] = "0"
5
 
6
  from einops import rearrange
7
  import gradio as gr
 
8
  import spaces
9
- import torch
 
10
  import torch.nn.functional as F
11
  from PIL import Image, ImageOps
12
  from transformers import AutoModel, CLIPImageProcessor
 
 
13
 
14
- hf_repo = "nvidia/RADIO-L"
15
 
16
- image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
17
- model = AutoModel.from_pretrained(hf_repo, trust_remote_code=True)
18
- model.eval()
 
 
 
 
19
 
 
 
20
 
21
- title = """RADIO: Reduce All Domains Into One"""
22
- description = """
23
- # RADIO
24
 
25
- AM-RADIO is a framework to distill Large Vision Foundation models into a single one.
26
- RADIO, a new vision foundation model, excels across visual domains, serving as a superior replacement for vision backbones.
27
- Integrating CLIP variants, DINOv2, and SAM through distillation, it preserves unique features like text grounding and segmentation correspondence.
28
- Outperforming teachers in ImageNet zero-shot (+6.8%), kNN (+2.39%), and linear probing segmentation (+3.8%) and vision-language models (LLaVa 1.5 up to 1.5%), it scales to any resolution, supports non-square images.
29
 
30
- # Instructions
 
31
 
32
- Simply paste an image or pick one from the gallery of examples and then click the "Submit" button.
33
- """
34
 
35
- inputs = [
36
- gr.Image(type="pil")
37
- ]
38
 
39
- examples = "./samples/"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- outputs = [
42
- gr.Textbox(label="Feature Shape"),
43
- gr.Image(),
44
- ]
45
 
46
  def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
47
  # features: (N, C)
@@ -110,11 +161,11 @@ def get_pca_map(
110
  return pca_color
111
 
112
 
113
- def pad_image_to_multiple_of_16(image):
114
- # Calculate the new dimensions to make them multiples of 16
115
  width, height = image.size
116
- new_width = (width + 15) // 16 * 16
117
- new_height = (height + 15) // 16 * 16
118
 
119
  # Calculate the padding needed on each side
120
  pad_width = new_width - width
@@ -131,17 +182,83 @@ def pad_image_to_multiple_of_16(image):
131
  return padded_image
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  @spaces.GPU
135
  def infer_radio(image):
136
  """Define the function to generate the output."""
137
  model.cuda()
138
- image=pad_image_to_multiple_of_16(image)
139
- width, height = image.size
140
- pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
141
- pixel_values = pixel_values.to(torch.bfloat16).cuda()
142
 
143
- _, features = model(pixel_values)
 
 
 
 
 
144
 
 
145
 
146
  num_rows = height // model.patch_size
147
  num_cols = width // model.patch_size
@@ -150,15 +267,49 @@ def infer_radio(image):
150
  features = rearrange(features, 'b (h w) c -> b h w c', h=num_rows, w=num_cols).float()
151
 
152
  pca_viz = get_pca_map(features, (height, width), interpolation='bilinear')
 
 
 
 
 
 
 
 
153
 
154
- return f"{features.shape}", pca_viz
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
 
 
 
 
 
156
 
157
  # Create the Gradio interface
158
  demo = gr.Interface(
159
  fn=infer_radio,
160
  inputs=inputs,
161
- examples=examples,
162
  outputs=outputs,
163
  title=title,
164
  description=description,
@@ -167,4 +318,3 @@ demo = gr.Interface(
167
 
168
  if __name__ == "__main__":
169
  demo.launch()
170
-
 
1
  import os
2
+ import requests
3
 
4
  # Disable JIT
5
  os.environ["PYTORCH_JIT"] = "0"
6
 
7
  from einops import rearrange
8
  import gradio as gr
9
+ import numpy as np
10
  import spaces
11
+ import torch
12
+ import torch.nn as nn
13
  import torch.nn.functional as F
14
  from PIL import Image, ImageOps
15
  from transformers import AutoModel, CLIPImageProcessor
16
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
17
+ from segment_anything.modeling.image_encoder import ImageEncoderViT
18
 
 
19
 
20
+ class RADIOVenc(nn.Module):
21
+ def __init__(self, radio: nn.Module, img_enc: ImageEncoderViT, img_size: int = 1024):
22
+ super().__init__()
23
+ self.radio = radio
24
+ self.neck = img_enc.neck
25
+ self.img_size = img_size
26
+ self.dtype = radio.input_conditioner.dtype
27
 
28
+ def forward(self, x: torch.Tensor):
29
+ h, w = x.shape[-2:]
30
 
31
+ if self.dtype is not None:
32
+ x = x.to(dtype=self.dtype)
 
33
 
34
+ with torch.autocast('cuda', dtype=torch.bfloat16, enabled=self.dtype is None):
35
+ output = self.radio(x)
36
+ features = output["sam"].features
 
37
 
38
+ rows = h // 16
39
+ cols = w // 16
40
 
41
+ features = rearrange(features, 'b (h w) c -> b c h w', h=rows, w=cols)
 
42
 
43
+ features = self.neck(features)
 
 
44
 
45
+ return features
46
+
47
+
48
+ def download_file(url, save_path):
49
+ # Check if the file already exists
50
+ if os.path.exists(save_path):
51
+ print(f"File already exists at {save_path}. Skipping download.")
52
+ return
53
+
54
+ print(f"Downloading from {url}")
55
+
56
+ # Send a GET request to the URL
57
+ response = requests.get(url, stream=True)
58
+
59
+ # Check if the request was successful
60
+ if response.status_code == 200:
61
+ # Open the file in binary write mode
62
+ with open(save_path, 'wb') as file:
63
+ # Iterate over the response content in chunks
64
+ for chunk in response.iter_content(chunk_size=1024):
65
+ if chunk: # filter out keep-alive new chunks
66
+ file.write(chunk)
67
+ print(f"File downloaded successfully and saved as {save_path}")
68
+ else:
69
+ print(f"Failed to download file. HTTP Status Code: {response.status_code}")
70
+
71
+
72
+ hf_repo = "nvidia/RADIO-L"
73
+ image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
74
+
75
+ model_version = "radio_v2.5-l" # for RADIOv2.5-L model (ViT-L/16)
76
+
77
+ model = torch.hub.load(
78
+ 'NVlabs/RADIO',
79
+ 'radio_model',
80
+ version=model_version,
81
+ progress=True,
82
+ skip_validation=True,
83
+ adaptor_names='sam',
84
+ vitdet_window_size=16)
85
+ model.eval()
86
+
87
+ local_sam_checkpoint_path = "sam_vit_h_4b8939.pth"
88
+ download_file("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", local_sam_checkpoint_path)
89
+ sam = sam_model_registry["vit_h"](checkpoint=local_sam_checkpoint_path)
90
+ model._patch_size = 16
91
+ sam.image_encoder = RADIOVenc(model, sam.image_encoder, img_size=1024)
92
+ conditioner = model.make_preprocessor_external()
93
+ sam.pixel_mean = conditioner.norm_mean * 255
94
+ sam.pixel_std = conditioner.norm_std * 255
95
 
 
 
 
 
96
 
97
  def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
98
  # features: (N, C)
 
161
  return pca_color
162
 
163
 
164
+ def pad_image_to_multiple_of(image, multiple=16):
165
+ # Calculate the new dimensions to make them multiples
166
  width, height = image.size
167
+ new_width = (width + multiple -1) // multiple * multiple
168
+ new_height = (height + multiple -1) // multiple * multiple
169
 
170
  # Calculate the padding needed on each side
171
  pad_width = new_width - width
 
182
  return padded_image
183
 
184
 
185
+ def center_crop_resize(image, size=(1024, 1024)):
186
+ # Get dimensions
187
+ width, height = image.size
188
+
189
+ # Determine the center crop box
190
+ if width > height:
191
+ new_width = height
192
+ new_height = height
193
+ left = (width - new_width) / 2
194
+ top = 0
195
+ right = (width + new_width) / 2
196
+ bottom = height
197
+ else:
198
+ new_width = width
199
+ new_height = width
200
+ left = 0
201
+ top = (height - new_height) / 2
202
+ right = width
203
+ bottom = (height + new_height) / 2
204
+
205
+ # Crop the image to a square
206
+ image = image.crop((left, top, right, bottom))
207
+
208
+ # Resize the cropped image to the target size
209
+ image = image.resize(size, Image.LANCZOS)
210
+
211
+ return image
212
+
213
+
214
+ def visualize_anns(orig_image: np.ndarray, anns):
215
+ if len(anns) == 0:
216
+ return orig_image
217
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
218
+
219
+ kernel = torch.ones(1, 1, 5, 5, dtype=torch.float32)
220
+
221
+ # RGBA
222
+ mask = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4), dtype=np.float32)
223
+ mask[:,:,3] = 0
224
+ for ann in sorted_anns:
225
+ m = ann['segmentation']
226
+ color_mask = np.concatenate([np.random.random(3), [0.35]])
227
+
228
+ tm = torch.as_tensor(m).reshape(1, 1, *m.shape).float()
229
+ cvtm = F.conv2d(tm, kernel, padding=2)
230
+
231
+ border_mask = (cvtm < 25).flatten(0, 2).numpy()
232
+
233
+ mask[m] = color_mask
234
+ mask[m & border_mask, 3] *= 1.0 / 0.35
235
+
236
+ color, alpha = mask[..., :3], mask[..., -1:]
237
+
238
+ orig_image = orig_image.astype(np.float32) / 255
239
+ overlay = alpha * color + (1 - alpha) * orig_image
240
+
241
+ overlay = (overlay * 255).astype(np.uint8)
242
+ return overlay
243
+
244
+
245
+
246
  @spaces.GPU
247
  def infer_radio(image):
248
  """Define the function to generate the output."""
249
  model.cuda()
250
+ conditioner.cuda()
251
+ sam.cuda()
252
+ sam_generator = SamAutomaticMaskGenerator(sam, output_mode="binary_mask")
 
253
 
254
+ # PCA feature visalization
255
+ padded_image=pad_image_to_multiple_of(image, multiple=256)
256
+ width, height = padded_image.size
257
+ pixel_values = image_processor(images=padded_image, return_tensors='pt').pixel_values
258
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
259
+ pixel_values = conditioner(pixel_values)
260
 
261
+ _, features = model(pixel_values)["backbone"]
262
 
263
  num_rows = height // model.patch_size
264
  num_cols = width // model.patch_size
 
267
  features = rearrange(features, 'b (h w) c -> b h w c', h=num_rows, w=num_cols).float()
268
 
269
  pca_viz = get_pca_map(features, (height, width), interpolation='bilinear')
270
+
271
+ # SAM feature visualization
272
+ resized_image = center_crop_resize(image)
273
+ image_array = np.array(image)
274
+ print("image size", image_array.shape)
275
+ #image_array = np.transpose(image_array, (2, 0, 1))
276
+ masks = sam_generator.generate(image_array)
277
+ overlay = visualize_anns(image_array, masks)
278
 
279
+ return f"{features.shape}", pca_viz, overlay
280
+
281
+
282
+
283
+ title = """RADIO: Reduce All Domains Into One"""
284
+
285
+ description = """
286
+ # RADIO
287
+
288
+ AM-RADIO is a framework to distill Large Vision Foundation models into a single one.
289
+ RADIO, a new vision foundation model, excels across visual domains, serving as a superior replacement for vision backbones.
290
+ Integrating CLIP variants, DINOv2, and SAM through distillation, it preserves unique features like text grounding and segmentation correspondence.
291
+ Outperforming teachers in ImageNet zero-shot (+6.8%), kNN (+2.39%), and linear probing segmentation (+3.8%) and vision-language models (LLaVa 1.5 up to 1.5%), it scales to any resolution, supports non-square images.
292
+
293
+ # Instructions
294
+
295
+ Simply paste an image or pick one from the gallery of examples and then click the "Submit" button.
296
+ """
297
+
298
+ inputs = [
299
+ gr.Image(type="pil")
300
+ ]
301
 
302
+ outputs = [
303
+ gr.Textbox(label="Feature Shape"),
304
+ gr.Image(label="PCA Feature Visalization"),
305
+ gr.Image(label="SAM Masks"),
306
+ ]
307
 
308
  # Create the Gradio interface
309
  demo = gr.Interface(
310
  fn=infer_radio,
311
  inputs=inputs,
312
+ examples="./samples/",
313
  outputs=outputs,
314
  title=title,
315
  description=description,
 
318
 
319
  if __name__ == "__main__":
320
  demo.launch()