jadechoghari commited on
Commit
e0c4817
1 Parent(s): cd769b1

Create dino_wrapper2.py

Browse files
Files changed (1) hide show
  1. dino_wrapper2.py +51 -0
dino_wrapper2.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch.nn as nn
9
+ from transformers import ViTImageProcessor, ViTModel, AutoImageProcessor, AutoModel, Dinov2Model
10
+
11
+ class DinoWrapper(nn.Module):
12
+ """
13
+ Dino v1 wrapper using huggingface transformer implementation.
14
+ """
15
+ def __init__(self, model_name: str, freeze: bool = True):
16
+ super().__init__()
17
+ self.model, self.processor = self._build_dino(model_name)
18
+ if freeze:
19
+ self._freeze()
20
+
21
+ def forward(self, image):
22
+ # image: [N, C, H, W], on cpu
23
+ # RGB image with [0,1] scale and properly sized
24
+ inputs = self.processor(images=image.float(), return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device)
25
+ # This resampling of positional embedding uses bicubic interpolation
26
+ outputs = self.model(**inputs)
27
+ last_hidden_states = outputs.last_hidden_state
28
+ return last_hidden_states
29
+
30
+ def _freeze(self):
31
+ print(f"======== Freezing DinoWrapper ========")
32
+ self.model.eval()
33
+ for name, param in self.model.named_parameters():
34
+ param.requires_grad = False
35
+
36
+ @staticmethod
37
+ def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
38
+ import requests
39
+ try:
40
+ processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
41
+ processor.do_center_crop = False
42
+ model = AutoModel.from_pretrained('facebook/dinov2-base')
43
+ return model, processor
44
+ except requests.exceptions.ProxyError as err:
45
+ if proxy_error_retries > 0:
46
+ print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
47
+ import time
48
+ time.sleep(proxy_error_cooldown)
49
+ return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
50
+ else:
51
+ raise err