Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional | |
BLOCKS = { | |
'content': ['unet.up_blocks.0.attentions.0'], | |
'style': ['unet.up_blocks.0.attentions.1'], | |
} | |
def is_belong_to_blocks(key, blocks): | |
try: | |
for g in blocks: | |
if g in key: | |
return True | |
return False | |
except Exception as e: | |
raise type(e)(f'failed to is_belong_to_block, due to: {e}') | |
def filter_lora(state_dict, blocks_): | |
try: | |
return {k: v for k, v in state_dict.items() if is_belong_to_blocks(k, blocks_)} | |
except Exception as e: | |
raise type(e)(f'failed to filter_lora, due to: {e}') | |
def scale_lora(state_dict, alpha): | |
try: | |
return {k: v * alpha for k, v in state_dict.items()} | |
except Exception as e: | |
raise type(e)(f'failed to scale_lora, due to: {e}') | |
def get_target_modules(unet, blocks=None): | |
try: | |
if not blocks: | |
blocks = [('.').join(blk.split('.')[1:]) for blk in BLOCKS['content'] + BLOCKS['style']] | |
attns = [attn_processor_name.rsplit('.', 1)[0] for attn_processor_name, _ in unet.attn_processors.items() if | |
is_belong_to_blocks(attn_processor_name, blocks)] | |
target_modules = [f'{attn}.{mat}' for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns] | |
return target_modules | |
except Exception as e: | |
raise type(e)(f'failed to get_target_modules, due to: {e}') | |