d-matrix commited on
Commit
e8a9343
1 Parent(s): 096defa

Update dmx_perplexity.py

Browse files

supporting model parallel for gpt

Files changed (1) hide show
  1. dmx_perplexity.py +7 -0
dmx_perplexity.py CHANGED
@@ -105,6 +105,13 @@ class DmxPerplexity(evaluate.Metric):
105
  target_ids = input_ids.clone()
106
  target_ids[:, :-trg_len] = -100
107
 
 
 
 
 
 
 
 
108
  with torch.no_grad():
109
  outputs = model(input_ids, labels=target_ids)
110
  if isinstance(outputs, Dict):
 
105
  target_ids = input_ids.clone()
106
  target_ids[:, :-trg_len] = -100
107
 
108
+ # Setting device for labels if mdodel has device mapping
109
+ if hasattr(model, "hf_device_map"):
110
+ last_device = "cuda:" + str(
111
+ max(model.hf_device_map.values())
112
+ )
113
+ target_ids = target_ids.to(last_device)
114
+
115
  with torch.no_grad():
116
  outputs = model(input_ids, labels=target_ids)
117
  if isinstance(outputs, Dict):