byeongjun-park commited on
Commit
0c0d385
1 Parent(s): 02b0827

HarmonyView update

Browse files
Files changed (3) hide show
  1. .idea/workspace.xml +30 -17
  2. app.py +5 -4
  3. ldm/models/diffusion/sync_dreamer.py +45 -27
.idea/workspace.xml CHANGED
@@ -4,15 +4,10 @@
4
  <option name="autoReloadType" value="SELECTIVE" />
5
  </component>
6
  <component name="ChangeListManager">
7
- <list default="true" id="a993d736-6297-4164-9c29-6b2ab1055a96" name="변경" comment="error resolve">
8
- <change afterPath="$PROJECT_DIR$/hf_demo/examples/cat.png" afterDir="false" />
9
- <change afterPath="$PROJECT_DIR$/hf_demo/examples/crab.png" afterDir="false" />
10
- <change afterPath="$PROJECT_DIR$/hf_demo/examples/elephant.png" afterDir="false" />
11
- <change afterPath="$PROJECT_DIR$/hf_demo/examples/flower.png" afterDir="false" />
12
- <change afterPath="$PROJECT_DIR$/hf_demo/examples/forest.png" afterDir="false" />
13
- <change afterPath="$PROJECT_DIR$/hf_demo/examples/monkey.png" afterDir="false" />
14
- <change afterPath="$PROJECT_DIR$/hf_demo/examples/teapot.png" afterDir="false" />
15
  <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
 
 
16
  </list>
17
  <option name="SHOW_DIALOG" value="false" />
18
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
@@ -34,14 +29,14 @@
34
  <option name="hideEmptyMiddlePackages" value="true" />
35
  <option name="showLibraryContents" value="true" />
36
  </component>
37
- <component name="PropertiesComponent">{
38
- &quot;keyToString&quot;: {
39
- &quot;RunOnceActivity.OpenProjectViewOnStart&quot;: &quot;true&quot;,
40
- &quot;RunOnceActivity.ShowReadmeOnStart&quot;: &quot;true&quot;,
41
- &quot;git-widget-placeholder&quot;: &quot;main&quot;,
42
- &quot;last_opened_file_path&quot;: &quot;/home/byeongjun/PycharmProjects/HarmonyView&quot;
43
  }
44
- }</component>
45
  <component name="RecentsManager">
46
  <key name="CopyFile.RECENT_KEYS">
47
  <recent name="$PROJECT_DIR$" />
@@ -104,7 +99,23 @@
104
  <option name="project" value="LOCAL" />
105
  <updated>1703061633630</updated>
106
  </task>
107
- <option name="localTasksCounter" value="6" />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  <servers />
109
  </component>
110
  <component name="Vcs.Log.Tabs.Properties">
@@ -120,6 +131,8 @@
120
  </component>
121
  <component name="VcsManagerConfiguration">
122
  <MESSAGE value="error resolve" />
123
- <option name="LAST_COMMIT_MESSAGE" value="error resolve" />
 
 
124
  </component>
125
  </project>
 
4
  <option name="autoReloadType" value="SELECTIVE" />
5
  </component>
6
  <component name="ChangeListManager">
7
+ <list default="true" id="a993d736-6297-4164-9c29-6b2ab1055a96" name="변경" comment="change title">
 
 
 
 
 
 
 
8
  <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
9
+ <change beforePath="$PROJECT_DIR$/app.py" beforeDir="false" afterPath="$PROJECT_DIR$/app.py" afterDir="false" />
10
+ <change beforePath="$PROJECT_DIR$/ldm/models/diffusion/sync_dreamer.py" beforeDir="false" afterPath="$PROJECT_DIR$/ldm/models/diffusion/sync_dreamer.py" afterDir="false" />
11
  </list>
12
  <option name="SHOW_DIALOG" value="false" />
13
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
 
29
  <option name="hideEmptyMiddlePackages" value="true" />
30
  <option name="showLibraryContents" value="true" />
31
  </component>
