Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn commited on
Commit
b672c72
1 Parent(s): cbca4e2

fix: dtype

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +5 -1
modeling_hf_nomic_bert.py CHANGED
@@ -391,6 +391,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
391
  num_labels = kwargs.pop("num_labels", None)
392
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
393
  strict = kwargs.pop("strict", True)
 
394
  if rotary_scaling_factor:
395
  config.rotary_scaling_factor = rotary_scaling_factor
396
 
@@ -406,6 +407,9 @@ class NomicBertPreTrainedModel(PreTrainedModel):
406
  model = cls(config, *inputs, add_pooling_layer=False)
407
  else:
408
  model = cls(config, *inputs)
 
 
 
409
  # TODO: fix this
410
  # Assuming we know what we're doing when loading from disk
411
  # Prob a bad assumption but i'm tired and want to train this asap
@@ -424,7 +428,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
424
  load_return = model.load_state_dict(state_dict, strict=False)
425
  else:
426
  # TODO: can probably check config class and see if we need to remap from a bert model
427
- state_dict = state_dict_from_pretrained(model_name)
428
  state_dict = remap_bert_state_dict(
429
  state_dict,
430
  config,
 
391
  num_labels = kwargs.pop("num_labels", None)
392
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
393
  strict = kwargs.pop("strict", True)
394
+ dtype = kwargs.pop("dtype", None)
395
  if rotary_scaling_factor:
396
  config.rotary_scaling_factor = rotary_scaling_factor
397
 
 
407
  model = cls(config, *inputs, add_pooling_layer=False)
408
  else:
409
  model = cls(config, *inputs)
410
+
411
+ if dtype is not None:
412
+ model = model.to(dtype=dtype)
413
  # TODO: fix this
414
  # Assuming we know what we're doing when loading from disk
415
  # Prob a bad assumption but i'm tired and want to train this asap
 
428
  load_return = model.load_state_dict(state_dict, strict=False)
429
  else:
430
  # TODO: can probably check config class and see if we need to remap from a bert model
431
+ state_dict = state_dict_from_pretrained(model_name, dtype=dtype)
432
  state_dict = remap_bert_state_dict(
433
  state_dict,
434
  config,