Spaces:
Running
Running
File size: 7,113 Bytes
5eeb557 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
diff --git a/models/hierarchy_inference_model.py b/models/hierarchy_inference_model.py
index 3116307..5de661d 100644
--- a/models/hierarchy_inference_model.py
+++ b/models/hierarchy_inference_model.py
@@ -21,7 +21,7 @@ class VQGANTextureAwareSpatialHierarchyInferenceModel():
def __init__(self, opt):
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
self.is_train = opt['is_train']
self.top_encoder = Encoder(
diff --git a/models/hierarchy_vqgan_model.py b/models/hierarchy_vqgan_model.py
index 4b0d657..0bf4712 100644
--- a/models/hierarchy_vqgan_model.py
+++ b/models/hierarchy_vqgan_model.py
@@ -20,7 +20,7 @@ class HierarchyVQSpatialTextureAwareModel():
def __init__(self, opt):
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
self.top_encoder = Encoder(
ch=opt['top_ch'],
num_res_blocks=opt['top_num_res_blocks'],
diff --git a/models/parsing_gen_model.py b/models/parsing_gen_model.py
index 9440345..15a1ecb 100644
--- a/models/parsing_gen_model.py
+++ b/models/parsing_gen_model.py
@@ -22,7 +22,7 @@ class ParsingGenModel():
def __init__(self, opt):
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
self.is_train = opt['is_train']
self.attr_embedder = ShapeAttrEmbedding(
diff --git a/models/sample_model.py b/models/sample_model.py
index 4c60e3f..5265cd0 100644
--- a/models/sample_model.py
+++ b/models/sample_model.py
@@ -23,7 +23,7 @@ class BaseSampleModel():
def __init__(self, opt):
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
# hierarchical VQVAE
self.decoder = Decoder(
@@ -123,7 +123,7 @@ class BaseSampleModel():
def load_top_pretrain_models(self):
# load pretrained vqgan
- top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device)
self.decoder.load_state_dict(
top_vae_checkpoint['decoder'], strict=True)
@@ -137,7 +137,7 @@ class BaseSampleModel():
self.top_post_quant_conv.eval()
def load_bot_pretrain_network(self):
- checkpoint = torch.load(self.opt['bot_vae_path'])
+ checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device)
self.bot_decoder_res.load_state_dict(
checkpoint['bot_decoder_res'], strict=True)
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
@@ -153,7 +153,7 @@ class BaseSampleModel():
def load_pretrained_segm_token(self):
# load pretrained vqgan for segmentation mask
- segm_token_checkpoint = torch.load(self.opt['segm_token_path'])
+ segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device)
self.segm_encoder.load_state_dict(
segm_token_checkpoint['encoder'], strict=True)
self.segm_quantizer.load_state_dict(
@@ -166,7 +166,7 @@ class BaseSampleModel():
self.segm_quant_conv.eval()
def load_index_pred_network(self):
- checkpoint = torch.load(self.opt['pretrained_index_network'])
+ checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device)
self.index_pred_guidance_encoder.load_state_dict(
checkpoint['guidance_encoder'], strict=True)
self.index_pred_decoder.load_state_dict(
@@ -176,7 +176,7 @@ class BaseSampleModel():
self.index_pred_decoder.eval()
def load_sampler_pretrained_network(self):
- checkpoint = torch.load(self.opt['pretrained_sampler'])
+ checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device)
self.sampler_fn.load_state_dict(checkpoint, strict=True)
self.sampler_fn.eval()
@@ -397,7 +397,7 @@ class SampleFromPoseModel(BaseSampleModel):
[185, 210, 205], [130, 165, 180], [225, 141, 151]]
def load_shape_generation_models(self):
- checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
+ checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device)
self.shape_attr_embedder.load_state_dict(
checkpoint['embedder'], strict=True)
diff --git a/models/transformer_model.py b/models/transformer_model.py
index 7db0f3e..4523d17 100644
--- a/models/transformer_model.py
+++ b/models/transformer_model.py
@@ -21,7 +21,7 @@ class TransformerTextureAwareModel():
def __init__(self, opt):
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
self.is_train = opt['is_train']
# VQVAE for image
@@ -317,10 +317,10 @@ class TransformerTextureAwareModel():
def sample_fn(self, temp=1.0, sample_steps=None):
self._denoise_fn.eval()
- b, device = self.image.size(0), 'cuda'
+ b = self.image.size(0)
x_t = torch.ones(
- (b, np.prod(self.shape)), device=device).long() * self.mask_id
- unmasked = torch.zeros_like(x_t, device=device).bool()
+ (b, np.prod(self.shape)), device=self.device).long() * self.mask_id
+ unmasked = torch.zeros_like(x_t, device=self.device).bool()
sample_steps = list(range(1, sample_steps + 1))
texture_mask_flatten = self.texture_tokens.view(-1)
@@ -336,11 +336,11 @@ class TransformerTextureAwareModel():
for t in reversed(sample_steps):
print(f'Sample timestep {t:4d}', end='\r')
- t = torch.full((b, ), t, device=device, dtype=torch.long)
+ t = torch.full((b, ), t, device=self.device, dtype=torch.long)
# where to unmask
changes = torch.rand(
- x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
+ x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
# don't unmask somewhere already unmasked
changes = torch.bitwise_xor(changes,
torch.bitwise_and(changes, unmasked))
diff --git a/models/vqgan_model.py b/models/vqgan_model.py
index 13a2e70..9c840f1 100644
--- a/models/vqgan_model.py
+++ b/models/vqgan_model.py
@@ -20,7 +20,7 @@ class VQModel():
def __init__(self, opt):
super().__init__()
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
self.encoder = Encoder(
ch=opt['ch'],
num_res_blocks=opt['num_res_blocks'],
@@ -390,7 +390,7 @@ class VQImageSegmTextureModel(VQImageModel):
def __init__(self, opt):
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
self.encoder = Encoder(
ch=opt['ch'],
num_res_blocks=opt['num_res_blocks'],
|