mranzinger commited on
Commit
d014741
1 Parent(s): 95b9c86

Make HF interface compatible

Browse files
Files changed (1) hide show
  1. hf_model.py +30 -6
hf_model.py CHANGED
@@ -124,14 +124,38 @@ class RADIOModel(PreTrainedModel):
124
  def input_conditioner(self) -> InputConditioner:
125
  return self.radio_model.input_conditioner
126
 
127
- @input_conditioner.setter
128
- def input_conditioner(self, v: InputConditioner):
129
- self.radio_model.input_conditioner = v
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
132
- ret = self.input_conditioner
133
- self.input_conditioner = nn.Identity()
134
- return ret
 
 
 
 
135
 
136
  def forward(self, x: torch.Tensor):
137
  return self.radio_model.forward(x)
 
124
  def input_conditioner(self) -> InputConditioner:
125
  return self.radio_model.input_conditioner
126
 
127
+ @property
128
+ def num_summary_tokens(self) -> int:
129
+ return self.radio_model.num_summary_tokens
130
+
131
+ @property
132
+ def patch_size(self) -> int:
133
+ return self.radio_model.patch_size
134
+
135
+ @property
136
+ def max_resolution(self) -> int:
137
+ return self.radio_model.max_resolution
138
+
139
+ @property
140
+ def preferred_resolution(self) -> Resolution:
141
+ return self.radio_model.preferred_resolution
142
+
143
+ @property
144
+ def window_size(self) -> int:
145
+ return self.radio_model.window_size
146
+
147
+ @property
148
+ def min_resolution_step(self) -> int:
149
+ return self.radio_model.min_resolution_step
150
 
151
  def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
152
+ return self.radio_model.make_preprocessor_external()
153
+
154
+ def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:
155
+ return self.radio_model.get_nearest_supported_resolution(height, width)
156
+
157
+ def switch_to_deploy(self):
158
+ return self.radio_model.switch_to_deploy()
159
 
160
  def forward(self, x: torch.Tensor):
161
  return self.radio_model.forward(x)