EcoDiff / norm_attn_hook.py
JJRWTH's picture
update hf hub
bb7ee19
raw
history blame
9.43 kB
# TODO should be a parent class for all the hooks !! for the offical repo
# 1: FLUX Norm
import logging
import os
from collections import OrderedDict
from functools import partial
import torch
from torch import nn
import re
class NormHooker:
def __init__(
self,
pipeline: nn.Module,
regex: str,
dtype: torch.dtype,
masking: str,
dst: str,
epsilon: float = 0.0,
eps: float = 1e-6,
use_log: bool = False,
binary: bool = False,
):
self.pipeline = pipeline
self.net = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
self.logger = logging.getLogger(__name__)
self.dtype = dtype
self.regex = regex
self.hook_dict = {}
self.masking = masking
self.dst = dst
self.epsilon = epsilon
self.eps = eps
self.use_log = use_log
self.lambs = []
self.lambs_module_names = [] # store the module names for each lambda block
self.hook_counter = 0
self.module_neurons = OrderedDict()
self.binary = (
binary # default, need to discuss if we need to keep this attribute or not
)
def add_hooks_to_norm(self, hook_fn: callable):
"""
Add forward hooks to every feed forward layer matching the regex
:param hook_fn: a callable to be added to torch nn module as a hook
:return: dictionary of added hooks
"""
total_hooks = 0
for name, module in self.net.named_modules():
name_last_word = name.split(".")[-1]
if "norm1_context" in name_last_word:
if re.match(self.regex, name):
hook_fn_with_name = partial(hook_fn, name=name)
if hasattr(module, "linear"):
actual_module = module.linear
else:
if isinstance(module, nn.Linear):
actual_module = module
else:
continue
hook = actual_module.register_forward_hook(
hook_fn_with_name, with_kwargs=True
)
self.hook_dict[name] = hook
# AdaLayerNormZero
if isinstance(actual_module, torch.nn.Linear):
self.module_neurons[name] = actual_module.out_features
else:
raise NotImplementedError(
f"Module {name} is not implemented, please check"
)
self.logger.info(
f"Adding hook to {name}, neurons: {self.module_neurons[name]}"
)
total_hooks += 1
self.logger.info(f"Total hooks added: {total_hooks}")
return self.hook_dict
def add_hooks(self, init_value=1.0):
hook_fn = self.get_norm_masking_hook(init_value)
self.add_hooks_to_norm(hook_fn)
# initialize the lambda
self.lambs = [None] * len(self.hook_dict)
# initialize the lambda module names
self.lambs_module_names = [None] * len(self.hook_dict)
def clear_hooks(self):
"""clear all hooks"""
for hook in self.hook_dict.values():
hook.remove()
self.hook_dict.clear()
def save(self, name: str = None):
if name is not None:
dst = os.path.join(os.path.dirname(self.dst), name)
else:
dst = self.dst
dst_dir = os.path.dirname(dst)
if not os.path.exists(dst_dir):
self.logger.info(f"Creating directory {dst_dir}")
os.makedirs(dst_dir)
torch.save(self.lambs, dst)
@property
def get_lambda_block_names(self):
return self.lambs_module_names
def load(self, device, threshold):
if os.path.exists(self.dst):
self.logger.info(f"loading lambda from {self.dst}")
self.lambs = torch.load(self.dst, weights_only=True, map_location=device)
if self.binary:
# set binary masking for each lambda by using clamp
self.lambs = [
(torch.relu(lamb - threshold) > 0).float() for lamb in self.lambs
]
else:
self.lambs = [torch.clamp(lamb, min=0.0) for lamb in self.lambs]
# self.lambs_module_names = [None for _ in self.lambs]
else:
self.logger.info("skipping loading, training from scratch")
def binarize(self, scope: str, ratio: float):
"""
binarize lambda to be 0 or 1
:param scope: either locally (sparsity within layer) or globally (sparsity within model)
:param ratio: the ratio of the number of 1s to the total number of elements
"""
assert scope in ["local", "global"], "scope must be either local or global"
assert (
not self.binary
), "binarization is not supported when using binary mask already"
if scope == "local":
# Local binarization
for i, lamb in enumerate(self.lambs):
num_heads = lamb.size(0)
num_activate_heads = int(num_heads * ratio)
# Sort the lambda values with stable sorting to maintain order for equal values
sorted_lamb, sorted_indices = torch.sort(
lamb, descending=True, stable=True
)
# Find the threshold value
threshold = sorted_lamb[num_activate_heads - 1]
# Create a mask based on the sorted indices
mask = torch.zeros_like(lamb)
mask[sorted_indices[:num_activate_heads]] = 1.0
# Binarize the lambda based on the threshold and the mask
self.lambs[i] = torch.where(
lamb > threshold, torch.ones_like(lamb), mask
)
else:
# Global binarization
all_lambs = torch.cat([lamb.flatten() for lamb in self.lambs])
num_total = all_lambs.numel()
num_activate = int(num_total * ratio)
# Sort all lambda values globally with stable sorting
sorted_lambs, sorted_indices = torch.sort(
all_lambs, descending=True, stable=True
)
# Find the global threshold value
threshold = sorted_lambs[num_activate - 1]
# Create a global mask based on the sorted indices
global_mask = torch.zeros_like(all_lambs)
global_mask[sorted_indices[:num_activate]] = 1.0
# Binarize all lambdas based on the global threshold and mask
start_idx = 0
for i in range(len(self.lambs)):
end_idx = start_idx + self.lambs[i].numel()
lamb_mask = global_mask[start_idx:end_idx].reshape(self.lambs[i].shape)
self.lambs[i] = torch.where(
self.lambs[i] > threshold, torch.ones_like(self.lambs[i]), lamb_mask
)
start_idx = end_idx
self.binary = True
@staticmethod
def masking_fn(hidden_states, **kwargs):
hidden_states_dtype = hidden_states.dtype
lamb = kwargs["lamb"].view(1, 1, kwargs["lamb"].shape[0])
if kwargs.get("masking", None) == "sigmoid":
mask = torch.sigmoid(lamb)
elif kwargs.get("masking", None) == "binary":
mask = lamb
elif kwargs.get("masking", None) == "continues2binary":
# TODO: this might cause potential issue as it hard threshold at 0
mask = (lamb > 0).float()
elif kwargs.get("masking", None) == "no_masking":
mask = torch.ones_like(lamb)
else:
raise NotImplementedError
epsilon = kwargs.get("epsilon", 0.0)
if hidden_states.dim() == 2:
mask = mask.squeeze(1)
hidden_states = hidden_states * mask + torch.randn_like(
hidden_states
) * epsilon * (1 - mask)
return hidden_states.to(hidden_states_dtype)
def get_norm_masking_hook(self, init_value=1.0):
"""
Get a hook function to mask feed forward layer
"""
def hook_fn(module, args, kwargs, output, name):
# initialize lambda with acual head dim in the first run
if self.lambs[self.hook_counter] is None:
self.lambs[self.hook_counter] = (
torch.ones(
self.module_neurons[name],
device=self.pipeline.device,
dtype=self.dtype,
)
* init_value
)
self.lambs[self.hook_counter].requires_grad = True
# load norm lambda module name for logging
self.lambs_module_names[self.hook_counter] = name
# perform masking
output = self.masking_fn(
output,
masking=self.masking,
lamb=self.lambs[self.hook_counter],
epsilon=self.epsilon,
eps=self.eps,
use_log=self.use_log,
)
self.hook_counter += 1
self.hook_counter %= len(self.lambs)
return output
return hook_fn