Spaces:
Running
Running
MeYourHint
commited on
Commit
•
2012098
1
Parent(s):
9629d26
Minor
Browse files
models/mask_transformer/transformer.py
CHANGED
@@ -179,9 +179,10 @@ class MaskTransformer(nn.Module):
|
|
179 |
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
|
180 |
jit=False) # Must set jit=False for training
|
181 |
# Cannot run on cpu
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
185 |
|
186 |
# Freeze CLIP weights
|
187 |
clip_model.eval()
|
@@ -731,9 +732,10 @@ class ResidualTransformer(nn.Module):
|
|
731 |
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
|
732 |
jit=False) # Must set jit=False for training
|
733 |
# Cannot run on cpu
|
734 |
-
|
735 |
-
|
736 |
-
|
|
|
737 |
|
738 |
# Freeze CLIP weights
|
739 |
clip_model.eval()
|
|
|
179 |
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
|
180 |
jit=False) # Must set jit=False for training
|
181 |
# Cannot run on cpu
|
182 |
+
if str(self.evice) != "cpu":
|
183 |
+
clip.model.convert_weights(
|
184 |
+
clip_model) # Actually this line is unnecessary since clip by default already on float16
|
185 |
+
# Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
|
186 |
|
187 |
# Freeze CLIP weights
|
188 |
clip_model.eval()
|
|
|
732 |
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
|
733 |
jit=False) # Must set jit=False for training
|
734 |
# Cannot run on cpu
|
735 |
+
if str(self.evice) != "cpu":
|
736 |
+
clip.model.convert_weights(
|
737 |
+
clip_model) # Actually this line is unnecessary since clip by default already on float16
|
738 |
+
# Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
|
739 |
|
740 |
# Freeze CLIP weights
|
741 |
clip_model.eval()
|
options/base_option.py
CHANGED
@@ -12,7 +12,7 @@ class BaseOptions():
|
|
12 |
|
13 |
self.parser.add_argument('--vq_name', type=str, default="rvq_nq1_dc512_nc512", help='Name of the rvq model.')
|
14 |
|
15 |
-
self.parser.add_argument("--gpu_id", type=int, default
|
16 |
self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name, {t2m} for humanml3d, {kit} for kit-ml')
|
17 |
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here.')
|
18 |
|
|
|
12 |
|
13 |
self.parser.add_argument('--vq_name', type=str, default="rvq_nq1_dc512_nc512", help='Name of the rvq model.')
|
14 |
|
15 |
+
self.parser.add_argument("--gpu_id", type=int, default=-1, help='GPU id')
|
16 |
self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name, {t2m} for humanml3d, {kit} for kit-ml')
|
17 |
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here.')
|
18 |
|