wondervictor commited on
Commit
c686e22
·
verified ·
1 Parent(s): e57f3c4

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +45 -45
model.py CHANGED
@@ -153,29 +153,29 @@ class Model:
153
  qzshape = [len(c_indices), 8, H // 16, W // 16]
154
  t1 = time.time()
155
  print(caption_embs.device)
156
- # index_sample = generate(
157
- # self.gpt_model,
158
- # c_indices,
159
- # (H // 16) * (W // 16),
160
- # c_emb_masks,
161
- # condition=condition_img,
162
- # cfg_scale=cfg_scale,
163
- # temperature=temperature,
164
- # top_k=top_k,
165
- # top_p=top_p,
166
- # sample_logits=True,
167
- # control_strength=control_strength,
168
- # )
169
- # sampling_time = time.time() - t1
170
- # print(f"Full sampling takes about {sampling_time:.2f} seconds.")
171
 
172
- # t2 = time.time()
173
- # print(index_sample.shape)
174
- # samples = self.vq_model.decode_code(
175
- # index_sample, qzshape) # output value is between [-1, 1]
176
- # decoder_time = time.time() - t2
177
- # print(f"decoder takes about {decoder_time:.2f} seconds.")
178
- samples = condition_img[0:1]
179
  samples = torch.cat((condition_img[0:1], samples), dim=0)
180
  samples = 255 * (samples * 0.5 + 0.5)
181
  samples = [
@@ -247,31 +247,31 @@ class Model:
247
  c_emb_masks = new_emb_masks
248
  qzshape = [len(c_indices), 8, H // 16, W // 16]
249
  t1 = time.time()
250
- # index_sample = generate(
251
- # self.gpt_model,
252
- # c_indices,
253
- # (H // 16) * (W // 16),
254
- # c_emb_masks,
255
- # condition=condition_img,
256
- # cfg_scale=cfg_scale,
257
- # temperature=temperature,
258
- # top_k=top_k,
259
- # top_p=top_p,
260
- # sample_logits=True,
261
- # control_strength=control_strength,
262
- # )
263
- # sampling_time = time.time() - t1
264
- # print(f"Full sampling takes about {sampling_time:.2f} seconds.")
265
 
266
- # t2 = time.time()
267
- # print(index_sample.shape)
268
- # samples = self.vq_model.decode_code(index_sample, qzshape)
269
- # decoder_time = time.time() - t2
270
- # print(f"decoder takes about {decoder_time:.2f} seconds.")
271
- # condition_img = condition_img.cpu()
272
- # samples = samples.cpu()
273
 
274
- samples = condition_img[0:1]
275
  samples = torch.cat((condition_img[0:1], samples), dim=0)
276
  samples = 255 * (samples * 0.5 + 0.5)
277
  samples = [
 
153
  qzshape = [len(c_indices), 8, H // 16, W // 16]
154
  t1 = time.time()
155
  print(caption_embs.device)
156
+ index_sample = generate(
157
+ self.gpt_model,
158
+ c_indices,
159
+ (H // 16) * (W // 16),
160
+ c_emb_masks,
161
+ condition=condition_img,
162
+ cfg_scale=cfg_scale,
163
+ temperature=temperature,
164
+ top_k=top_k,
165
+ top_p=top_p,
166
+ sample_logits=True,
167
+ control_strength=control_strength,
168
+ )
169
+ sampling_time = time.time() - t1
170
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
171
 
172
+ t2 = time.time()
173
+ print(index_sample.shape)
174
+ samples = self.vq_model.decode_code(
175
+ index_sample, qzshape) # output value is between [-1, 1]
176
+ decoder_time = time.time() - t2
177
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
178
+ # samples = condition_img[0:1]
179
  samples = torch.cat((condition_img[0:1], samples), dim=0)
180
  samples = 255 * (samples * 0.5 + 0.5)
181
  samples = [
 
247
  c_emb_masks = new_emb_masks
248
  qzshape = [len(c_indices), 8, H // 16, W // 16]
249
  t1 = time.time()
250
+ index_sample = generate(
251
+ self.gpt_model,
252
+ c_indices,
253
+ (H // 16) * (W // 16),
254
+ c_emb_masks,
255
+ condition=condition_img,
256
+ cfg_scale=cfg_scale,
257
+ temperature=temperature,
258
+ top_k=top_k,
259
+ top_p=top_p,
260
+ sample_logits=True,
261
+ control_strength=control_strength,
262
+ )
263
+ sampling_time = time.time() - t1
264
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
265
 
266
+ t2 = time.time()
267
+ print(index_sample.shape)
268
+ samples = self.vq_model.decode_code(index_sample, qzshape)
269
+ decoder_time = time.time() - t2
270
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
271
+ condition_img = condition_img.cpu()
272
+ samples = samples.cpu()
273
 
274
+ # samples = condition_img[0:1]
275
  samples = torch.cat((condition_img[0:1], samples), dim=0)
276
  samples = 255 * (samples * 0.5 + 0.5)
277
  samples = [