32
+ <component name="PropertiesComponent"><![CDATA[{
33
+ "keyToString": {
34
+ "RunOnceActivity.OpenProjectViewOnStart": "true",
35
+ "RunOnceActivity.ShowReadmeOnStart": "true",
36
+ "git-widget-placeholder": "main",
37
+ "last_opened_file_path": "/home/byeongjun/PycharmProjects/cvpr2024"
38
  }
39
+ }]]></component>
40
  <component name="RecentsManager">
41
  <key name="CopyFile.RECENT_KEYS">
42
  <recent name="$PROJECT_DIR$" />
 
99
  <option name="project" value="LOCAL" />
100
  <updated>1703061633630</updated>
101
  </task>
102
+ <task id="LOCAL-00006" summary="add example code">
103
+ <option name="closed" value="true" />
104
+ <created>1703069567948</created>
105
+ <option name="number" value="00006" />
106
+ <option name="presentableId" value="LOCAL-00006" />
107
+ <option name="project" value="LOCAL" />
108
+ <updated>1703069567948</updated>
109
+ </task>
110
+ <task id="LOCAL-00007" summary="change title">
111
+ <option name="closed" value="true" />
112
+ <created>1703070569206</created>
113
+ <option name="number" value="00007" />
114
+ <option name="presentableId" value="LOCAL-00007" />
115
+ <option name="project" value="LOCAL" />
116
+ <updated>1703070569206</updated>
117
+ </task>
118
+ <option name="localTasksCounter" value="8" />
119
  <servers />
120
  </component>
121
  <component name="Vcs.Log.Tabs.Properties">
 
131
  </component>
132
  <component name="VcsManagerConfiguration">
133
  <MESSAGE value="error resolve" />
134
+ <MESSAGE value="add example code" />
135
+ <MESSAGE value="change title" />
136
+ <option name="LAST_COMMIT_MESSAGE" value="change title" />
137
  </component>
138
  </project>
app.py CHANGED
@@ -79,7 +79,7 @@ def resize_inputs(image_input, crop_size):
79
  results = add_margin(ref_img_, size=256)
80
  return results
81
 
82
- def generate(model, sample_steps, batch_view_num, sample_num, cfg_scale, seed, image_input, elevation_input):
83
  if deployed:
84
  assert isinstance(model, SyncMultiviewDiffusion)
85
  seed=int(seed)
@@ -104,7 +104,7 @@ def generate(model, sample_steps, batch_view_num, sample_num, cfg_scale, seed, i
104
 
105
  if deployed:
106
  sampler = SyncDDIMSampler(model, sample_steps)
107
- x_sample = model.sample(sampler, data, cfg_scale, batch_view_num)
108
  else:
109
  x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
110
 
@@ -225,7 +225,8 @@ def run_demo():
225
  input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
226
  elevation.render()
227
  with gr.Accordion('Advanced options', open=False):
228
- cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
 
229
  sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=False, info='How many instance (16 images per instance)')
230
  sample_steps = gr.Slider(10, 300, 50, step=10, label='Sample steps', interactive=False)
231
  batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
@@ -252,7 +253,7 @@ def run_demo():
252
  # crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
253
  # .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
254
 
255
- run_btn.click(partial(generate, model), inputs=[sample_steps, batch_view_num, sample_num, cfg_scale, seed, input_block, elevation], outputs=[output_block], queue=True)\
256
  .success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
257
 
258
  demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
 
79
  results = add_margin(ref_img_, size=256)
80
  return results
81
 
82
+ def generate(model, sample_steps, batch_view_num, sample_num, cfg_scale_1, cfg_scale_2, seed, image_input, elevation_input):
83
  if deployed:
84
  assert isinstance(model, SyncMultiviewDiffusion)
85
  seed=int(seed)
 
104
 
105
  if deployed:
106
  sampler = SyncDDIMSampler(model, sample_steps)
107
+ x_sample = model.sample(sampler, data, (cfg_scale_1, cfg_scale_2), batch_view_num)
108
  else:
109
  x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
110
 
 
225
  input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
226
  elevation.render()
227
  with gr.Accordion('Advanced options', open=False):
