dreamdrone / sd /pnp_utils.py
imsuperkong's picture
Upload 6 files
d3bdeec
import torch
import os
import random
import numpy as np
import ipdb
import torch.nn.functional as F
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def register_time(model, t):
conv_module = model.unet.up_blocks[1].resnets[1]
setattr(conv_module, 't', t)
down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]}
up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
for res in up_res_dict:
for block in up_res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
setattr(module, 't', t)
for res in down_res_dict:
for block in down_res_dict[res]:
module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
setattr(module, 't', t)
module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1
setattr(module, 't', t)
def load_source_latents_t(t, latents_path):
latents_t_path = os.path.join(latents_path, f'noisy_latents_{t}.pt')
assert os.path.exists(latents_t_path), f'Missing latents at t {t} path {latents_t_path}'
latents = torch.load(latents_t_path)
return latents
def register_attention_control_efficient(model, injection_schedule):
def sa_forward(self):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward(x, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if is_cross else x
if not is_cross and self.injection_schedule is not None and (
self.t in self.injection_schedule or self.t == 1000):
q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
source_batch_size = int(q.shape[0] // 3)
# inject unconditional
q[source_batch_size:2 * source_batch_size] = q[:source_batch_size]
k[source_batch_size:2 * source_batch_size] = k[:source_batch_size]
# inject conditional
q[2 * source_batch_size:] = q[:source_batch_size]
k[2 * source_batch_size:] = k[:source_batch_size]
q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
else:
q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
v = self.to_v(encoder_hidden_states)
v = self.head_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if attention_mask is not None:
attention_mask = attention_mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~attention_mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.batch_to_head_dim(out)
return to_out(out)
return forward
res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
for res in res_dict:
for block in res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
module.forward = sa_forward(module)
setattr(module, 'injection_schedule', injection_schedule)
def register_attention_control_efficient_kv(model, injection_schedule):
def sa_forward(self):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward(x, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
# if encoder_hidden_states is None:
# ipdb.set_trace()
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if is_cross else x
q = self.to_q(x)
q = self.head_to_batch_dim(q)
if not is_cross and self.injection_schedule is not None and (
self.t in self.injection_schedule or self.t == 1000):
# q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
source_batch_size = int(v.shape[0] // 3)
# inject unconditional
k[source_batch_size:2 * source_batch_size] = k[:source_batch_size]
v[source_batch_size:2 * source_batch_size] = v[:source_batch_size]
# inject conditional
k[2 * source_batch_size:] = k[:source_batch_size]
v[2 * source_batch_size:] = v[:source_batch_size]
# q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
v = self.head_to_batch_dim(v)
else:
# q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
# q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
v = self.to_v(encoder_hidden_states)
v = self.head_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if attention_mask is not None:
attention_mask = attention_mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~attention_mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.batch_to_head_dim(out)
return to_out(out)
return forward
res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
for res in res_dict:
for block in res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
module.forward = sa_forward(module)
setattr(module, 'injection_schedule', injection_schedule)
def register_conv_control_efficient(model, injection_schedule):
def conv_forward(self):
def forward(input_tensor, temb):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
source_batch_size = int(hidden_states.shape[0] // 3)
# inject unconditional
hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
# inject conditional
hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size]
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
return forward
conv_module = model.unet.up_blocks[1].resnets[1]
conv_module.forward = conv_forward(conv_module)
setattr(conv_module, 'injection_schedule', injection_schedule)
def register_attention_control_efficient_kv_2nd_to_1st(model, injection_schedule, mask=None):
def sa_forward(self):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward(x, mask=mask, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
# if encoder_hidden_states is None:
# ipdb.set_trace()
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if is_cross else x
q = self.to_q(x)
q = self.head_to_batch_dim(q)
if not is_cross and self.injection_schedule is not None and (
self.t in self.injection_schedule or self.t == 1000):
# q = self.to_q(x)
target_size = int(np.sqrt(encoder_hidden_states.shape[1]))
target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
k = self.to_k(encoder_hidden_states) # k: bx256x1280
v = self.to_v(encoder_hidden_states)
source_batch_size = int(v.shape[0] // 2)
# inject
k[:source_batch_size] = k[source_batch_size:2 * source_batch_size] * (1-target_mask) + k[:source_batch_size] * target_mask
v[:source_batch_size] = v[source_batch_size:2 * source_batch_size] * (1-target_mask) + v[:source_batch_size] * target_mask
# q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
v = self.head_to_batch_dim(v)
else:
# q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
# q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
v = self.to_v(encoder_hidden_states)
v = self.head_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if attention_mask is not None:
attention_mask = attention_mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~attention_mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.batch_to_head_dim(out)
return to_out(out)
return forward
# res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
res_dict = {1: [1, 2], 2: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
for res in res_dict:
for block in res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
module.forward = sa_forward(module)
setattr(module, 'injection_schedule', injection_schedule)
def register_conv_control_efficient_2nd_to_1st(model, injection_schedule, mask=None):
def conv_forward(self):
def forward(input_tensor, temb):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
source_batch_size = int(hidden_states.shape[0] // 2)
# inject unconditional
# hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
# inject conditional
target_size = int(np.sqrt(hidden_states.shape[-1]))
target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
hidden_states[:source_batch_size] = hidden_states[source_batch_size:] * (1-target_mask) + hidden_states[:source_batch_size] * target_mask
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
return forward
conv_module = model.unet.up_blocks[1].resnets[1]
conv_module.forward = conv_forward(conv_module)
setattr(conv_module, 'injection_schedule', injection_schedule)
def register_attention_control_efficient_qk_w_mask(model, injection_schedule, mask):
def sa_forward(self):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward(x, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if is_cross else x
if not is_cross and self.injection_schedule is not None and (
self.t in self.injection_schedule or self.t == 1000):
q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
target_size = int(np.sqrt(encoder_hidden_states.shape[1]))
target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
source_batch_size = int(q.shape[0] // 3)
# inject unconditional
q[source_batch_size:2 * source_batch_size] = q[:source_batch_size] * target_mask + q[source_batch_size:2 * source_batch_size] * (1 - target_mask)
k[source_batch_size:2 * source_batch_size] = k[:source_batch_size] * target_mask + k[source_batch_size:2 * source_batch_size] * (1 - target_mask)
# inject conditional
q[2 * source_batch_size:] = q[:source_batch_size] * target_mask + q[2 * source_batch_size:] * (1 - target_mask)
k[2 * source_batch_size:] = k[:source_batch_size] * target_mask + k[2 * source_batch_size:] * (1 - target_mask)
q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
else:
q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
v = self.to_v(encoder_hidden_states)
v = self.head_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if attention_mask is not None:
attention_mask = attention_mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~attention_mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.batch_to_head_dim(out)
return to_out(out)
return forward
res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
for res in res_dict:
for block in res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
module.forward = sa_forward(module)
setattr(module, 'injection_schedule', injection_schedule)
def register_attention_control_efficient_kv_w_mask(model, injection_schedule, mask, do_classifier_free_guidance):
def sa_forward(self):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward(x, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if is_cross else x
q = self.to_q(x)
q = self.head_to_batch_dim(q)
if not is_cross and self.injection_schedule is not None and (
self.t in self.injection_schedule or self.t == 1000):
# if False:
k = self.to_k(encoder_hidden_states) # k: bx256x1280
v = self.to_v(encoder_hidden_states)
target_size = int(np.sqrt(encoder_hidden_states.shape[1]))
target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
source_batch_size = int(v.shape[0] // 3)
if do_classifier_free_guidance:
# inject unconditional
v[source_batch_size:2 * source_batch_size] = v[:source_batch_size] * target_mask + v[source_batch_size:2 * source_batch_size] * (1 - target_mask)
k[source_batch_size:2 * source_batch_size] = k[:source_batch_size] * target_mask + k[source_batch_size:2 * source_batch_size] * (1 - target_mask)
# inject conditional
v[2 * source_batch_size:] = v[:source_batch_size] * target_mask + v[2 * source_batch_size:] * (1 - target_mask)
k[2 * source_batch_size:] = k[:source_batch_size] * target_mask + k[2 * source_batch_size:] * (1 - target_mask)
else:
v[source_batch_size:2 * source_batch_size] = v[:source_batch_size] * target_mask + v[source_batch_size:2 * source_batch_size] * (1 - target_mask)
k[source_batch_size:2 * source_batch_size] = k[:source_batch_size] * target_mask + k[source_batch_size:2 * source_batch_size] * (1 - target_mask)
k = self.head_to_batch_dim(k)
v = self.head_to_batch_dim(v)
else:
# q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
# q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
v = self.to_v(encoder_hidden_states)
v = self.head_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if attention_mask is not None:
attention_mask = attention_mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~attention_mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.batch_to_head_dim(out)
return to_out(out)
return forward
res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
# res_dict = {1: [2], 2: [2], 3: [2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
for res in res_dict:
for block in res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
module.forward = sa_forward(module)
setattr(module, 'injection_schedule', injection_schedule)
# down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]}
# for res in down_res_dict:
# for block in down_res_dict[res]:
# module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
# module.forward = sa_forward(module)
# setattr(module, 'injection_schedule', injection_schedule)
def register_conv_control_efficient_w_mask(model, injection_schedule, mask):
def conv_forward(self):
def forward(input_tensor, temb):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
# if False:
source_batch_size = int(hidden_states.shape[0] // 3)
target_size = int(np.sqrt(hidden_states.shape[-1]))
target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
# inject unconditional
hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size] * target_mask + hidden_states[source_batch_size:2 * source_batch_size] * (1-target_mask)
# inject conditional
hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size] * target_mask + hidden_states[2 * source_batch_size:] * (1-target_mask)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
return forward
conv_module = model.unet.up_blocks[1].resnets[1]
conv_module.forward = conv_forward(conv_module)
setattr(conv_module, 'injection_schedule', injection_schedule)