Update model_helper.py
Browse files- model_helper.py +1 -1
model_helper.py
CHANGED
@@ -104,7 +104,7 @@ def load_model_checkpoint(args=None, device='cpu'):
|
|
104 |
print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")
|
105 |
|
106 |
# Use GPU if available
|
107 |
-
|
108 |
|
109 |
# Model
|
110 |
model = YourMT3(
|
|
|
104 |
print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")
|
105 |
|
106 |
# Use GPU if available
|
107 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
108 |
|
109 |
# Model
|
110 |
model = YourMT3(
|