|
import logging |
|
|
|
from safetensors.torch import load_file |
|
|
|
from animatediff import get_dir |
|
from animatediff.utils.lora_diffusers import (LoRANetwork, |
|
create_network_from_weights) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
data_dir = get_dir("data") |
|
|
|
|
|
def merge_safetensors_lora(text_encoder, unet, lora_path, alpha=0.75, is_animatediff=True): |
|
|
|
def dump(loaded): |
|
for a in loaded: |
|
logger.info(f"{a} {loaded[a].shape}") |
|
|
|
sd = load_file(lora_path) |
|
|
|
if False: |
|
dump(sd) |
|
|
|
print(f"create LoRA network") |
|
lora_network: LoRANetwork = create_network_from_weights(text_encoder, unet, sd, multiplier=alpha, is_animatediff=is_animatediff) |
|
print(f"load LoRA network weights") |
|
lora_network.load_state_dict(sd, False) |
|
lora_network.merge_to(alpha) |
|
|
|
def load_lora_map(pipe, lora_map_config, video_length, is_sdxl=False): |
|
new_map = {} |
|
for item in lora_map_config: |
|
lora_path = data_dir.joinpath(item) |
|
if type(lora_map_config[item]) in (float,int): |
|
te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder |
|
merge_safetensors_lora(te_en, pipe.unet, lora_path, lora_map_config[item], not is_sdxl) |
|
else: |
|
new_map[lora_path] = lora_map_config[item] |
|
|
|
lora_map = LoraMap(pipe, new_map, video_length, is_sdxl) |
|
pipe.lora_map = lora_map if lora_map.is_valid else None |
|
|
|
def load_lcm_lora(pipe, lcm_map, is_sdxl=False, is_merge=False): |
|
if is_sdxl: |
|
lora_path = data_dir.joinpath("models/lcm_lora/sdxl/pytorch_lora_weights.safetensors") |
|
else: |
|
lora_path = data_dir.joinpath("models/lcm_lora/sd15/pytorch_lora_weights.safetensors") |
|
logger.info(f"{lora_path=}") |
|
|
|
if is_merge: |
|
te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder |
|
merge_safetensors_lora(te_en, pipe.unet, lora_path, 1.0, not is_sdxl) |
|
pipe.lcm = None |
|
return |
|
|
|
lcm = LcmLora(pipe, is_sdxl, lora_path, lcm_map) |
|
pipe.lcm = lcm if lcm.is_valid else None |
|
|
|
class LcmLora: |
|
def __init__( |
|
self, |
|
pipe, |
|
is_sdxl, |
|
lora_path, |
|
lcm_map |
|
): |
|
self.is_valid = False |
|
|
|
sd = load_file(lora_path) |
|
if not sd: |
|
return |
|
|
|
te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder |
|
lora_network: LoRANetwork = create_network_from_weights(te_en, pipe.unet, sd, multiplier=1.0, is_animatediff=not is_sdxl) |
|
lora_network.load_state_dict(sd, False) |
|
lora_network.apply_to(1.0) |
|
self.network = lora_network |
|
|
|
self.is_valid = True |
|
|
|
self.start_scale = lcm_map["start_scale"] |
|
self.end_scale = lcm_map["end_scale"] |
|
self.gradient_start = lcm_map["gradient_start"] |
|
self.gradient_end = lcm_map["gradient_end"] |
|
|
|
|
|
def to( |
|
self, |
|
device, |
|
dtype, |
|
): |
|
self.network.to(device=device, dtype=dtype) |
|
|
|
def apply( |
|
self, |
|
step, |
|
total_steps, |
|
): |
|
step += 1 |
|
progress = step / total_steps |
|
|
|
if progress < self.gradient_start: |
|
scale = self.start_scale |
|
elif progress > self.gradient_end: |
|
scale = self.end_scale |
|
else: |
|
if (self.gradient_end - self.gradient_start) < 1e-4: |
|
progress = 0 |
|
else: |
|
progress = (progress - self.gradient_start) / (self.gradient_end - self.gradient_start) |
|
scale = (self.end_scale - self.start_scale) * progress |
|
scale += self.start_scale |
|
|
|
self.network.active( scale ) |
|
|
|
def unapply( |
|
self, |
|
): |
|
self.network.deactive( ) |
|
|
|
|
|
|
|
class LoraMap: |
|
def __init__( |
|
self, |
|
pipe, |
|
lora_map, |
|
video_length, |
|
is_sdxl, |
|
): |
|
self.networks = [] |
|
|
|
def create_schedule(scales, length): |
|
scales = { int(i):scales[i] for i in scales } |
|
keys = sorted(scales.keys()) |
|
|
|
if len(keys) == 1: |
|
return { i:scales[keys[0]] for i in range(length) } |
|
keys = keys + [keys[0]] |
|
|
|
schedule={} |
|
|
|
def calc(rate,start_v,end_v): |
|
return start_v + (rate * rate)*(end_v - start_v) |
|
|
|
for key_prev,key_next in zip(keys[:-1],keys[1:]): |
|
v1 = scales[key_prev] |
|
v2 = scales[key_next] |
|
if key_prev > key_next: |
|
key_next += length |
|
for i in range(key_prev,key_next): |
|
dist = i-key_prev |
|
if i >= length: |
|
i -= length |
|
schedule[i] = calc( dist/(key_next-key_prev), v1, v2 ) |
|
return schedule |
|
|
|
for lora_path in lora_map: |
|
sd = load_file(lora_path) |
|
if not sd: |
|
continue |
|
te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder |
|
lora_network: LoRANetwork = create_network_from_weights(te_en, pipe.unet, sd, multiplier=0.75, is_animatediff=not is_sdxl) |
|
lora_network.load_state_dict(sd, False) |
|
lora_network.apply_to(0.75) |
|
|
|
self.networks.append( |
|
{ |
|
"network":lora_network, |
|
"region":lora_map[lora_path]["region"], |
|
"schedule": create_schedule(lora_map[lora_path]["scale"], video_length ) |
|
} |
|
) |
|
|
|
def region_convert(i): |
|
if i == "background": |
|
return 0 |
|
else: |
|
return int(i) + 1 |
|
|
|
for net in self.networks: |
|
net["region"] = [ region_convert(i) for i in net["region"] ] |
|
|
|
|
|
|
|
|
|
|
|
if self.networks: |
|
self.is_valid = True |
|
else: |
|
self.is_valid = False |
|
|
|
def to( |
|
self, |
|
device, |
|
dtype, |
|
): |
|
for net in self.networks: |
|
net["network"].to(device=device, dtype=dtype) |
|
|
|
def apply( |
|
self, |
|
cond_index, |
|
cond_nums, |
|
frame_no, |
|
): |
|
''' |
|
neg 0 (bg) |
|
neg 1 |
|
neg 2 |
|
pos 0 (bg) |
|
pos 1 |
|
pos 2 |
|
''' |
|
|
|
region_index = cond_index if cond_index < cond_nums//2 else cond_index - cond_nums//2 |
|
|
|
|
|
|
|
|
|
|
|
for i,net in enumerate(self.networks): |
|
if region_index in net["region"]: |
|
scale = net["schedule"][frame_no] |
|
if scale > 0: |
|
net["network"].active( scale ) |
|
|
|
else: |
|
net["network"].deactive( ) |
|
|
|
|
|
else: |
|
net["network"].deactive( ) |
|
|
|
|
|
def unapply( |
|
self, |
|
): |
|
|
|
for net in self.networks: |
|
net["network"].deactive( ) |
|
|
|
|