228
+ cfg_scale_1 = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
229
+ cfg_scale_2 = gr.Slider(0.5, 1.5, 1.0, step=0.1, label='Classifier free guidance', interactive=True)
230
  sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=False, info='How many instance (16 images per instance)')
231
  sample_steps = gr.Slider(10, 300, 50, step=10, label='Sample steps', interactive=False)
232
  batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
 
253
  # crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
254
  # .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
255
 
256
+ run_btn.click(partial(generate, model), inputs=[sample_steps, batch_view_num, sample_num, cfg_scale_1, cfg_scale_2, seed, input_block, elevation], outputs=[output_block], queue=True)\
257
  .success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
258
 
259
  demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
ldm/models/diffusion/sync_dreamer.py CHANGED
@@ -110,6 +110,7 @@ class UNetWrapper(nn.Module):
110
  v_[k] = torch.cat([v, torch.zeros_like(v)], 0)
111
 
112
  x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat)], 0)
 
113
  if self.use_zero_123:
114
  # zero123 does not multiply this when encoding, maybe a bug for zero123
115
  first_stage_scale_factor = 0.18215
@@ -119,6 +120,24 @@ class UNetWrapper(nn.Module):
119
  s = s_uc + unconditional_scale * (s - s_uc)
120
  return s
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  class SpatialVolumeNet(nn.Module):
124
  def __init__(self, time_dim, view_dim, view_num,
@@ -156,13 +175,12 @@ class SpatialVolumeNet(nn.Module):
156
  device = x.device
157
 
158
  spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device)
159
- spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts), -1)
160
  spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)]
161
  spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1)
162
 
163
  # encode source features
164
  t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim)
165
- # v_embed_ = v_embed.view(1, N, self.view_dim).repeat(B, 1, 1).view(B, N, self.view_dim)
166
  v_embed_ = v_embed
167
  target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1)
168
  target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1)
@@ -227,7 +245,8 @@ class SyncMultiviewDiffusion(pl.LightningModule):
227
  view_num=16, image_size=256,
228
  cfg_scale=3.0, output_num=8, batch_view_num=4,
229
  drop_conditions=False, drop_scheme='default',
230
- clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt"):
 
231
  super().__init__()
232
 
233
  self.finetune_unet = finetune_unet
@@ -255,7 +274,10 @@ class SyncMultiviewDiffusion(pl.LightningModule):
255
  self.scheduler_config = scheduler_config
256
 
257
  latent_size = image_size//8
258
- self.ddim = SyncDDIMSampler(self, 200, "uniform", 1.0, latent_size=latent_size)
 
 
 
259
 
260
  def _init_clip_projection(self):
261
  self.cc_projection = nn.Linear(772, 768)
@@ -468,9 +490,9 @@ class SyncMultiviewDiffusion(pl.LightningModule):
468
  x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
469
  return x_noisy, noise
470
 
471
- def sample(self, sampler, batch, cfg_scale, batch_view_num, return_inter_results=False, inter_interval=50, inter_view_interval=2):
472
  _, clip_embed, input_info = self.prepare(batch)
473
- x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval, batch_view_num=batch_view_num)
474
 
475
  N = x_sample.shape[1]
476
  x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
@@ -509,7 +531,7 @@ class SyncMultiviewDiffusion(pl.LightningModule):
509
  step = self.global_step
510
  batch_ = {}
511
  for k, v in batch.items(): batch_[k] = v[:self.output_num]
512
- x_sample = self.sample(batch_, self.cfg_scale, self.batch_view_num)
513
  output_dir = Path(self.image_dir) / 'images' / 'val'
514
  output_dir.mkdir(exist_ok=True, parents=True)
515
  self.log_image(x_sample, batch, step, output_dir=output_dir)
@@ -588,7 +610,7 @@ class SyncDDIMSampler:
588
  return x_prev
589
 
590
  @torch.no_grad()
591
- def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=1, is_step0=False):
592
  """
593
  @param x_target_noisy: B,N,4,H,W
594
  @param input_info:
@@ -596,7 +618,6 @@ class SyncDDIMSampler:
596
  @param time_steps: B,
597
  @param index: int
598
  @param unconditional_scale:
599
- @param batch_view_num: int
600
  @param is_step0: bool
601
  @return:
602
  """
