Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn commited on
Commit
1fe3d3f
·
verified ·
1 Parent(s): b476a30

state_dict fixes

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +8 -10
modeling_hf_nomic_bert.py CHANGED
@@ -59,19 +59,17 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
59
  is_sharded = True
60
  load_safe = True
61
  else: # Try loading from HF hub instead of from local files
62
- weight_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME
63
- resolved_archive_file = cached_file(
64
- model_name, weight_name, token=True, _raise_exceptions_for_missing_entries=False
65
- )
66
- if resolved_archive_file is None:
67
- weight_index = WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
68
  resolved_archive_file = cached_file(
69
- model_name, weight_index, token=True, _raise_exceptions_for_missing_entries=False
70
  )
71
  if resolved_archive_file is not None:
72
- is_sharded = True
73
-
74
- load_safe = safe_serialization
 
 
75
 
76
  if resolved_archive_file is None:
77
  raise EnvironmentError(f"Model name {model_name} was not found.")
 
59
  is_sharded = True
60
  load_safe = True
61
  else: # Try loading from HF hub instead of from local files
62
+ resolved_archive_file = None
63
+ for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
 
 
 
 
64
  resolved_archive_file = cached_file(
65
+ model_name, weight_name, token=True, _raise_exceptions_for_missing_entries=False
66
  )
67
  if resolved_archive_file is not None:
68
+ if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]:
69
+ load_safe = True
70
+ if weight_name in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
71
+ is_sharded = True
72
+ break
73
 
74
  if resolved_archive_file is None:
75
  raise EnvironmentError(f"Model name {model_name} was not found.")