TheNetherWatcher's picture
Upload folder using huggingface_hub
d0ffe9c verified
import logging
from safetensors.torch import load_file
from animatediff import get_dir
from animatediff.utils.lora_diffusers import (LoRANetwork,
create_network_from_weights)
logger = logging.getLogger(__name__)
data_dir = get_dir("data")
def merge_safetensors_lora(text_encoder, unet, lora_path, alpha=0.75, is_animatediff=True):
def dump(loaded):
for a in loaded:
logger.info(f"{a} {loaded[a].shape}")
sd = load_file(lora_path)
if False:
dump(sd)
print(f"create LoRA network")
lora_network: LoRANetwork = create_network_from_weights(text_encoder, unet, sd, multiplier=alpha, is_animatediff=is_animatediff)
print(f"load LoRA network weights")
lora_network.load_state_dict(sd, False)
lora_network.merge_to(alpha)
def load_lora_map(pipe, lora_map_config, video_length, is_sdxl=False):
new_map = {}
for item in lora_map_config:
lora_path = data_dir.joinpath(item)
if type(lora_map_config[item]) in (float,int):
te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder
merge_safetensors_lora(te_en, pipe.unet, lora_path, lora_map_config[item], not is_sdxl)
else:
new_map[lora_path] = lora_map_config[item]
lora_map = LoraMap(pipe, new_map, video_length, is_sdxl)
pipe.lora_map = lora_map if lora_map.is_valid else None
def load_lcm_lora(pipe, lcm_map, is_sdxl=False, is_merge=False):
if is_sdxl:
lora_path = data_dir.joinpath("models/lcm_lora/sdxl/pytorch_lora_weights.safetensors")
else:
lora_path = data_dir.joinpath("models/lcm_lora/sd15/pytorch_lora_weights.safetensors")
logger.info(f"{lora_path=}")
if is_merge:
te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder
merge_safetensors_lora(te_en, pipe.unet, lora_path, 1.0, not is_sdxl)
pipe.lcm = None
return
lcm = LcmLora(pipe, is_sdxl, lora_path, lcm_map)
pipe.lcm = lcm if lcm.is_valid else None
class LcmLora:
def __init__(
self,
pipe,
is_sdxl,
lora_path,
lcm_map
):
self.is_valid = False
sd = load_file(lora_path)
if not sd:
return
te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder
lora_network: LoRANetwork = create_network_from_weights(te_en, pipe.unet, sd, multiplier=1.0, is_animatediff=not is_sdxl)
lora_network.load_state_dict(sd, False)
lora_network.apply_to(1.0)
self.network = lora_network
self.is_valid = True
self.start_scale = lcm_map["start_scale"]
self.end_scale = lcm_map["end_scale"]
self.gradient_start = lcm_map["gradient_start"]
self.gradient_end = lcm_map["gradient_end"]
def to(
self,
device,
dtype,
):
self.network.to(device=device, dtype=dtype)
def apply(
self,
step,
total_steps,
):
step += 1
progress = step / total_steps
if progress < self.gradient_start:
scale = self.start_scale
elif progress > self.gradient_end:
scale = self.end_scale
else:
if (self.gradient_end - self.gradient_start) < 1e-4:
progress = 0
else:
progress = (progress - self.gradient_start) / (self.gradient_end - self.gradient_start)
scale = (self.end_scale - self.start_scale) * progress
scale += self.start_scale
self.network.active( scale )
def unapply(
self,
):
self.network.deactive( )
class LoraMap:
def __init__(
self,
pipe,
lora_map,
video_length,
is_sdxl,
):
self.networks = []
def create_schedule(scales, length):
scales = { int(i):scales[i] for i in scales }
keys = sorted(scales.keys())
if len(keys) == 1:
return { i:scales[keys[0]] for i in range(length) }
keys = keys + [keys[0]]
schedule={}
def calc(rate,start_v,end_v):
return start_v + (rate * rate)*(end_v - start_v)
for key_prev,key_next in zip(keys[:-1],keys[1:]):
v1 = scales[key_prev]
v2 = scales[key_next]
if key_prev > key_next:
key_next += length
for i in range(key_prev,key_next):
dist = i-key_prev
if i >= length:
i -= length
schedule[i] = calc( dist/(key_next-key_prev), v1, v2 )
return schedule
for lora_path in lora_map:
sd = load_file(lora_path)
if not sd:
continue
te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder
lora_network: LoRANetwork = create_network_from_weights(te_en, pipe.unet, sd, multiplier=0.75, is_animatediff=not is_sdxl)
lora_network.load_state_dict(sd, False)
lora_network.apply_to(0.75)
self.networks.append(
{
"network":lora_network,
"region":lora_map[lora_path]["region"],
"schedule": create_schedule(lora_map[lora_path]["scale"], video_length )
}
)
def region_convert(i):
if i == "background":
return 0
else:
return int(i) + 1
for net in self.networks:
net["region"] = [ region_convert(i) for i in net["region"] ]
# for n in self.networks:
# logger.info(f"{n['region']=}")
# logger.info(f"{n['schedule']=}")
if self.networks:
self.is_valid = True
else:
self.is_valid = False
def to(
self,
device,
dtype,
):
for net in self.networks:
net["network"].to(device=device, dtype=dtype)
def apply(
self,
cond_index,
cond_nums,
frame_no,
):
'''
neg 0 (bg)
neg 1
neg 2
pos 0 (bg)
pos 1
pos 2
'''
region_index = cond_index if cond_index < cond_nums//2 else cond_index - cond_nums//2
# logger.info(f"{cond_index=}")
# logger.info(f"{cond_nums=}")
# logger.info(f"{region_index=}")
for i,net in enumerate(self.networks):
if region_index in net["region"]:
scale = net["schedule"][frame_no]
if scale > 0:
net["network"].active( scale )
# logger.info(f"{i=} active {scale=}")
else:
net["network"].deactive( )
# logger.info(f"{i=} DEactive")
else:
net["network"].deactive( )
# logger.info(f"{i=} DEactive")
def unapply(
self,
):
for net in self.networks:
net["network"].deactive( )