@@ -608,37 +629,34 @@ class SyncDDIMSampler:
608
  t_embed = self.model.embed_time(time_steps) # B,t_dim
609
  spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks)
610
 
611
- e_t = []
612
- target_indices = torch.arange(N) # N
613
- for ni in range(0, N, batch_view_num):
614
- x_target_noisy_ = x_target_noisy[:, ni:ni + batch_view_num]
615
- VN = x_target_noisy_.shape[1]
616
- x_target_noisy_ = x_target_noisy_.reshape(B*VN,C,H,W)
617
-
618
- time_steps_ = repeat_to_batch(time_steps, B, VN)
619
- target_indices_ = target_indices[ni:ni+batch_view_num].unsqueeze(0).repeat(B,1)
620
- clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_)
621
- if unconditional_scale!=1.0:
622
  noise = self.model.model.predict_with_unconditional_scale(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
623
  else:
624
  noise = self.model.model(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, is_train=False)
625
- e_t.append(noise.view(B,VN,4,H,W))
 
626
 
627
- e_t = torch.cat(e_t, 1)
628
- x_prev = self.denoise_apply_impl(x_target_noisy, index, e_t, is_step0)
629
  return x_prev
630
 
631
  @torch.no_grad()
632
- def sample(self, input_info, clip_embed, unconditional_scale=1.0, log_every_t=50, batch_view_num=1):
633
  """
634
  @param input_info: x, elevation
635
  @param clip_embed: B,M,768
636
  @param unconditional_scale:
637
  @param log_every_t:
638
- @param batch_view_num:
639
  @return:
640
  """
641
- print(f"unconditional scale {unconditional_scale:.1f}")
642
  C, H, W = 4, self.latent_size, self.latent_size
643
  B = clip_embed.shape[0]
644
  N = self.model.view_num
@@ -654,7 +672,7 @@ class SyncDDIMSampler:
654
  for i, step in enumerate(iterator):
655
  index = total_steps - i - 1 # index in ddim state
656
  time_steps = torch.full((B,), step, device=device, dtype=torch.long)
657
- x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=batch_view_num, is_step0=index==0)
658
  if index % log_every_t == 0 or index == total_steps - 1:
659
  intermediates['x_inter'].append(x_target_noisy)
660
 
 
110
  v_[k] = torch.cat([v, torch.zeros_like(v)], 0)
111
 
112
  x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat)], 0)
113
+
114
  if self.use_zero_123:
115
  # zero123 does not multiply this when encoding, maybe a bug for zero123
116
  first_stage_scale_factor = 0.18215
 
120
  s = s_uc + unconditional_scale * (s - s_uc)
121
  return s
122
 
123
+ def predict_with_decomposed_unconditional_scales(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scales):
124
+ x_ = torch.cat([x] * 3, 0)
125
+ t_ = torch.cat([t] * 3, 0)
126
+ clip_embed_ = torch.cat([clip_embed, torch.zeros_like(clip_embed), clip_embed], 0)
127
+ x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat), x_concat*4], 0)
128
+
129
+ v_ = {}
130
+ for k, v in volume_feats.items():
131
+ v_[k] = torch.cat([v, v, torch.zeros_like(v)], 0)
132
+
133
+ if self.use_zero_123:
134
+ # zero123 does not multiply this when encoding, maybe a bug for zero123
135
+ first_stage_scale_factor = 0.18215
136
+ x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor
137
+ x_ = torch.cat([x_, x_concat_], 1)
138
+ s, s_uc1, s_uc2 = self.diffusion_model(x_, t_, clip_embed_, source_dict=v_).chunk(3)
139
+ s = s + unconditional_scales[0] * (s - s_uc1) + unconditional_scales[1] * (s - s_uc2)
140
+ return s
141
 
142
  class SpatialVolumeNet(nn.Module):
