Spaces:
Runtime error
Runtime error
File size: 4,999 Bytes
85456ff 2fc816b 85456ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from dataclasses import dataclass
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet3DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
deprecate,
logging,
replace_example_docstring,
BaseOutput,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import (
tensor2vid,
)
from ..CrossAttn.InjecterProc import InjecterProcessor
from ..Misc import Logger as log
from ..Misc import Const
def use_dd_temporal(unet, use=True):
""" To determine using the temporal attention editing at a step
"""
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "Attention" and "attn2" in name:
module.processor.use_dd_temporal = use
def use_dd(unet, use=True):
""" To determine using the spatial attention editing at a step
"""
for name, module in unet.named_modules():
module_name = type(module).__name__
# if module_name == "CrossAttention" and "attn2" in name:
if module_name == "Attention" and "attn2" in name:
module.processor.use_dd = use
def initiailization(unet, bundle, bbox_per_frame):
log.info("Intialization")
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "Attention" and "attn2" in name:
if "temp_attentions" in name:
processor = InjecterProcessor(
bundle=bundle,
bbox_per_frame=bbox_per_frame,
strengthen_scale=bundle["temp_strengthen_scale"],
weaken_scale=bundle["temp_weaken_scale"],
is_text2vidzero=False,
name=name,
)
else:
processor = InjecterProcessor(
bundle=bundle,
bbox_per_frame=bbox_per_frame,
strengthen_scale=bundle["spatial_strengthen_scale"],
weaken_scale=bundle["spatial_weaken_scale"],
is_text2vidzero=False,
name=name,
)
module.processor = processor
# print(name)
log.info("Initialized")
def keyframed_prompt_embeds(bundle, encode_prompt_func, device):
num_frames = bundle["keyframe"][-1]["frame"] + 1
keyframe = bundle["keyframe"]
f = lambda start, end, index: (1 - index) * start + index * end
n = len(keyframe)
keyed_prompt_embeds = []
for i in range(n - 1):
if i == 0:
start_fr = keyframe[i]["frame"]
else:
start_fr = keyframe[i]["frame"] + 1
end_fr = keyframe[i + 1]["frame"]
start_prompt = keyframe[i]["prompt"] + Const.POSITIVE_PROMPT
end_prompt = keyframe[i + 1]["prompt"] + Const.POSITIVE_PROMPT
clip_length = end_fr - start_fr + 1
start_prompt_embeds, _ = encode_prompt_func(
start_prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=Const.NEGATIVE_PROMPT,
)
end_prompt_embeds, negative_prompt_embeds = encode_prompt_func(
end_prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=Const.NEGATIVE_PROMPT,
)
for fr in range(clip_length):
index = float(fr) / (clip_length - 1)
keyed_prompt_embeds.append(f(start_prompt_embeds, end_prompt_embeds, index))
assert len(keyed_prompt_embeds) == num_frames
return torch.cat(keyed_prompt_embeds), negative_prompt_embeds.repeat_interleave(
num_frames, dim=0
)
def keyframed_bbox(bundle):
keyframe = bundle["keyframe"]
bbox_per_frame = []
f = lambda start, end, index: (1 - index) * start + index * end
n = len(keyframe)
for i in range(n - 1):
if i == 0:
start_fr = keyframe[i]["frame"]
else:
start_fr = keyframe[i]["frame"] + 1
end_fr = keyframe[i + 1]["frame"]
start_bbox = keyframe[i]["bbox_ratios"]
end_bbox = keyframe[i + 1]["bbox_ratios"]
clip_length = end_fr - start_fr + 1
for fr in range(clip_length):
index = float(fr) / (clip_length - 1)
bbox = []
for j in range(4):
bbox.append(f(start_bbox[j], end_bbox[j], index))
bbox_per_frame.append(bbox)
return bbox_per_frame
|