fix: dtype
Browse files
modeling_hf_nomic_bert.py
CHANGED
@@ -391,6 +391,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
391 |
num_labels = kwargs.pop("num_labels", None)
|
392 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
393 |
strict = kwargs.pop("strict", True)
|
|
|
394 |
if rotary_scaling_factor:
|
395 |
config.rotary_scaling_factor = rotary_scaling_factor
|
396 |
|
@@ -406,6 +407,9 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
406 |
model = cls(config, *inputs, add_pooling_layer=False)
|
407 |
else:
|
408 |
model = cls(config, *inputs)
|
|
|
|
|
|
|
409 |
# TODO: fix this
|
410 |
# Assuming we know what we're doing when loading from disk
|
411 |
# Prob a bad assumption but i'm tired and want to train this asap
|
@@ -424,7 +428,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
424 |
load_return = model.load_state_dict(state_dict, strict=False)
|
425 |
else:
|
426 |
# TODO: can probably check config class and see if we need to remap from a bert model
|
427 |
-
state_dict = state_dict_from_pretrained(model_name)
|
428 |
state_dict = remap_bert_state_dict(
|
429 |
state_dict,
|
430 |
config,
|
|
|
391 |
num_labels = kwargs.pop("num_labels", None)
|
392 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
393 |
strict = kwargs.pop("strict", True)
|
394 |
+
dtype = kwargs.pop("dtype", None)
|
395 |
if rotary_scaling_factor:
|
396 |
config.rotary_scaling_factor = rotary_scaling_factor
|
397 |
|
|
|
407 |
model = cls(config, *inputs, add_pooling_layer=False)
|
408 |
else:
|
409 |
model = cls(config, *inputs)
|
410 |
+
|
411 |
+
if dtype is not None:
|
412 |
+
model = model.to(dtype=dtype)
|
413 |
# TODO: fix this
|
414 |
# Assuming we know what we're doing when loading from disk
|
415 |
# Prob a bad assumption but i'm tired and want to train this asap
|
|
|
428 |
load_return = model.load_state_dict(state_dict, strict=False)
|
429 |
else:
|
430 |
# TODO: can probably check config class and see if we need to remap from a bert model
|
431 |
+
state_dict = state_dict_from_pretrained(model_name, dtype=dtype)
|
432 |
state_dict = remap_bert_state_dict(
|
433 |
state_dict,
|
434 |
config,
|