from typing import Optional BLOCKS = { 'content': ['unet.up_blocks.0.attentions.0'], 'style': ['unet.up_blocks.0.attentions.1'], } def is_belong_to_blocks(key, blocks): try: for g in blocks: if g in key: return True return False except Exception as e: raise type(e)(f'failed to is_belong_to_block, due to: {e}') def filter_lora(state_dict, blocks_): try: return {k: v for k, v in state_dict.items() if is_belong_to_blocks(k, blocks_)} except Exception as e: raise type(e)(f'failed to filter_lora, due to: {e}') def scale_lora(state_dict, alpha): try: return {k: v * alpha for k, v in state_dict.items()} except Exception as e: raise type(e)(f'failed to scale_lora, due to: {e}') def get_target_modules(unet, blocks=None): try: if not blocks: blocks = [('.').join(blk.split('.')[1:]) for blk in BLOCKS['content'] + BLOCKS['style']] attns = [attn_processor_name.rsplit('.', 1)[0] for attn_processor_name, _ in unet.attn_processors.items() if is_belong_to_blocks(attn_processor_name, blocks)] target_modules = [f'{attn}.{mat}' for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns] return target_modules except Exception as e: raise type(e)(f'failed to get_target_modules, due to: {e}')