143
  def __init__(self, time_dim, view_dim, view_num,
 
175
  device = x.device
176
 
177
  spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device)
178
+ spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts, indexing='ij'), -1)
179
  spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)]
180
  spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1)
181
 
182
  # encode source features
183
  t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim)
 
184
  v_embed_ = v_embed
185
  target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1)
186
  target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1)
 
245
  view_num=16, image_size=256,
246
  cfg_scale=3.0, output_num=8, batch_view_num=4,
247
  drop_conditions=False, drop_scheme='default',
248
+ clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt",
249
+ sample_type='ddim', sample_steps=200):
250
  super().__init__()
251
 
252
  self.finetune_unet = finetune_unet
 
274
  self.scheduler_config = scheduler_config
275
 
276
  latent_size = image_size//8
277
+ if sample_type=='ddim':
278
+ self.sampler = SyncDDIMSampler(self, sample_steps , "uniform", 1.0, latent_size=latent_size)
279
+ else:
280
+ raise NotImplementedError
281
 
282
  def _init_clip_projection(self):
283
  self.cc_projection = nn.Linear(772, 768)
 
490
  x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
491
  return x_noisy, noise
492
 
493
+ def sample(self, sampler, batch, cfg_scale, return_inter_results=False, inter_interval=50, inter_view_interval=2):
494
  _, clip_embed, input_info = self.prepare(batch)
495
+ x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval)
496
 
497
  N = x_sample.shape[1]
498
  x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
 
531
  step = self.global_step
532
  batch_ = {}
533
  for k, v in batch.items(): batch_[k] = v[:self.output_num]
534
+ x_sample = self.sample(self.sampler, batch_, self.cfg_scale)
535
  output_dir = Path(self.image_dir) / 'images' / 'val'
536
  output_dir.mkdir(exist_ok=True, parents=True)
537
  self.log_image(x_sample, batch, step, output_dir=output_dir)
 
610
  return x_prev
611
 
612
  @torch.no_grad()
613
+ def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, is_step0=False):
614
  """
615
  @param x_target_noisy: B,N,4,H,W
616
  @param input_info:
 
618
  @param time_steps: B,
619
  @param index: int
620
  @param unconditional_scale:
 
621
  @param is_step0: bool
622
  @return:
623
  """
 
629
  t_embed = self.model.embed_time(time_steps) # B,t_dim
630
  spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks)
631
 
632
+ target_indices_ = torch.arange(N).unsqueeze(0).repeat(B, 1)
633
+ x_target_noisy_ = x_target_noisy.reshape(B*N,C,H,W)
634
+
635
+ time_steps_ = repeat_to_batch(time_steps, B, N)
636
+ clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_)
637
+
638
+ if type(unconditional_scale) == float: ## CFG
639
+ if unconditional_scale != 1.0:
 
 
 
640
  noise = self.model.model.predict_with_unconditional_scale(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
641
  else:
642
  noise = self.model.model(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, is_train=False)
643
+ else: ## DG
644
+ noise = self.model.model.predict_with_decomposed_unconditional_scales(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
645
 
646
+ noise = noise.reshape(B, N, 4, H, W)
647
+ x_prev = self.denoise_apply_impl(x_target_noisy, index, noise, is_step0)
648
  return x_prev
649
 
650
  @torch.no_grad()
651
+ def sample(self, input_info, clip_embed, unconditional_scale, log_every_t=50):
652
  """
653
  @param input_info: x, elevation
654
  @param clip_embed: B,M,768
655
  @param unconditional_scale:
656
  @param log_every_t:
 
657
  @return:
658
  """
659
+
660
  C, H, W = 4, self.latent_size, self.latent_size
661
  B = clip_embed.shape[0]
662
  N = self.model.view_num
 
672
  for i, step in enumerate(iterator):
673
  index = total_steps - i - 1 # index in ddim state
674
  time_steps = torch.full((B,), step, device=device, dtype=torch.long)
675
+ x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, is_step0=index==0)
676
  if index % log_every_t == 0 or index == total_steps - 1:
677
  intermediates['x_inter'].append(x_target_noisy)
678