File size: 1,385 Bytes
6da9572
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch

from diffusers.models.embeddings import TimestepEmbedding, Timesteps


def unet_add_coded_conds(unet, added_number_count=1):
    unet.add_time_proj = Timesteps(256, True, 0)
    unet.add_embedding = TimestepEmbedding(256 * added_number_count, 1280)

    def get_aug_embed(emb, encoder_hidden_states, added_cond_kwargs):
        coded_conds = added_cond_kwargs.get("coded_conds")
        batch_size = coded_conds.shape[0]
        time_embeds = unet.add_time_proj(coded_conds.flatten())
        time_embeds = time_embeds.reshape((batch_size, -1))
        time_embeds = time_embeds.to(emb)
        aug_emb = unet.add_embedding(time_embeds)
        return aug_emb

    unet.get_aug_embed = get_aug_embed

    unet_original_forward = unet.forward

    def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
        cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()}
        coded_conds = cross_attention_kwargs.pop('coded_conds')
        kwargs['cross_attention_kwargs'] = cross_attention_kwargs

        coded_conds = torch.cat([coded_conds] * (sample.shape[0] // coded_conds.shape[0]), dim=0).to(sample.device)
        kwargs['added_cond_kwargs'] = dict(coded_conds=coded_conds)
        return unet_original_forward(sample, timestep, encoder_hidden_states, **kwargs)

    unet.forward = hooked_unet_forward

    return