Spaces:
Runtime error
Runtime error
File size: 4,850 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
# https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion
from collections import defaultdict
from random import shuffle
from typing import NamedTuple
import torch
from scipy.optimize import linear_sum_assignment
from modules.shared import log
SPECIAL_KEYS = [
"first_stage_model.decoder.norm_out.weight",
"first_stage_model.decoder.norm_out.bias",
"first_stage_model.encoder.norm_out.weight",
"first_stage_model.encoder.norm_out.bias",
"model.diffusion_model.out.0.weight",
"model.diffusion_model.out.0.bias",
]
class PermutationSpec(NamedTuple):
perm_to_axes: dict
axes_to_perm: dict
def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
perm_to_axes = defaultdict(list)
for wk, axis_perms in axes_to_perm.items():
for axis, perm in enumerate(axis_perms):
if perm is not None:
perm_to_axes[perm].append((wk, axis))
return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)
def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
"""Get parameter `k` from `params`, with the permutations applied."""
w = params[k]
for axis, p in enumerate(ps.axes_to_perm[k]):
# Skip the axis we're trying to permute.
if axis == except_axis:
continue
# None indicates that there is no permutation relevant to that axis.
if p:
w = torch.index_select(w, axis, perm[p].int())
return w
def apply_permutation(ps: PermutationSpec, perm, params):
"""Apply a `perm` to `params`."""
return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}
def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha):
for k in model_a:
try:
perm_params = get_permuted_param(
ps, perm, k, model_a
)
model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params
except RuntimeError: # dealing with pix2pix and inpainting models
continue
return model_a
def inner_matching(
n,
ps,
p,
params_a,
params_b,
usefp16,
progress,
number,
linear_sum,
perm,
device,
):
A = torch.zeros((n, n), dtype=torch.float16) if usefp16 else torch.zeros((n, n))
A = A.to(device)
for wk, axis in ps.perm_to_axes[p]:
w_a = params_a[wk]
w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)).to(device)
w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1)).T.to(device)
if usefp16:
w_a = w_a.half().to(device)
w_b = w_b.half().to(device)
try:
A += torch.matmul(w_a, w_b)
except RuntimeError:
A += torch.matmul(torch.dequantize(w_a), torch.dequantize(w_b))
A = A.cpu()
ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True)
A = A.to(device)
assert (torch.tensor(ri) == torch.arange(len(ri))).all()
eye_tensor = torch.eye(n).to(device)
oldL = torch.vdot(
torch.flatten(A).float(), torch.flatten(eye_tensor[perm[p].long()])
)
newL = torch.vdot(torch.flatten(A).float(), torch.flatten(eye_tensor[ci, :]))
if usefp16:
oldL = oldL.half()
newL = newL.half()
if newL - oldL != 0:
linear_sum += abs((newL - oldL).item())
number += 1
log.debug(f"Merge Rebasin permutation: {p}={newL-oldL}")
progress = progress or newL > oldL + 1e-12
perm[p] = torch.Tensor(ci).to(device)
return linear_sum, number, perm, progress
def weight_matching(
ps: PermutationSpec,
params_a,
params_b,
max_iter=1,
init_perm=None,
usefp16=False,
device="cpu",
):
perm_sizes = {
p: params_a[axes[0][0]].shape[axes[0][1]]
for p, axes in ps.perm_to_axes.items()
if axes[0][0] in params_a.keys()
}
perm = {}
perm = (
{p: torch.arange(n).to(device) for p, n in perm_sizes.items()}
if init_perm is None
else init_perm
)
linear_sum = 0
number = 0
special_layers = ["P_bg324"]
for _i in range(max_iter):
progress = False
shuffle(special_layers)
for p in special_layers:
n = perm_sizes[p]
linear_sum, number, perm, progress = inner_matching(
n,
ps,
p,
params_a,
params_b,
usefp16,
progress,
number,
linear_sum,
perm,
device,
)
progress = True
if not progress:
break
average = linear_sum / number if number > 0 else 0
return perm, average
|