# NAI compatible import torch class HypernetworkModule(torch.nn.Module): def __init__(self, dim, multiplier=1.0): super().__init__() linear1 = torch.nn.Linear(dim, dim * 2) linear2 = torch.nn.Linear(dim * 2, dim) linear1.weight.data.normal_(mean=0.0, std=0.01) linear1.bias.data.zero_() linear2.weight.data.normal_(mean=0.0, std=0.01) linear2.bias.data.zero_() linears = [linear1, linear2] self.linear = torch.nn.Sequential(*linears) self.multiplier = multiplier def forward(self, x): return x + self.linear(x) * self.multiplier class Hypernetwork(torch.nn.Module): enable_sizes = [320, 640, 768, 1280] # return self.modules[Hypernetwork.enable_sizes.index(size)] def __init__(self, multiplier=1.0) -> None: super().__init__() self.modules = [] for size in Hypernetwork.enable_sizes: self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier))) self.register_module(f"{size}_0", self.modules[-1][0]) self.register_module(f"{size}_1", self.modules[-1][1]) def apply_to_stable_diffusion(self, text_encoder, vae, unet): blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks for block in blocks: for subblk in block: if 'SpatialTransformer' in str(type(subblk)): for tf_block in subblk.transformer_blocks: for attn in [tf_block.attn1, tf_block.attn2]: size = attn.context_dim if size in Hypernetwork.enable_sizes: attn.hypernetwork = self else: attn.hypernetwork = None def apply_to_diffusers(self, text_encoder, vae, unet): blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks for block in blocks: if hasattr(block, 'attentions'): for subblk in block.attentions: if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~ for tf_block in subblk.transformer_blocks: for attn in [tf_block.attn1, tf_block.attn2]: size = attn.to_k.in_features if size in Hypernetwork.enable_sizes: attn.hypernetwork = self else: attn.hypernetwork = None return True # TODO error checking def forward(self, x, context): size = context.shape[-1] assert size in Hypernetwork.enable_sizes module = self.modules[Hypernetwork.enable_sizes.index(size)] return module[0].forward(context), module[1].forward(context) def load_from_state_dict(self, state_dict): # old ver to new ver changes = { 'linear1.bias': 'linear.0.bias', 'linear1.weight': 'linear.0.weight', 'linear2.bias': 'linear.1.bias', 'linear2.weight': 'linear.1.weight', } for key_from, key_to in changes.items(): if key_from in state_dict: state_dict[key_to] = state_dict[key_from] del state_dict[key_from] for size, sd in state_dict.items(): if type(size) == int: self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True) self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True) return True def get_state_dict(self): state_dict = {} for i, size in enumerate(Hypernetwork.enable_sizes): sd0 = self.modules[i][0].state_dict() sd1 = self.modules[i][1].state_dict() state_dict[size] = [sd0, sd1] return state_dict