mranzinger
commited on
Commit
•
d014741
1
Parent(s):
95b9c86
Make HF interface compatible
Browse files- hf_model.py +30 -6
hf_model.py
CHANGED
@@ -124,14 +124,38 @@ class RADIOModel(PreTrainedModel):
|
|
124 |
def input_conditioner(self) -> InputConditioner:
|
125 |
return self.radio_model.input_conditioner
|
126 |
|
127 |
-
@
|
128 |
-
def
|
129 |
-
self.radio_model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
135 |
|
136 |
def forward(self, x: torch.Tensor):
|
137 |
return self.radio_model.forward(x)
|
|
|
124 |
def input_conditioner(self) -> InputConditioner:
|
125 |
return self.radio_model.input_conditioner
|
126 |
|
127 |
+
@property
|
128 |
+
def num_summary_tokens(self) -> int:
|
129 |
+
return self.radio_model.num_summary_tokens
|
130 |
+
|
131 |
+
@property
|
132 |
+
def patch_size(self) -> int:
|
133 |
+
return self.radio_model.patch_size
|
134 |
+
|
135 |
+
@property
|
136 |
+
def max_resolution(self) -> int:
|
137 |
+
return self.radio_model.max_resolution
|
138 |
+
|
139 |
+
@property
|
140 |
+
def preferred_resolution(self) -> Resolution:
|
141 |
+
return self.radio_model.preferred_resolution
|
142 |
+
|
143 |
+
@property
|
144 |
+
def window_size(self) -> int:
|
145 |
+
return self.radio_model.window_size
|
146 |
+
|
147 |
+
@property
|
148 |
+
def min_resolution_step(self) -> int:
|
149 |
+
return self.radio_model.min_resolution_step
|
150 |
|
151 |
def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
152 |
+
return self.radio_model.make_preprocessor_external()
|
153 |
+
|
154 |
+
def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:
|
155 |
+
return self.radio_model.get_nearest_supported_resolution(height, width)
|
156 |
+
|
157 |
+
def switch_to_deploy(self):
|
158 |
+
return self.radio_model.switch_to_deploy()
|
159 |
|
160 |
def forward(self, x: torch.Tensor):
|
161 |
return self.radio_model.forward(x)
|