Spaces:
Sleeping
Sleeping
byeongjun-park
commited on
Commit
•
0c0d385
1
Parent(s):
02b0827
HarmonyView update
Browse files- .idea/workspace.xml +30 -17
- app.py +5 -4
- ldm/models/diffusion/sync_dreamer.py +45 -27
.idea/workspace.xml
CHANGED
@@ -4,15 +4,10 @@
|
|
4 |
<option name="autoReloadType" value="SELECTIVE" />
|
5 |
</component>
|
6 |
<component name="ChangeListManager">
|
7 |
-
<list default="true" id="a993d736-6297-4164-9c29-6b2ab1055a96" name="변경" comment="
|
8 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/cat.png" afterDir="false" />
|
9 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/crab.png" afterDir="false" />
|
10 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/elephant.png" afterDir="false" />
|
11 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/flower.png" afterDir="false" />
|
12 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/forest.png" afterDir="false" />
|
13 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/monkey.png" afterDir="false" />
|
14 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/teapot.png" afterDir="false" />
|
15 |
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
|
|
|
|
16 |
</list>
|
17 |
<option name="SHOW_DIALOG" value="false" />
|
18 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
@@ -34,14 +29,14 @@
|
|
34 |
<option name="hideEmptyMiddlePackages" value="true" />
|
35 |
<option name="showLibraryContents" value="true" />
|
36 |
</component>
|
37 |
-
<component name="PropertiesComponent"
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
}
|
44 |
-
}
|
45 |
<component name="RecentsManager">
|
46 |
<key name="CopyFile.RECENT_KEYS">
|
47 |
<recent name="$PROJECT_DIR$" />
|
@@ -104,7 +99,23 @@
|
|
104 |
<option name="project" value="LOCAL" />
|
105 |
<updated>1703061633630</updated>
|
106 |
</task>
|
107 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
<servers />
|
109 |
</component>
|
110 |
<component name="Vcs.Log.Tabs.Properties">
|
@@ -120,6 +131,8 @@
|
|
120 |
</component>
|
121 |
<component name="VcsManagerConfiguration">
|
122 |
<MESSAGE value="error resolve" />
|
123 |
-
<
|
|
|
|
|
124 |
</component>
|
125 |
</project>
|
|
|
4 |
<option name="autoReloadType" value="SELECTIVE" />
|
5 |
</component>
|
6 |
<component name="ChangeListManager">
|
7 |
+
<list default="true" id="a993d736-6297-4164-9c29-6b2ab1055a96" name="변경" comment="change title">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
9 |
+
<change beforePath="$PROJECT_DIR$/app.py" beforeDir="false" afterPath="$PROJECT_DIR$/app.py" afterDir="false" />
|
10 |
+
<change beforePath="$PROJECT_DIR$/ldm/models/diffusion/sync_dreamer.py" beforeDir="false" afterPath="$PROJECT_DIR$/ldm/models/diffusion/sync_dreamer.py" afterDir="false" />
|
11 |
</list>
|
12 |
<option name="SHOW_DIALOG" value="false" />
|
13 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
|
29 |
<option name="hideEmptyMiddlePackages" value="true" />
|
30 |
<option name="showLibraryContents" value="true" />
|
31 |
</component>
|
32 |
+
<component name="PropertiesComponent"><![CDATA[{
|
33 |
+
"keyToString": {
|
34 |
+
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
35 |
+
"RunOnceActivity.ShowReadmeOnStart": "true",
|
36 |
+
"git-widget-placeholder": "main",
|
37 |
+
"last_opened_file_path": "/home/byeongjun/PycharmProjects/cvpr2024"
|
38 |
}
|
39 |
+
}]]></component>
|
40 |
<component name="RecentsManager">
|
41 |
<key name="CopyFile.RECENT_KEYS">
|
42 |
<recent name="$PROJECT_DIR$" />
|
|
|
99 |
<option name="project" value="LOCAL" />
|
100 |
<updated>1703061633630</updated>
|
101 |
</task>
|
102 |
+
<task id="LOCAL-00006" summary="add example code">
|
103 |
+
<option name="closed" value="true" />
|
104 |
+
<created>1703069567948</created>
|
105 |
+
<option name="number" value="00006" />
|
106 |
+
<option name="presentableId" value="LOCAL-00006" />
|
107 |
+
<option name="project" value="LOCAL" />
|
108 |
+
<updated>1703069567948</updated>
|
109 |
+
</task>
|
110 |
+
<task id="LOCAL-00007" summary="change title">
|
111 |
+
<option name="closed" value="true" />
|
112 |
+
<created>1703070569206</created>
|
113 |
+
<option name="number" value="00007" />
|
114 |
+
<option name="presentableId" value="LOCAL-00007" />
|
115 |
+
<option name="project" value="LOCAL" />
|
116 |
+
<updated>1703070569206</updated>
|
117 |
+
</task>
|
118 |
+
<option name="localTasksCounter" value="8" />
|
119 |
<servers />
|
120 |
</component>
|
121 |
<component name="Vcs.Log.Tabs.Properties">
|
|
|
131 |
</component>
|
132 |
<component name="VcsManagerConfiguration">
|
133 |
<MESSAGE value="error resolve" />
|
134 |
+
<MESSAGE value="add example code" />
|
135 |
+
<MESSAGE value="change title" />
|
136 |
+
<option name="LAST_COMMIT_MESSAGE" value="change title" />
|
137 |
</component>
|
138 |
</project>
|
app.py
CHANGED
@@ -79,7 +79,7 @@ def resize_inputs(image_input, crop_size):
|
|
79 |
results = add_margin(ref_img_, size=256)
|
80 |
return results
|
81 |
|
82 |
-
def generate(model, sample_steps, batch_view_num, sample_num,
|
83 |
if deployed:
|
84 |
assert isinstance(model, SyncMultiviewDiffusion)
|
85 |
seed=int(seed)
|
@@ -104,7 +104,7 @@ def generate(model, sample_steps, batch_view_num, sample_num, cfg_scale, seed, i
|
|
104 |
|
105 |
if deployed:
|
106 |
sampler = SyncDDIMSampler(model, sample_steps)
|
107 |
-
x_sample = model.sample(sampler, data,
|
108 |
else:
|
109 |
x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
|
110 |
|
@@ -225,7 +225,8 @@ def run_demo():
|
|
225 |
input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
|
226 |
elevation.render()
|
227 |
with gr.Accordion('Advanced options', open=False):
|
228 |
-
|
|
|
229 |
sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=False, info='How many instance (16 images per instance)')
|
230 |
sample_steps = gr.Slider(10, 300, 50, step=10, label='Sample steps', interactive=False)
|
231 |
batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
|
@@ -252,7 +253,7 @@ def run_demo():
|
|
252 |
# crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
|
253 |
# .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
|
254 |
|
255 |
-
run_btn.click(partial(generate, model), inputs=[sample_steps, batch_view_num, sample_num,
|
256 |
.success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
|
257 |
|
258 |
demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
|
|
|
79 |
results = add_margin(ref_img_, size=256)
|
80 |
return results
|
81 |
|
82 |
+
def generate(model, sample_steps, batch_view_num, sample_num, cfg_scale_1, cfg_scale_2, seed, image_input, elevation_input):
|
83 |
if deployed:
|
84 |
assert isinstance(model, SyncMultiviewDiffusion)
|
85 |
seed=int(seed)
|
|
|
104 |
|
105 |
if deployed:
|
106 |
sampler = SyncDDIMSampler(model, sample_steps)
|
107 |
+
x_sample = model.sample(sampler, data, (cfg_scale_1, cfg_scale_2), batch_view_num)
|
108 |
else:
|
109 |
x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
|
110 |
|
|
|
225 |
input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
|
226 |
elevation.render()
|
227 |
with gr.Accordion('Advanced options', open=False):
|
228 |
+
cfg_scale_1 = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
|
229 |
+
cfg_scale_2 = gr.Slider(0.5, 1.5, 1.0, step=0.1, label='Classifier free guidance', interactive=True)
|
230 |
sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=False, info='How many instance (16 images per instance)')
|
231 |
sample_steps = gr.Slider(10, 300, 50, step=10, label='Sample steps', interactive=False)
|
232 |
batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
|
|
|
253 |
# crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
|
254 |
# .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
|
255 |
|
256 |
+
run_btn.click(partial(generate, model), inputs=[sample_steps, batch_view_num, sample_num, cfg_scale_1, cfg_scale_2, seed, input_block, elevation], outputs=[output_block], queue=True)\
|
257 |
.success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
|
258 |
|
259 |
demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
|
ldm/models/diffusion/sync_dreamer.py
CHANGED
@@ -110,6 +110,7 @@ class UNetWrapper(nn.Module):
|
|
110 |
v_[k] = torch.cat([v, torch.zeros_like(v)], 0)
|
111 |
|
112 |
x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat)], 0)
|
|
|
113 |
if self.use_zero_123:
|
114 |
# zero123 does not multiply this when encoding, maybe a bug for zero123
|
115 |
first_stage_scale_factor = 0.18215
|
@@ -119,6 +120,24 @@ class UNetWrapper(nn.Module):
|
|
119 |
s = s_uc + unconditional_scale * (s - s_uc)
|
120 |
return s
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
class SpatialVolumeNet(nn.Module):
|
124 |
def __init__(self, time_dim, view_dim, view_num,
|
@@ -156,13 +175,12 @@ class SpatialVolumeNet(nn.Module):
|
|
156 |
device = x.device
|
157 |
|
158 |
spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device)
|
159 |
-
spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts), -1)
|
160 |
spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)]
|
161 |
spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1)
|
162 |
|
163 |
# encode source features
|
164 |
t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim)
|
165 |
-
# v_embed_ = v_embed.view(1, N, self.view_dim).repeat(B, 1, 1).view(B, N, self.view_dim)
|
166 |
v_embed_ = v_embed
|
167 |
target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1)
|
168 |
target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1)
|
@@ -227,7 +245,8 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
227 |
view_num=16, image_size=256,
|
228 |
cfg_scale=3.0, output_num=8, batch_view_num=4,
|
229 |
drop_conditions=False, drop_scheme='default',
|
230 |
-
clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt"
|
|
|
231 |
super().__init__()
|
232 |
|
233 |
self.finetune_unet = finetune_unet
|
@@ -255,7 +274,10 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
255 |
self.scheduler_config = scheduler_config
|
256 |
|
257 |
latent_size = image_size//8
|
258 |
-
|
|
|
|
|
|
|
259 |
|
260 |
def _init_clip_projection(self):
|
261 |
self.cc_projection = nn.Linear(772, 768)
|
@@ -468,9 +490,9 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
468 |
x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
|
469 |
return x_noisy, noise
|
470 |
|
471 |
-
def sample(self, sampler, batch, cfg_scale,
|
472 |
_, clip_embed, input_info = self.prepare(batch)
|
473 |
-
x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval
|
474 |
|
475 |
N = x_sample.shape[1]
|
476 |
x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
|
@@ -509,7 +531,7 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
509 |
step = self.global_step
|
510 |
batch_ = {}
|
511 |
for k, v in batch.items(): batch_[k] = v[:self.output_num]
|
512 |
-
x_sample = self.sample(batch_, self.cfg_scale
|
513 |
output_dir = Path(self.image_dir) / 'images' / 'val'
|
514 |
output_dir.mkdir(exist_ok=True, parents=True)
|
515 |
self.log_image(x_sample, batch, step, output_dir=output_dir)
|
@@ -588,7 +610,7 @@ class SyncDDIMSampler:
|
|
588 |
return x_prev
|
589 |
|
590 |
@torch.no_grad()
|
591 |
-
def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale,
|
592 |
"""
|
593 |
@param x_target_noisy: B,N,4,H,W
|
594 |
@param input_info:
|
@@ -596,7 +618,6 @@ class SyncDDIMSampler:
|
|
596 |
@param time_steps: B,
|
597 |
@param index: int
|
598 |
@param unconditional_scale:
|
599 |
-
@param batch_view_num: int
|
600 |
@param is_step0: bool
|
601 |
@return:
|
602 |
"""
|
@@ -608,37 +629,34 @@ class SyncDDIMSampler:
|
|
608 |
t_embed = self.model.embed_time(time_steps) # B,t_dim
|
609 |
spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks)
|
610 |
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
target_indices_ = target_indices[ni:ni+batch_view_num].unsqueeze(0).repeat(B,1)
|
620 |
-
clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_)
|
621 |
-
if unconditional_scale!=1.0:
|
622 |
noise = self.model.model.predict_with_unconditional_scale(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
|
623 |
else:
|
624 |
noise = self.model.model(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, is_train=False)
|
625 |
-
|
|
|
626 |
|
627 |
-
|
628 |
-
x_prev = self.denoise_apply_impl(x_target_noisy, index,
|
629 |
return x_prev
|
630 |
|
631 |
@torch.no_grad()
|
632 |
-
def sample(self, input_info, clip_embed, unconditional_scale
|
633 |
"""
|
634 |
@param input_info: x, elevation
|
635 |
@param clip_embed: B,M,768
|
636 |
@param unconditional_scale:
|
637 |
@param log_every_t:
|
638 |
-
@param batch_view_num:
|
639 |
@return:
|
640 |
"""
|
641 |
-
|
642 |
C, H, W = 4, self.latent_size, self.latent_size
|
643 |
B = clip_embed.shape[0]
|
644 |
N = self.model.view_num
|
@@ -654,7 +672,7 @@ class SyncDDIMSampler:
|
|
654 |
for i, step in enumerate(iterator):
|
655 |
index = total_steps - i - 1 # index in ddim state
|
656 |
time_steps = torch.full((B,), step, device=device, dtype=torch.long)
|
657 |
-
x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale,
|
658 |
if index % log_every_t == 0 or index == total_steps - 1:
|
659 |
intermediates['x_inter'].append(x_target_noisy)
|
660 |
|
|
|
110 |
v_[k] = torch.cat([v, torch.zeros_like(v)], 0)
|
111 |
|
112 |
x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat)], 0)
|
113 |
+
|
114 |
if self.use_zero_123:
|
115 |
# zero123 does not multiply this when encoding, maybe a bug for zero123
|
116 |
first_stage_scale_factor = 0.18215
|
|
|
120 |
s = s_uc + unconditional_scale * (s - s_uc)
|
121 |
return s
|
122 |
|
123 |
+
def predict_with_decomposed_unconditional_scales(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scales):
|
124 |
+
x_ = torch.cat([x] * 3, 0)
|
125 |
+
t_ = torch.cat([t] * 3, 0)
|
126 |
+
clip_embed_ = torch.cat([clip_embed, torch.zeros_like(clip_embed), clip_embed], 0)
|
127 |
+
x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat), x_concat*4], 0)
|
128 |
+
|
129 |
+
v_ = {}
|
130 |
+
for k, v in volume_feats.items():
|
131 |
+
v_[k] = torch.cat([v, v, torch.zeros_like(v)], 0)
|
132 |
+
|
133 |
+
if self.use_zero_123:
|
134 |
+
# zero123 does not multiply this when encoding, maybe a bug for zero123
|
135 |
+
first_stage_scale_factor = 0.18215
|
136 |
+
x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor
|
137 |
+
x_ = torch.cat([x_, x_concat_], 1)
|
138 |
+
s, s_uc1, s_uc2 = self.diffusion_model(x_, t_, clip_embed_, source_dict=v_).chunk(3)
|
139 |
+
s = s + unconditional_scales[0] * (s - s_uc1) + unconditional_scales[1] * (s - s_uc2)
|
140 |
+
return s
|
141 |
|
142 |
class SpatialVolumeNet(nn.Module):
|
143 |
def __init__(self, time_dim, view_dim, view_num,
|
|
|
175 |
device = x.device
|
176 |
|
177 |
spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device)
|
178 |
+
spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts, indexing='ij'), -1)
|
179 |
spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)]
|
180 |
spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1)
|
181 |
|
182 |
# encode source features
|
183 |
t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim)
|
|
|
184 |
v_embed_ = v_embed
|
185 |
target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1)
|
186 |
target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1)
|
|
|
245 |
view_num=16, image_size=256,
|
246 |
cfg_scale=3.0, output_num=8, batch_view_num=4,
|
247 |
drop_conditions=False, drop_scheme='default',
|
248 |
+
clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt",
|
249 |
+
sample_type='ddim', sample_steps=200):
|
250 |
super().__init__()
|
251 |
|
252 |
self.finetune_unet = finetune_unet
|
|
|
274 |
self.scheduler_config = scheduler_config
|
275 |
|
276 |
latent_size = image_size//8
|
277 |
+
if sample_type=='ddim':
|
278 |
+
self.sampler = SyncDDIMSampler(self, sample_steps , "uniform", 1.0, latent_size=latent_size)
|
279 |
+
else:
|
280 |
+
raise NotImplementedError
|
281 |
|
282 |
def _init_clip_projection(self):
|
283 |
self.cc_projection = nn.Linear(772, 768)
|
|
|
490 |
x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
|
491 |
return x_noisy, noise
|
492 |
|
493 |
+
def sample(self, sampler, batch, cfg_scale, return_inter_results=False, inter_interval=50, inter_view_interval=2):
|
494 |
_, clip_embed, input_info = self.prepare(batch)
|
495 |
+
x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval)
|
496 |
|
497 |
N = x_sample.shape[1]
|
498 |
x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
|
|
|
531 |
step = self.global_step
|
532 |
batch_ = {}
|
533 |
for k, v in batch.items(): batch_[k] = v[:self.output_num]
|
534 |
+
x_sample = self.sample(self.sampler, batch_, self.cfg_scale)
|
535 |
output_dir = Path(self.image_dir) / 'images' / 'val'
|
536 |
output_dir.mkdir(exist_ok=True, parents=True)
|
537 |
self.log_image(x_sample, batch, step, output_dir=output_dir)
|
|
|
610 |
return x_prev
|
611 |
|
612 |
@torch.no_grad()
|
613 |
+
def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, is_step0=False):
|
614 |
"""
|
615 |
@param x_target_noisy: B,N,4,H,W
|
616 |
@param input_info:
|
|
|
618 |
@param time_steps: B,
|
619 |
@param index: int
|
620 |
@param unconditional_scale:
|
|
|
621 |
@param is_step0: bool
|
622 |
@return:
|
623 |
"""
|
|
|
629 |
t_embed = self.model.embed_time(time_steps) # B,t_dim
|
630 |
spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks)
|
631 |
|
632 |
+
target_indices_ = torch.arange(N).unsqueeze(0).repeat(B, 1)
|
633 |
+
x_target_noisy_ = x_target_noisy.reshape(B*N,C,H,W)
|
634 |
+
|
635 |
+
time_steps_ = repeat_to_batch(time_steps, B, N)
|
636 |
+
clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_)
|
637 |
+
|
638 |
+
if type(unconditional_scale) == float: ## CFG
|
639 |
+
if unconditional_scale != 1.0:
|
|
|
|
|
|
|
640 |
noise = self.model.model.predict_with_unconditional_scale(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
|
641 |
else:
|
642 |
noise = self.model.model(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, is_train=False)
|
643 |
+
else: ## DG
|
644 |
+
noise = self.model.model.predict_with_decomposed_unconditional_scales(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
|
645 |
|
646 |
+
noise = noise.reshape(B, N, 4, H, W)
|
647 |
+
x_prev = self.denoise_apply_impl(x_target_noisy, index, noise, is_step0)
|
648 |
return x_prev
|
649 |
|
650 |
@torch.no_grad()
|
651 |
+
def sample(self, input_info, clip_embed, unconditional_scale, log_every_t=50):
|
652 |
"""
|
653 |
@param input_info: x, elevation
|
654 |
@param clip_embed: B,M,768
|
655 |
@param unconditional_scale:
|
656 |
@param log_every_t:
|
|
|
657 |
@return:
|
658 |
"""
|
659 |
+
|
660 |
C, H, W = 4, self.latent_size, self.latent_size
|
661 |
B = clip_embed.shape[0]
|
662 |
N = self.model.view_num
|
|
|
672 |
for i, step in enumerate(iterator):
|
673 |
index = total_steps - i - 1 # index in ddim state
|
674 |
time_steps = torch.full((B,), step, device=device, dtype=torch.long)
|
675 |
+
x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, is_step0=index==0)
|
676 |
if index % log_every_t == 0 or index == total_steps - 1:
|
677 |
intermediates['x_inter'].append(x_target_noisy)
|
678 |
|