Spaces:
Paused
Paused
import comfy.utils | |
import folder_paths | |
import torch | |
import logging | |
def load_hypernetwork_patch(path, strength): | |
sd = comfy.utils.load_torch_file(path, safe_load=True) | |
activation_func = sd.get('activation_func', 'linear') | |
is_layer_norm = sd.get('is_layer_norm', False) | |
use_dropout = sd.get('use_dropout', False) | |
activate_output = sd.get('activate_output', False) | |
last_layer_dropout = sd.get('last_layer_dropout', False) | |
valid_activation = { | |
"linear": torch.nn.Identity, | |
"relu": torch.nn.ReLU, | |
"leakyrelu": torch.nn.LeakyReLU, | |
"elu": torch.nn.ELU, | |
"swish": torch.nn.Hardswish, | |
"tanh": torch.nn.Tanh, | |
"sigmoid": torch.nn.Sigmoid, | |
"softsign": torch.nn.Softsign, | |
"mish": torch.nn.Mish, | |
} | |
if activation_func not in valid_activation: | |
logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)) | |
return None | |
out = {} | |
for d in sd: | |
try: | |
dim = int(d) | |
except: | |
continue | |
output = [] | |
for index in [0, 1]: | |
attn_weights = sd[dim][index] | |
keys = attn_weights.keys() | |
linears = filter(lambda a: a.endswith(".weight"), keys) | |
linears = list(map(lambda a: a[:-len(".weight")], linears)) | |
layers = [] | |
i = 0 | |
while i < len(linears): | |
lin_name = linears[i] | |
last_layer = (i == (len(linears) - 1)) | |
penultimate_layer = (i == (len(linears) - 2)) | |
lin_weight = attn_weights['{}.weight'.format(lin_name)] | |
lin_bias = attn_weights['{}.bias'.format(lin_name)] | |
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0]) | |
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias}) | |
layers.append(layer) | |
if activation_func != "linear": | |
if (not last_layer) or (activate_output): | |
layers.append(valid_activation[activation_func]()) | |
if is_layer_norm: | |
i += 1 | |
ln_name = linears[i] | |
ln_weight = attn_weights['{}.weight'.format(ln_name)] | |
ln_bias = attn_weights['{}.bias'.format(ln_name)] | |
ln = torch.nn.LayerNorm(ln_weight.shape[0]) | |
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias}) | |
layers.append(ln) | |
if use_dropout: | |
if (not last_layer) and (not penultimate_layer or last_layer_dropout): | |
layers.append(torch.nn.Dropout(p=0.3)) | |
i += 1 | |
output.append(torch.nn.Sequential(*layers)) | |
out[dim] = torch.nn.ModuleList(output) | |
class hypernetwork_patch: | |
def __init__(self, hypernet, strength): | |
self.hypernet = hypernet | |
self.strength = strength | |
def __call__(self, q, k, v, extra_options): | |
dim = k.shape[-1] | |
if dim in self.hypernet: | |
hn = self.hypernet[dim] | |
k = k + hn[0](k) * self.strength | |
v = v + hn[1](v) * self.strength | |
return q, k, v | |
def to(self, device): | |
for d in self.hypernet.keys(): | |
self.hypernet[d] = self.hypernet[d].to(device) | |
return self | |
return hypernetwork_patch(out, strength) | |
class HypernetworkLoader: | |
def INPUT_TYPES(s): | |
return {"required": { "model": ("MODEL",), | |
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ), | |
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), | |
}} | |
RETURN_TYPES = ("MODEL",) | |
FUNCTION = "load_hypernetwork" | |
CATEGORY = "loaders" | |
def load_hypernetwork(self, model, hypernetwork_name, strength): | |
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name) | |
model_hypernetwork = model.clone() | |
patch = load_hypernetwork_patch(hypernetwork_path, strength) | |
if patch is not None: | |
model_hypernetwork.set_model_attn1_patch(patch) | |
model_hypernetwork.set_model_attn2_patch(patch) | |
return (model_hypernetwork,) | |
NODE_CLASS_MAPPINGS = { | |
"HypernetworkLoader": HypernetworkLoader | |
} | |