#!/usr/bin/env python # -*- coding: utf-8 -*- import argparse import logging import re from collections import defaultdict from pathlib import Path import numpy as np from safetensors.numpy import safe_open, save_file logger = logging.getLogger(__name__) def analyze_lora_layers( sft_fd: safe_open, ) -> tuple[list[tuple[tuple[str, int], set[str]]], set[str]]: """ Analyze the LoRA layers in a SafeTensors file. Args: sft_fd (safe_open): An open SafeTensors file. Returns: A tuple containing: - A list of tuples, each containing a (section, index) pair and a set of associated keys. - A set of pass-through keys (non-LoRA layers). """ RE_LORA_NAME = re.compile( r"lora_unet_((?:input|middle|output|down|mid|up)_blocks?)(?:(?:_(\d+))?_attentions)?_(\d+)_.*" ) pass_through_keys: set[str] = set() block2keys: dict[tuple[str, int], set[str]] = defaultdict(set) for k in sft_fd.keys(): m = RE_LORA_NAME.fullmatch(k.replace("_0_1_transformer_blocks_", "_0_")) if not m: pass_through_keys.add(k) continue section, idx1, idx2 = m.groups() if idx1 is None: idx = idx2 else: idx = f"{idx1}{idx2}" block2keys[(section, idx)].add(k) if not block2keys: raise ValueError( "No UNet layers found in the LoRA checkpoint (Maybe not a SDXL model?)" ) block2keys_sorted = sorted((k, sorted(v)) for k, v in block2keys.items()) for k in pass_through_keys: if not "te_" in k and "text_" not in k: logging.warning( f"key {k} removed but it doesn't look like a text encoder layer" ) def print_layers(layers): for layer, params in layers.items(): params = ", ".join(sorted(params)) dbg(f" - {layer:<70}: {params}") if logger.getEffectiveLevel() <= logging.DEBUG: dbg = logger.debug for (section, idx), keys in block2keys_sorted: layers = groupby_layer(keys) dbg(f"* {section=} {idx=} keys={len(keys)} layers={len(layers)}") print_layers(layers) logger.debug(f" * Pass through: ") print_layers(groupby_layer(pass_through_keys)) return block2keys_sorted, pass_through_keys def groupby_layer( keys, make_empty=set, update=lambda vs, layer_name, param_name: vs.add(param_name) ): d = defaultdict(make_empty) for k in keys: layer, _, param = k.rpartition(".") vs = d[layer] update(vs, layer, param) return d def print_block_layout( block2keys: list[tuple[tuple[str, int], set[str]]], weights: list[float] | None = None, ) -> None: """ Print the layout of LoRA blocks, optionally with weights. Args: block2keys: A list of tuples, each containing a (section, index) pair and a set of associated keys. weights: Optional list of weights corresponding to each block. """ logger.info("Blocks layout:") if weights is None: for i, ((section, idx), keys) in enumerate(block2keys): logger.info(f"\t[{i:>2d}] {section:>13}.{idx} layers={len(keys):<3}") section2shortname = { # SDXL names: "input_blocks": "INP", "middle_block": "MID", "output_blocks": "OUT", # SD1 names "down_blocks": "INP", "mid_block": "MID", "up_blocks": "OUT", } vector_string = ",".join( f"{section2shortname[section]}{idx:>02}" for (section, idx), _ in block2keys ) logger.info(f'Vector string format: "1,{vector_string}"') vector_string = ",".join("0" * len(block2keys)) logger.info(f'Example (drops all blocks): "1,{vector_string}"') else: for i, (((section, idx), keys), weight) in enumerate(zip(block2keys, weights)): if abs(weight) > 1e-6: if abs(weight - 1) < 1e-6: weight = 1 w_disp = f"weight={weight}" else: w_disp = "removed" layers = len( groupby_layer(keys, lambda: None, lambda _layers, _layer, _attr: None) ) logger.info( f"\t[{i:>2d}] {section:>13}.{idx} keys={len(keys):<3} layers={layers:<3} {w_disp}" ) def filter_blocks(sft_fd: safe_open, vector_string: str) -> dict[str, "numpy.ndarray"]: """ Filter LoRA blocks based on a vector string. Args: sft_fd (safe_open): An open SafeTensors file. vector_string (str): A string representing weights for each block. Returns: A dictionary containing the filtered state dict, or None if an error occurs. """ global_weight, *weights_vector = map(float, vector_string.split(",")) block2keys, pass_through_keys = analyze_lora_layers(sft_fd) if len(weights_vector) != len(block2keys): logger.error(f"expected {len(block2keys)} weights, got {len(weights_vector)}") print_block_layout(block2keys) return None if logger.getEffectiveLevel() >= logging.INFO: print_block_layout(block2keys, weights_vector) state_dict = {} for weight, ((s, idx), keys) in zip(weights_vector, block2keys): weight *= global_weight if abs(weight) < 1e-6: logger.debug("reject %s:%s (%s)", s, idx, keys[0]) continue for layer, params in groupby_layer(keys).items(): logger.debug( "accept %s:%s (%s) weight=%.2f params=%s", s, idx, layer, weight, ",".join(params), ) if "alpha" in params: params.remove("alpha") key = f"{layer}.alpha" state_dict[key] = sft_fd.get_tensor(key) * weight # if 'dora_scale' in params: # params.remove("dora_scale") # key = f"{layer}.dora_scale" # tensor = sft_fd.get_tensor(key) # if abs(weight - 1.0) > 1e-6: # tensor -= 1.0 # tensor *= weight # tensor += 1.0 # state_dict[key] = tensor for param in params: key = f"{layer}.{param}" state_dict[key] = sft_fd.get_tensor(key) else: logging.warning("no alpha parameter in layer %s: %r", layer, params) for param in params: key = f"{layer}.{param}" state_dict[key] = sft_fd.get_tensor(key) logger.info( "Keeping %d keys from the UNet, %d passing through (text encoders)", len(state_dict), len(pass_through_keys), ) for k in pass_through_keys: state_dict[k] = sft_fd.get_tensor(k) return state_dict def setup_logging(verbosity: int) -> None: """ Set up logging based on verbosity level and quiet flag. Args: verbosity (int): The verbosity level (0-2). quiet (bool): If True, suppress all output except errors. """ log_levels = [logging.WARNING, logging.INFO, logging.DEBUG] log_level = log_levels[max(0, min(verbosity, 2))] logging.basicConfig(level=log_level, format="%(levelname)s: %(message)s") def main() -> None: """ Main function to handle CLI arguments and execute the appropriate actions. """ parser = argparse.ArgumentParser( description="Analyze and filter LoRA layers in SafeTensors files." ) parser.add_argument("input_file", type=Path, help="Input SafeTensors file") parser.add_argument( "vector_string", nargs="?", help="Vector string for filtering blocks" ) parser.add_argument("-o", "--output", type=Path, help="Output file path") parser.add_argument( "-v", "--verbose", action="count", default=1, help="Increase verbosity (can be repeated)", ) parser.add_argument( "-q", "--quiet", action="count", default=0, help="Suppress all output except errors", ) args = parser.parse_args() setup_logging(args.verbose - args.quiet) with safe_open(args.input_file, framework="np") as sft_fd: if args.vector_string: # Filter blocks and save the result filtered_state_dict = filter_blocks(sft_fd, args.vector_string) if filtered_state_dict is None: logging.error("No lyaers in output!") exit(1) # Determine output path output_path = args.output or args.input_file.with_stem( f"{args.input_file.stem}-chop" ) metadata = sft_fd.metadata() metadata["block_vector_string"] = args.vector_string save_file(filtered_state_dict, output_path, metadata=metadata) logging.info(f"Filtered LoRA saved to {output_path}") else: # Analyze LoRA layers block2keys, pass_through_keys = analyze_lora_layers(sft_fd) print_block_layout(block2keys) logging.info(f"Pass through layers: {len(pass_through_keys)}") if __name__ == "__main__": main()