B-LoRA / blora_utils.py
Yardenfren's picture
Upload 3 files
ddea0a0 verified
from typing import Optional
BLOCKS = {
'content': ['unet.up_blocks.0.attentions.0'],
'style': ['unet.up_blocks.0.attentions.1'],
}
def is_belong_to_blocks(key, blocks):
try:
for g in blocks:
if g in key:
return True
return False
except Exception as e:
raise type(e)(f'failed to is_belong_to_block, due to: {e}')
def filter_lora(state_dict, blocks_):
try:
return {k: v for k, v in state_dict.items() if is_belong_to_blocks(k, blocks_)}
except Exception as e:
raise type(e)(f'failed to filter_lora, due to: {e}')
def scale_lora(state_dict, alpha):
try:
return {k: v * alpha for k, v in state_dict.items()}
except Exception as e:
raise type(e)(f'failed to scale_lora, due to: {e}')
def get_target_modules(unet, blocks=None):
try:
if not blocks:
blocks = [('.').join(blk.split('.')[1:]) for blk in BLOCKS['content'] + BLOCKS['style']]
attns = [attn_processor_name.rsplit('.', 1)[0] for attn_processor_name, _ in unet.attn_processors.items() if
is_belong_to_blocks(attn_processor_name, blocks)]
target_modules = [f'{attn}.{mat}' for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns]
return target_modules
except Exception as e:
raise type(e)(f'failed to get_target_modules, due to: {e}')