d-matrix commited on
Commit
b0c3beb
1 Parent(s): 6d20fa3

drop device setting for already parallel models

Browse files
Files changed (1) hide show
  1. dmx_perplexity.py +7 -2
dmx_perplexity.py CHANGED
@@ -40,6 +40,7 @@ Examples:
40
  46.05925369262695
41
  """
42
 
 
43
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
44
  class DmxPerplexity(evaluate.Metric):
45
  def _info(self):
@@ -89,9 +90,13 @@ class DmxPerplexity(evaluate.Metric):
89
  max_seq_len = model.config.n_positions
90
  else:
91
  max_seq_len = 2048
92
-
93
- if not hasattr(model, "hf_device_map"):
 
 
94
  model = model.to(device)
 
 
95
  encodings = tokenizer("\n\n".join(references), return_tensors="pt")
96
 
97
  stride = max_seq_len
 
40
  46.05925369262695
41
  """
42
 
43
+
44
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
45
  class DmxPerplexity(evaluate.Metric):
46
  def _info(self):
 
90
  max_seq_len = model.config.n_positions
91
  else:
92
  max_seq_len = 2048
93
+
94
+ if not hasattr(model, "hf_device_map") and (
95
+ not hasattr(model, "model_parallel") or not model.model_parallel
96
+ ):
97
  model = model.to(device)
98
+
99
+ model.eval()
100
  encodings = tokenizer("\n\n".join(references), return_tensors="pt")
101
 
102
  stride = max_seq_len