Upload 3 files
Browse files- configuration_hat.py +1 -1
- modelling_hat.py +0 -1
- tokenization_hat.py +22 -0
configuration_hat.py
CHANGED
@@ -147,4 +147,4 @@ class HATOnnxConfig(OnnxConfig):
|
|
147 |
("input_ids", {0: "batch", 1: "sequence"}),
|
148 |
("attention_mask", {0: "batch", 1: "sequence"}),
|
149 |
]
|
150 |
-
)
|
|
|
147 |
("input_ids", {0: "batch", 1: "sequence"}),
|
148 |
("attention_mask", {0: "batch", 1: "sequence"}),
|
149 |
]
|
150 |
+
)
|
modelling_hat.py
CHANGED
@@ -2357,4 +2357,3 @@ def off_diagonal(x):
|
|
2357 |
assert n == m
|
2358 |
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
|
2359 |
|
2360 |
-
|
|
|
2357 |
assert n == m
|
2358 |
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
|
2359 |
|
|
tokenization_hat.py
CHANGED
@@ -246,4 +246,26 @@ class HATTokenizer:
|
|
246 |
flat_input[:chunk_size-1],
|
247 |
torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
|
248 |
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
|
|
246 |
flat_input[:chunk_size-1],
|
247 |
torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
|
248 |
))
|
249 |
+
|
250 |
+
@classmethod
|
251 |
+
def register_for_auto_class(cls, auto_class="AutoModel"):
|
252 |
+
"""
|
253 |
+
Register this class with a given auto class. This should only be used for custom models as the ones in the
|
254 |
+
library are already mapped with an auto class.
|
255 |
+
<Tip warning={true}>
|
256 |
+
This API is experimental and may have some slight breaking changes in the next releases.
|
257 |
+
</Tip>
|
258 |
+
Args:
|
259 |
+
auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
|
260 |
+
The auto class to register this new model with.
|
261 |
+
"""
|
262 |
+
if not isinstance(auto_class, str):
|
263 |
+
auto_class = auto_class.__name__
|
264 |
+
|
265 |
+
import transformers.models.auto as auto_module
|
266 |
+
|
267 |
+
if not hasattr(auto_module, auto_class):
|
268 |
+
raise ValueError(f"{auto_class} is not a valid auto class.")
|
269 |
+
|
270 |
+
cls._auto_class = auto_class
|
271 |
|