Spaces:
Sleeping
Sleeping
d-matrix
commited on
Commit
•
e8a9343
1
Parent(s):
096defa
Update dmx_perplexity.py
Browse filessupporting model parallel for gpt
- dmx_perplexity.py +7 -0
dmx_perplexity.py
CHANGED
@@ -105,6 +105,13 @@ class DmxPerplexity(evaluate.Metric):
|
|
105 |
target_ids = input_ids.clone()
|
106 |
target_ids[:, :-trg_len] = -100
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
with torch.no_grad():
|
109 |
outputs = model(input_ids, labels=target_ids)
|
110 |
if isinstance(outputs, Dict):
|
|
|
105 |
target_ids = input_ids.clone()
|
106 |
target_ids[:, :-trg_len] = -100
|
107 |
|
108 |
+
# Setting device for labels if mdodel has device mapping
|
109 |
+
if hasattr(model, "hf_device_map"):
|
110 |
+
last_device = "cuda:" + str(
|
111 |
+
max(model.hf_device_map.values())
|
112 |
+
)
|
113 |
+
target_ids = target_ids.to(last_device)
|
114 |
+
|
115 |
with torch.no_grad():
|
116 |
outputs = model(input_ids, labels=target_ids)
|
117 |
if isinstance(outputs, Dict):
|