eggarsway commited on
Commit
85456ff
Β·
1 Parent(s): 6ee204b
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ assets
2
+ __pycache__
3
+ *.pyc
4
+ *.png
5
+ *undo*
TrailBlazer/CrossAttn/BaseProc.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, TypedDict
2
+ import numpy as np
3
+ import math
4
+ import torch
5
+ from abc import ABC, abstractmethod
6
+ from diffusers.models.attention_processor import Attention as CrossAttention
7
+ from einops import rearrange
8
+ from ..Misc import Logger as log
9
+ from ..Misc.BBox import BoundingBox
10
+
11
+ KERNEL_DIVISION = 3.
12
+ INJECTION_SCALE = 1.0
13
+
14
+
15
+ def reshape_fortran(x, shape):
16
+ """ Reshape a tensor in the fortran index. See
17
+ https://stackoverflow.com/a/63964246
18
+ """
19
+ if len(x.shape) > 0:
20
+ x = x.permute(*reversed(range(len(x.shape))))
21
+ return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
22
+
23
+
24
+ def gaussian_2d(x=0, y=0, mx=0, my=0, sx=1, sy=1):
25
+ """ 2d Gaussian weight function
26
+ """
27
+ gaussian_map = (
28
+ 1
29
+ / (2 * math.pi * sx * sy)
30
+ * torch.exp(-((x - mx) ** 2 / (2 * sx**2) + (y - my) ** 2 / (2 * sy**2)))
31
+ )
32
+ gaussian_map.div_(gaussian_map.max())
33
+ return gaussian_map
34
+
35
+
36
+ class BundleType(TypedDict):
37
+ selected_inds: List[int] # the 1-indexed indices of a subject
38
+ trailing_inds: List[int] # the 1-indexed indices of trailings
39
+ bbox: List[
40
+ float
41
+ ] # four floats to determine the bounding box [left, right, top, bottom]
42
+
43
+
44
+ class CrossAttnProcessorBase:
45
+
46
+ MAX_LEN_CLIP_TOKENS = 77
47
+ DEVICE = "cuda"
48
+
49
+ def __init__(self, bundle, is_text2vidzero=False):
50
+
51
+ self.prompt = bundle["prompt_base"]
52
+ base_prompt = self.prompt.split(";")[0]
53
+ self.len_prompt = len(base_prompt.split(" "))
54
+ self.prompt_len = len(self.prompt.split(" "))
55
+ self.use_dd = False
56
+ self.use_dd_temporal = False
57
+ self.unet_chunk_size = 2
58
+ self._cross_attention_map = None
59
+ self._loss = None
60
+ self._parameters = None
61
+ self.is_text2vidzero = is_text2vidzero
62
+ bbox = None
63
+
64
+ @property
65
+ def cross_attention_map(self):
66
+ return self._cross_attention_map
67
+
68
+ @property
69
+ def loss(self):
70
+ return self._loss
71
+
72
+ @property
73
+ def parameters(self):
74
+ if type(self._parameters) == type(None):
75
+ log.warn("No parameters being initialized. Be cautious!")
76
+ return self._parameters
77
+
78
+ def __call__(
79
+ self,
80
+ attn: CrossAttention,
81
+ hidden_states,
82
+ encoder_hidden_states=None,
83
+ attention_mask=None,
84
+ ):
85
+
86
+ batch_size, sequence_length, _ = hidden_states.shape
87
+ attention_mask = attn.prepare_attention_mask(
88
+ attention_mask, sequence_length, batch_size
89
+ )
90
+ #print("====================")
91
+ query = attn.to_q(hidden_states)
92
+
93
+ is_cross_attention = encoder_hidden_states is not None
94
+ if encoder_hidden_states is None:
95
+ encoder_hidden_states = hidden_states
96
+ # elif attn.cross_attention_norm:
97
+ elif attn.norm_cross:
98
+ encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
99
+
100
+ key = attn.to_k(encoder_hidden_states)
101
+ value = attn.to_v(encoder_hidden_states)
102
+
103
+ def rearrange_3(tensor, f):
104
+ F, D, C = tensor.size()
105
+ return torch.reshape(tensor, (F // f, f, D, C))
106
+
107
+ def rearrange_4(tensor):
108
+ B, F, D, C = tensor.size()
109
+ return torch.reshape(tensor, (B * F, D, C))
110
+
111
+ # Cross Frame Attention
112
+ if not is_cross_attention and self.is_text2vidzero:
113
+ video_length = key.size()[0] // 2
114
+ first_frame_index = [0] * video_length
115
+
116
+ # rearrange keys to have batch and frames in the 1st and 2nd dims respectively
117
+ key = rearrange_3(key, video_length)
118
+ key = key[:, first_frame_index]
119
+ # rearrange values to have batch and frames in the 1st and 2nd dims respectively
120
+ value = rearrange_3(value, video_length)
121
+ value = value[:, first_frame_index]
122
+
123
+ # rearrange back to original shape
124
+ key = rearrange_4(key)
125
+ value = rearrange_4(value)
126
+
127
+ query = attn.head_to_batch_dim(query)
128
+ key = attn.head_to_batch_dim(key)
129
+ value = attn.head_to_batch_dim(value)
130
+ # Cross attention map
131
+ #print(query.shape, key.shape, value.shape)
132
+ attention_probs = attn.get_attention_scores(query, key)
133
+ # print(attention_probs.shape)
134
+ # torch.Size([960, 77, 64]) torch.Size([960, 256, 64]) torch.Size([960, 77, 64]) torch.Size([960, 256, 77])
135
+ # torch.Size([10240, 24, 64]) torch.Size([10240, 24, 64]) torch.Size([10240, 24, 64]) torch.Size([10240, 24, 24])
136
+
137
+ n = attention_probs.shape[0] // 2
138
+ if attention_probs.shape[-1] == CrossAttnProcessorBase.MAX_LEN_CLIP_TOKENS:
139
+ dim = int(np.sqrt(attention_probs.shape[1]))
140
+ if self.use_dd:
141
+ # self.use_dd = False
142
+ attention_probs_4d = attention_probs.view(
143
+ attention_probs.shape[0], dim, dim, attention_probs.shape[-1]
144
+ )[n:]
145
+ attention_probs_4d = self.dd_core(attention_probs_4d)
146
+ attention_probs[n:] = attention_probs_4d.reshape(
147
+ attention_probs_4d.shape[0], dim * dim, attention_probs_4d.shape[-1]
148
+ )
149
+
150
+ self._cross_attention_map = attention_probs.view(
151
+ attention_probs.shape[0], dim, dim, attention_probs.shape[-1]
152
+ )[n:]
153
+
154
+ elif (
155
+ attention_probs.shape[-1] == self.num_frames
156
+ and (attention_probs.shape[0] == 65536)
157
+ ):
158
+ dim = int(np.sqrt(attention_probs.shape[0] // (2 * attn.heads)))
159
+ if self.use_dd_temporal:
160
+ # self.use_dd_temporal = False
161
+ def temporal_doit(origin_attn):
162
+ temporal_attn = reshape_fortran(
163
+ origin_attn,
164
+ (attn.heads, dim, dim, self.num_frames, self.num_frames),
165
+ )
166
+ temporal_attn = torch.transpose(temporal_attn, 1, 2)
167
+ temporal_attn = self.dd_core(temporal_attn)
168
+ # torch.Size([8, 64, 64, 24, 24])
169
+ temporal_attn = torch.transpose(temporal_attn, 1, 2)
170
+ temporal_attn = reshape_fortran(
171
+ temporal_attn,
172
+ (attn.heads * dim * dim, self.num_frames, self.num_frames),
173
+ )
174
+ return temporal_attn
175
+
176
+
177
+ # NOTE: So null text embedding for classification free guidance
178
+ # doesn't really help?
179
+ #attention_probs[n:] = temporal_doit(attention_probs[n:])
180
+ attention_probs[:n] = temporal_doit(attention_probs[:n])
181
+
182
+ self._cross_attention_map = reshape_fortran(
183
+ attention_probs[:n],
184
+ (attn.heads, dim, dim, self.num_frames, self.num_frames),
185
+ )
186
+ self._cross_attention_map = self._cross_attention_map.mean(dim=0)
187
+ self._cross_attention_map = torch.transpose(self._cross_attention_map, 0, 1)
188
+
189
+ attention_probs = torch.abs(attention_probs)
190
+ hidden_states = torch.bmm(attention_probs, value)
191
+ hidden_states = attn.batch_to_head_dim(hidden_states)
192
+ # linear proj
193
+ hidden_states = attn.to_out[0](hidden_states)
194
+ # dropout
195
+ hidden_states = attn.to_out[1](hidden_states)
196
+ return hidden_states
197
+
198
+ @abstractmethod
199
+ def dd_core(self):
200
+ """All DD variants implement this function"""
201
+ pass
202
+
203
+ @staticmethod
204
+ def localized_weight_map(attention_probs_4d, token_inds, bbox_per_frame, scale=1):
205
+ """Using guassian 2d distribution to generate weight map and return the
206
+ array with the same size of the attention argument.
207
+ """
208
+ dim = int(attention_probs_4d.size()[1])
209
+ max_val = attention_probs_4d.max()
210
+ weight_map = torch.zeros_like(attention_probs_4d).half()
211
+ frame_size = attention_probs_4d.shape[0] // len(bbox_per_frame)
212
+
213
+ for i in range(len(bbox_per_frame)):
214
+ bbox_ratios = bbox_per_frame[i]
215
+ bbox = BoundingBox(dim, bbox_ratios)
216
+ # Generating the gaussian distribution map patch
217
+ x = torch.linspace(0, bbox.height, bbox.height)
218
+ y = torch.linspace(0, bbox.width, bbox.width)
219
+ x, y = torch.meshgrid(x, y, indexing="ij")
220
+ noise_patch = (
221
+ gaussian_2d(
222
+ x,
223
+ y,
224
+ mx=int(bbox.height / 2),
225
+ my=int(bbox.width / 2),
226
+ sx=float(bbox.height / KERNEL_DIVISION),
227
+ sy=float(bbox.width / KERNEL_DIVISION),
228
+ )
229
+ .unsqueeze(0)
230
+ .unsqueeze(-1)
231
+ .repeat(frame_size, 1, 1, len(token_inds))
232
+ .to(attention_probs_4d.device)
233
+ ).half()
234
+
235
+ scale = attention_probs_4d.max() * INJECTION_SCALE
236
+ noise_patch.mul_(scale)
237
+
238
+ b_idx = frame_size * i
239
+ e_idx = frame_size * (i + 1)
240
+ bbox.sliced_tensor_in_bbox(weight_map)[
241
+ b_idx:e_idx, ..., token_inds
242
+ ] = noise_patch
243
+ return weight_map
244
+
245
+ @staticmethod
246
+ def localized_temporal_weight_map(attention_probs_5d, bbox_per_frame, scale=1):
247
+ """Using guassian 2d distribution to generate weight map and return the
248
+ array with the same size of the attention argument.
249
+ """
250
+ dim = int(attention_probs_5d.size()[1])
251
+ f = attention_probs_5d.shape[-1]
252
+ max_val = attention_probs_5d.max()
253
+ weight_map = torch.zeros_like(attention_probs_5d).half()
254
+
255
+ def get_patch(bbox_at_frame, i, j, bbox_per_frame):
256
+ bbox = BoundingBox(dim, bbox_at_frame)
257
+ # Generating the gaussian distribution map patch
258
+ x = torch.linspace(0, bbox.height, bbox.height)
259
+ y = torch.linspace(0, bbox.width, bbox.width)
260
+ x, y = torch.meshgrid(x, y, indexing="ij")
261
+ noise_patch = (
262
+ gaussian_2d(
263
+ x,
264
+ y,
265
+ mx=int(bbox.height / 2),
266
+ my=int(bbox.width / 2),
267
+ sx=float(bbox.height / KERNEL_DIVISION),
268
+ sy=float(bbox.width / KERNEL_DIVISION),
269
+ )
270
+ .unsqueeze(0)
271
+ .repeat(attention_probs_5d.shape[0], 1, 1)
272
+ .to(attention_probs_5d.device)
273
+ ).half()
274
+ scale = attention_probs_5d.max() * INJECTION_SCALE
275
+ noise_patch.mul_(scale)
276
+ inv_noise_patch = noise_patch - noise_patch.max()
277
+ dist = (float(abs(j - i))) / len(bbox_per_frame)
278
+ final_patch = inv_noise_patch * dist + noise_patch * (1. - dist)
279
+ #final_patch = noise_patch * (1. - dist)
280
+ #final_patch = inv_noise_patch * dist
281
+ return final_patch, bbox
282
+
283
+
284
+ for j in range(len(bbox_per_frame)):
285
+ for i in range(len(bbox_per_frame)):
286
+ patch_i, bbox_i = get_patch(bbox_per_frame[i], i, j, bbox_per_frame)
287
+ patch_j, bbox_j = get_patch(bbox_per_frame[j], i, j, bbox_per_frame)
288
+ bbox_i.sliced_tensor_in_bbox(weight_map)[..., i, j] = patch_i
289
+ bbox_j.sliced_tensor_in_bbox(weight_map)[..., i, j] = patch_j
290
+
291
+ return weight_map
TrailBlazer/CrossAttn/InjecterProc.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, TypedDict
2
+ import numpy as np
3
+ import torch
4
+ import math
5
+
6
+ from ..Misc import Logger as log
7
+
8
+ from .BaseProc import CrossAttnProcessorBase
9
+ from .BaseProc import BundleType
10
+ from ..Misc.BBox import BoundingBox
11
+
12
+
13
+ class InjecterProcessor(CrossAttnProcessorBase):
14
+ def __init__(
15
+ self,
16
+ bundle: BundleType,
17
+ bbox_per_frame: List[BoundingBox],
18
+ name: str,
19
+ strengthen_scale: float = 0.0,
20
+ weaken_scale: float = 1.0,
21
+ is_text2vidzero: bool = False,
22
+ ):
23
+ super().__init__(bundle, is_text2vidzero=is_text2vidzero)
24
+ self.strengthen_scale = strengthen_scale
25
+ self.weaken_scale = weaken_scale
26
+ self.bundle = bundle
27
+ self.num_frames = len(bbox_per_frame)
28
+ self.bbox_per_frame = bbox_per_frame
29
+ self.use_weaken = True
30
+ self.name = name
31
+
32
+ def dd_core(self, attention_probs: torch.Tensor):
33
+ """ """
34
+
35
+ frame_size = attention_probs.shape[0] // self.num_frames
36
+ num_affected_frames = self.num_frames
37
+ attention_probs_copied = attention_probs.detach().clone()
38
+
39
+ token_inds = self.bundle.get("token_inds")
40
+ trailing_length = self.bundle.get("trailing_length")
41
+ trailing_inds = list(
42
+ range(self.len_prompt + 1, self.len_prompt + trailing_length + 1)
43
+ )
44
+ # NOTE: Spatial cross attention editing
45
+ if len(attention_probs.size()) == 4:
46
+ all_tokens_inds = list(set(token_inds).union(set(trailing_inds)))
47
+ strengthen_map = self.localized_weight_map(
48
+ attention_probs_copied,
49
+ token_inds=all_tokens_inds,
50
+ bbox_per_frame=self.bbox_per_frame,
51
+ )
52
+
53
+ weaken_map = torch.ones_like(strengthen_map)
54
+ zero_indices = torch.where(strengthen_map == 0)
55
+ weaken_map[zero_indices] = self.weaken_scale
56
+
57
+ # weakening
58
+ attention_probs_copied[..., all_tokens_inds] *= weaken_map[
59
+ ..., all_tokens_inds
60
+ ]
61
+ # strengthen
62
+ attention_probs_copied[..., all_tokens_inds] += (
63
+ self.strengthen_scale * strengthen_map[..., all_tokens_inds]
64
+ )
65
+ # NOTE: Temporal cross attention editing
66
+ elif len(attention_probs.size()) == 5:
67
+ strengthen_map = self.localized_temporal_weight_map(
68
+ attention_probs_copied,
69
+ bbox_per_frame=self.bbox_per_frame,
70
+ )
71
+ weaken_map = torch.ones_like(strengthen_map)
72
+ zero_indices = torch.where(strengthen_map == 0)
73
+ weaken_map[zero_indices] = self.weaken_scale
74
+ # weakening
75
+ attention_probs_copied *= weaken_map
76
+ # strengthen
77
+ attention_probs_copied += self.strengthen_scale * strengthen_map
78
+
79
+ return attention_probs_copied
TrailBlazer/CrossAttn/Utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import torch
3
+ import torchvision
4
+ import numpy as np
5
+
6
+ from ..Misc import Logger as log
7
+ from ..Setting import Config
8
+
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib
11
+
12
+ # To avoid plt.imshow crash
13
+ matplotlib.use("Agg")
14
+
15
+
16
+ class CAttnProcChoice(enum.Enum):
17
+ INVALID = -1
18
+ BASIC = 0
19
+
20
+
21
+ def plot_activations(cross_attn, prompt, plot_with_trailings=False):
22
+ num_frames = cross_attn.shape[0]
23
+ cross_attn = cross_attn.cpu()
24
+ for i in range(num_frames):
25
+ filename = "/tmp/out.{:04d}.jpg".format(i)
26
+ plot_activation(cross_attn[i], prompt, filename, plot_with_trailings)
27
+
28
+
29
+ def plot_activation(cross_attn, prompt, filepath="", plot_with_trailings=False):
30
+
31
+ splitted_prompt = prompt.split(" ")
32
+ n = len(splitted_prompt)
33
+ start = 0
34
+ arrs = []
35
+ if plot_with_trailings:
36
+ for j in range(5):
37
+ arr = []
38
+ for i in range(start, start + n):
39
+ cross_attn_sliced = cross_attn[..., i + 1]
40
+ arr.append(cross_attn_sliced.T)
41
+ start += n
42
+ arr = np.hstack(arr)
43
+ arrs.append(arr)
44
+ arrs = np.vstack(arrs).T
45
+ else:
46
+ arr = []
47
+ for i in range(start, start + n):
48
+ print(i)
49
+ cross_attn_sliced = cross_attn[..., i + 1]
50
+ arr.append(cross_attn_sliced)
51
+ arrs = np.hstack(arr).astype(np.float32)
52
+ plt.clf()
53
+
54
+ v_min = arrs.min()
55
+ v_max = arrs.max()
56
+ n_min = 0.0
57
+ n_max = 1
58
+
59
+ arrs = (arrs - v_min) / (v_max - v_min)
60
+ arrs = (arrs * (n_max - n_min)) + n_min
61
+
62
+ plt.imshow(arrs, cmap="jet")
63
+ plt.title(prompt)
64
+ plt.colorbar(orientation="horizontal", pad=0.2)
65
+ if filepath:
66
+ plt.savefig(filepath)
67
+ log.info(f"Saved [{filepath}]")
68
+ else:
69
+ plt.show()
70
+
71
+
72
+ def get_cross_attn(
73
+ unet,
74
+ resolution=32,
75
+ target_size=64,
76
+ ):
77
+ """To get the cross attention map softmax(QK^T) from Unet.
78
+ Args:
79
+ unet (UNet2DConditionModel): unet
80
+ resolution (int): the cross attention map with specific resolution. It only supports 64, 32, 16, and 8
81
+ target_size (int): the target resolution for resizing the cross attention map
82
+ Returns:
83
+ (torch.tensor): a tensor with shape (target_size, target_size, 77)
84
+ """
85
+ attns = []
86
+ check = [8, 16, 32, 64]
87
+ if resolution not in check:
88
+ raise ValueError(
89
+ "The cross attention resolution only support 8x8, 16x16, 32x32, and 64x64. "
90
+ "The given resolution {}x{} is not in the list. Abort.".format(
91
+ resolution, resolution
92
+ )
93
+ )
94
+ for name, module in unet.named_modules():
95
+ module_name = type(module).__name__
96
+ # NOTE: attn2 is for cross-attention while attn1 is self-attention
97
+ dim = resolution * resolution
98
+ if not hasattr(module, "processor"):
99
+ continue
100
+ if hasattr(module.processor, "cross_attention_map"):
101
+ attn = module.processor.cross_attention_map[None, ...]
102
+ attns.append(attn)
103
+
104
+ if not attns:
105
+ print("Err: Quried attns size [{}]".format(len(attns)))
106
+ return
107
+ attns = torch.cat(attns, dim=0)
108
+ attns = torch.sum(attns, dim=0)
109
+ # resized = torch.zeros([target_size, target_size, 77])
110
+ # f = torchvision.transforms.Resize(size=(64, 64))
111
+ # dim = attns.shape[1]
112
+ # print(attns.shape)
113
+ # for i in range(77):
114
+ # attn_slice = attns[..., i].view(1, dim, dim)
115
+ # resized[..., i] = f(attn_slice)[0]
116
+ return attns
117
+
118
+
119
+ def get_avg_cross_attn(unet, resolutions, resize):
120
+ """To get the average cross attention map across its resolutions.
121
+ Args:
122
+ unet (UNet2DConditionModel): unet
123
+ resolution (list): a list of specific resolution. It only supports 64, 32, 16, and 8
124
+ target_size (int): the target resolution for resizing the cross attention map
125
+ Returns:
126
+ (torch.tensor): a tensor with shape (target_size, target_size, 77)
127
+ """
128
+ cross_attns = []
129
+ for resolution in resolutions:
130
+ try:
131
+ cross_attns.append(get_cross_attn(unet, resolution, resize))
132
+ except:
133
+ log.warn(f"No cross-attention map with resolution [{resolution}]")
134
+ if cross_attns:
135
+ cross_attns = torch.stack(cross_attns).mean(0)
136
+ return cross_attns
137
+
138
+
139
+ def save_cross_attn(unet):
140
+ """TODO: to save cross attn"""
141
+ for name, module in unet.named_modules():
142
+ module_name = type(module).__name__
143
+ if module_name == "CrossAttention" and "attn2" in name:
144
+ folder = "/tmp"
145
+ filepath = os.path.join(folder, name + ".pt")
146
+ torch.save(module.attn, filepath)
147
+ print(filepath)
148
+
149
+
150
+ def use_dd(unet, use=True):
151
+ for name, module in unet.named_modules():
152
+ module_name = type(module).__name__
153
+ if module_name == "CrossAttention" and "attn2" in name:
154
+ module.processor.use_dd = use
155
+
156
+
157
+ def use_dd_temporal(unet, use=True):
158
+ for name, module in unet.named_modules():
159
+ module_name = type(module).__name__
160
+ if module_name == "CrossAttention" and "attn2" in name:
161
+ module.processor.use_dd_temporal = use
162
+
163
+
164
+ def get_loss(unet):
165
+ loss = 0
166
+ total = 0
167
+ for name, module in unet.named_modules():
168
+ module_name = type(module).__name__
169
+ if module_name == "CrossAttention" and "attn2" in name:
170
+ loss += module.processor.loss
171
+ total += 1
172
+ return loss / total
173
+
174
+
175
+ def get_params(unet):
176
+ parameters = []
177
+ for name, module in unet.named_modules():
178
+ module_name = type(module).__name__
179
+ if module_name == "CrossAttention" and "attn2" in name:
180
+ parameters.append(module.processor.parameters)
181
+ return parameters
TrailBlazer/CrossAttn/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
TrailBlazer/Misc/BBox.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ import torch
4
+
5
+
6
+ class BoundingBox:
7
+ """A rectangular bounding box determines the directed regions."""
8
+
9
+ def __init__(self, resolution, box_ratios, margin=0.0):
10
+ """
11
+ Args:
12
+ resolution(int): the resolution of the 2d spatial input
13
+ box_ratios(List[float]):
14
+ Returns:
15
+ """
16
+ assert (
17
+ box_ratios[1] < box_ratios[3]
18
+ ), "the boundary top ratio should be less than bottom"
19
+ assert (
20
+ box_ratios[0] < box_ratios[2]
21
+ ), "the boundary left ratio should be less than right"
22
+ self.left = int((box_ratios[0] - margin) * resolution)
23
+ self.right = int((box_ratios[2] + margin) * resolution)
24
+ self.top = int((box_ratios[1] - margin) * resolution)
25
+ self.bottom = int((box_ratios[3] + margin) * resolution)
26
+ self.height = self.bottom - self.top
27
+ self.width = self.right - self.left
28
+ if self.height == 0:
29
+ self.height = 1
30
+ if self.width == 0:
31
+ self.width = 1
32
+
33
+ def sliced_tensor_in_bbox(self, tensor: torch.tensor) -> torch.tensor:
34
+ """ slicing the tensor with bbox area
35
+
36
+ Args:
37
+ tensor(torch.tensor): the original tensor in 4d
38
+ Returns:
39
+ (torch.tensor): the reduced tensor inside bbox
40
+ """
41
+ return tensor[:, self.top : self.bottom, self.left : self.right, :]
42
+
43
+ def mask_reweight_out_bbox(
44
+ self, tensor: torch.tensor, value: float = 0.0
45
+ ) -> torch.tensor:
46
+ """reweighting value outside bbox
47
+
48
+ Args:
49
+ tensor(torch.tensor): the original tensor in 4d
50
+ value(float): reweighting factor default with 0.0
51
+ Returns:
52
+ (torch.tensor): the reweighted tensor
53
+ """
54
+ mask = torch.ones_like(tensor).to(tensor.device) * value
55
+ mask[:, self.top : self.bottom, self.left : self.right, :] = 1
56
+ return tensor * mask
57
+
58
+ def mask_reweight_in_bbox(
59
+ self, tensor: torch.tensor, value: float = 0.0
60
+ ) -> torch.tensor:
61
+ """reweighting value within bbox
62
+
63
+ Args:
64
+ tensor(torch.tensor): the original tensor in 4d
65
+ value(float): reweighting factor default with 0.0
66
+ Returns:
67
+ (torch.tensor): the reweighted tensor
68
+ """
69
+ mask = torch.ones_like(tensor).to(tensor.device)
70
+ mask[:, self.top : self.bottom, self.left : self.right, :] = value
71
+ return tensor * mask
72
+
73
+ def __str__(self):
74
+ """it prints Box(L:%d, R:%d, T:%d, B:%d) for better ingestion"""
75
+ return f"Box(L:{self.left}, R:{self.right}, T:{self.top}, B:{self.bottom})"
76
+
77
+ def __rerp__(self):
78
+ """ """
79
+ return f"Box(L:{self.left}, R:{self.right}, T:{self.top}, B:{self.bottom})"
80
+
81
+
82
+ if __name__ == "__main__":
83
+ # Example: second quadrant
84
+ input_res = 32
85
+ left = 0.0
86
+ top = 0.0
87
+ right = 0.5
88
+ bottom = 0.5
89
+ box_ratios = [left, top, right, bottom]
90
+ bbox = BoundingBox(resolution=input_res, box_ratios=box_ratios)
91
+
92
+ print(bbox)
93
+ # Box(L:0, R:16, T:0, B:16)
TrailBlazer/Misc/ConfigIO.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+
3
+ def config_loader(filepath):
4
+ data = None
5
+ with open(filepath, "r") as yamlfile:
6
+ data = yaml.load(yamlfile, Loader=yaml.FullLoader)
7
+ yamlfile.close()
8
+ return data
9
+
10
+ def config_saver(data, filepath):
11
+ with open(filepath, 'w') as yamlfile:
12
+ data1 = yaml.dump(data, yamlfile)
13
+ yamlfile.close()
TrailBlazer/Misc/Const.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # https://okuha.com/best-stable-diffusion-prompts/
2
+
3
+ NEGATIVE_PROMPT = "bad anatomy, bad proportions, blurry, cloned face, cropped, deformed, dehydrated, disfigured, duplicate, error, extra arms, extra fingers, extra legs, extra limbs, fused fingers, gross proportions, jpeg artifacts, long neck, low quality, lowres, malformed limbs, missing arms, missing legs, morbid, mutated hands, mutation, mutilated, out of frame, poorly drawn face, poorly drawn hands, signature, text, too many fingers, ugly, username, watermark, worst quality, Amputee, Autograph, Bad anatomy, Bad illustration, Bad proportions, Beyond the borders, Blank background, Blurry, Body out of frame, Boring background, Branding, Cropped, Cut off, Deformed, Disfigured, Dismembered, Disproportioned, Distorted, Draft, Duplicate, Duplicated features, Extra arms, Extra fingers, Extra hands, Extra legs, Extra limbs, Fault, Flaw, Fused fingers, Grains, Grainy, Gross proportions, Hazy, Identifying mark, Improper scale, Incorrect physiology, Incorrect ratio, Indistinct, Kitsch, Logo, Long neck, Low quality, Low resolution, Macabre, Malformed, Mark, Misshapen, Missing arms, Missing fingers, Missing hands, Missing legs, Mistake, Morbid, Mutated hands, Mutation, Mutilated, Off-screen, Out of frame, Outside the picture, Pixelated, Poorly drawn face, Poorly drawn feet, Poorly drawn hands, Printed words, Render, Repellent, Replicate, Reproduce, Revolting dimensions, Script, Shortened, Sign, Signature, Split image, Squint, Storyboard, Text, Tiling, Trimmed, Ugly, Unfocused, Unattractive, Unnatural pose, Unreal engine, Unsightly, Watermark, Written language, Absent limbs, Additional appendages, Additional digits, Additional limbs, Altered appendages, Amputee, Asymmetric, Asymmetric ears, Bad anatomy, Bad ears, Bad eyes, Bad face, Bad proportions, Broken finger, Broken hand, Broken leg, Broken wrist, Cartoon, Cloned face, Cloned head, Collapsed eyeshadow, Combined appendages, Conjoined, Copied visage, Corpse, Cripple, Cropped head, Cross-eyed, Depressed, Desiccated, Disconnected limb, Disfigured, Dismembered, Disproportionate, Double face, Duplicated features, Eerie, Elongated throat, lowres, low quality, jpeg, artifacts, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, drawing, painting, crayon, sketch, graphite, impressionist, noisy, soft, extra tails"
4
+
5
+
6
+ POSITIVE_PROMPT = "; masterpiece, best quality, intricate, detailed, sharp, focused, intricate details, hyperdetailed, 8k, RAW photo,realistic style, national geography, fantasy, hyper-realistic, rich colors, realistic texture"
TrailBlazer/Misc/Logger.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+
4
+ from io import StringIO # Python3
5
+
6
+ import sys
7
+
8
+ class SilencedStdOut:
9
+ # https://stackoverflow.com/questions/65608502/is-there-a-way-to-force-any-function-to-not-be-verbose-in-python
10
+ def __enter__(self):
11
+ self.old_stdout = sys.stdout
12
+ self.result = StringIO()
13
+ sys.stdout = self.result
14
+
15
+ def __exit__(self, *args, **kwargs):
16
+
17
+ sys.stdout = self.old_stdout
18
+ result_string = self.result.getvalue() # use if you want or discard.
19
+
20
+ class CustomFormatter(logging.Formatter):
21
+
22
+ GRAY = "\x1b[38m"
23
+ YELLOW = "\x1b[33m"
24
+ CYAN = "\x1b[36m"
25
+ RED = "\x1b[31m"
26
+ BOLD_RED = "\x1b[31;1m"
27
+ RESET = "\x1b[0m"
28
+ FORMAT = "[%(asctime)s - %(name)s - %(levelname)8s] - %(message)s (%(filename)s:%(lineno)d)"
29
+
30
+ FORMATS = {
31
+ logging.DEBUG: GRAY + FORMAT + RESET,
32
+ logging.INFO: GRAY + FORMAT + RESET,
33
+ logging.WARNING: YELLOW + FORMAT + RESET,
34
+ logging.ERROR: RED + FORMAT + RESET,
35
+ logging.CRITICAL: BOLD_RED + FORMAT + RESET,
36
+ logging.DEBUG: CYAN + FORMAT + RESET,
37
+ }
38
+
39
+ def format(self, record):
40
+ log_fmt = self.FORMATS.get(record.levelno)
41
+ formatter = logging.Formatter(log_fmt)
42
+ return formatter.format(record)
43
+
44
+ # create logger with 'spam_application'
45
+
46
+ logger = logging.getLogger("TrailBlazer")
47
+ logger.handlers = []
48
+ logger.setLevel(logging.DEBUG)
49
+ # create console handler with a higher log level
50
+ console_handler = logging.StreamHandler()
51
+ console_handler.setLevel(logging.DEBUG)
52
+
53
+ console_handler.setFormatter(CustomFormatter())
54
+ logger.addHandler(console_handler)
55
+
56
+ critical = logger.critical
57
+ fatal = logger.fatal
58
+ error = logger.error
59
+ warning = logger.warning
60
+ warn = logger.warn
61
+ info = logger.info
62
+ debug = logger.debug
63
+
64
+ if __name__ == "__main__":
65
+ from DirectedDiffusion import Logger as log
66
+ log.info("info message")
67
+ log.warning("warning message")
68
+ log.error("error message")
69
+ log.debug("debug message")
70
+ log.critical("critical message")
TrailBlazer/Misc/Painter.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch.nn.functional as nnf
7
+ import torchvision
8
+ import einops
9
+ import matplotlib.pyplot as plt
10
+ import scipy.stats as st
11
+ from PIL import Image, ImageFont, ImageDraw
12
+
13
+ plt.rcParams["figure.figsize"] = [
14
+ float(v) * 1.5 for v in plt.rcParams["figure.figsize"]
15
+ ]
16
+
17
+
18
+ class CrossAttnPainter:
19
+
20
+ def __init__(self, bundle, pipe, root="/tmp"):
21
+ self.dim = 64
22
+ self.folder =
23
+
24
+ def plot_frames(self):
25
+ folder = "/tmp"
26
+ from PIL import Image
27
+ for i, f in enumerate(video_frames):
28
+ img = Image.fromarray(f)
29
+ filepath = os.path.join(folder, "recons.{:04d}.jpg".format(i))
30
+ img.save(filepath)
31
+
32
+
33
+ def plot_spatial_attn(self):
34
+
35
+ arr = (
36
+ pipe.unet.up_blocks[1]
37
+ .attentions[0]
38
+ .transformer_blocks[0]
39
+ .attn2.processor.cross_attention_map
40
+ )
41
+ heads = pipe.unet.up_blocks[1].attentions[0].transformer_blocks[0].attn2.heads
42
+ arr = torch.transpose(arr, 1, 3)
43
+ arr = nnf.interpolate(arr, size=(64, 64), mode='bicubic', align_corners=False)
44
+ arr = torch.transpose(arr, 1, 3)
45
+ arr = arr.cpu().numpy()
46
+ arr = arr.reshape(24, heads, 64, 64, 77)
47
+ arr = arr.mean(axis=1)
48
+ n = arr.shape[0]
49
+ for i in range(n):
50
+ filename = "/tmp/spatialca.{:04d}.jpg".format(i)
51
+ plt.clf()
52
+ plt.imshow(arr[i, :, :, 2], cmap="jet")
53
+ plt.gca().set_axis_off()
54
+ plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
55
+ hspace = 0, wspace = 0)
56
+ plt.margins(0,0)
57
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
58
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
59
+ plt.savefig(filename, bbox_inches = 'tight',pad_inches = 0)
60
+ print(filename)
61
+
62
+ def plot_temporal_attn(self):
63
+
64
+ # arr = pipe.unet.mid_block.temp_attentions[0].transformer_blocks[0].attn2.processor.cross_attention_map
65
+ import matplotlib.pyplot as plt
66
+ import torch.nn.functional as nnf
67
+ arr = (
68
+ pipe.unet.up_blocks[2]
69
+ .temp_attentions[1]
70
+ .transformer_blocks[0]
71
+ .attn2.processor.cross_attention_map
72
+ )
73
+ #arr = pipe.unet.transformer_in.transformer_blocks[0].attn2.processor.cross_attention_map
74
+ arr = torch.transpose(arr, 0, 2).transpose(1, 3)
75
+ arr = nnf.interpolate(arr, size=(64, 64), mode="bicubic", align_corners=False)
76
+ arr = torch.transpose(arr, 0, 2).transpose(1, 3)
77
+ arr = arr.cpu().numpy()
78
+ n = arr.shape[-1]
79
+ for i in range(n-2):
80
+ filename = "/tmp/tempcaiip2.{:04d}.jpg".format(i)
81
+ plt.clf()
82
+ plt.imshow(arr[..., i+2, i], cmap="jet")
83
+ plt.gca().set_axis_off()
84
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
85
+ plt.margins(0, 0)
86
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
87
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
88
+ plt.savefig(filename, bbox_inches="tight", pad_inches=0)
89
+ print(filename)
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+ def plot_latent_noise(latents, mode):
101
+
102
+ for i in range(latents.shape[0]):
103
+ tensor = latents[i].cpu()
104
+ min_val = torch.min(tensor)
105
+ max_val = torch.max(tensor)
106
+ scale = 255 * (max_val - min_val)
107
+ tensor = scale * (tensor - min_val)
108
+ tensor = tensor.type(torch.int8)
109
+ tensor = einops.rearrange(tensor, "c w h -> w h c")
110
+ if mode == "RGB":
111
+ tensor = tensor[...,:3]
112
+ mode_ = "RGB"
113
+ elif mode == "RGBA":
114
+ mode_ = "RGBA"
115
+ pass
116
+ elif mode == "GRAY":
117
+ tensor = tensor[...,0]
118
+ mode_ = "L"
119
+
120
+ x = tensor.numpy()
121
+
122
+ img = Image.fromarray(x, mode_)
123
+ img = img.resize((256, 256), resample=Image.NEAREST )
124
+ filepath = f"/tmp/out.{i:04d}.jpg"
125
+ img.save(filepath)
126
+
127
+ tensor = latents[i].cpu()
128
+ x = tensor.flatten().numpy()
129
+ x /= x.max()
130
+ plt.hist(x, density=True, bins=20, range=[-1, 1])
131
+ mn, mx = plt.xlim()
132
+ plt.xlim(mn, mx)
133
+ kde_xs = np.linspace(mn, mx, 300)
134
+ kde = st.gaussian_kde(x)
135
+ plt.plot(kde_xs, kde.pdf(kde_xs), label="PDF")
136
+ filepath = f"/tmp/hist.{i:04d}.jpg"
137
+ plt.savefig(filepath)
138
+ plt.clf()
139
+
140
+ print(i)
141
+
142
+
143
+ def plot_activation(cross_attn, prompt, filepath="", plot_with_trailings=False, n_trailing=2):
144
+ splitted_prompt = prompt.split(" ")
145
+ n = len(splitted_prompt)
146
+ start = 0
147
+ arrs = []
148
+ if plot_with_trailings:
149
+ for j in range(n_trailing):
150
+ arr = []
151
+ for i in range(start, start + n):
152
+ cross_attn_sliced = cross_attn[..., i + 1]
153
+ arr.append(cross_attn_sliced.T)
154
+ start += n
155
+ arr = np.hstack(arr)
156
+ arrs.append(arr)
157
+ arrs = np.vstack(arrs).T
158
+ else:
159
+ arr = []
160
+ for i in range(start, start + n):
161
+ cross_attn_sliced = cross_attn[..., i + 1]
162
+ arr.append(cross_attn_sliced)
163
+ arrs = np.vstack(arr)
164
+ plt.imshow(arrs, cmap="jet", vmin=0.0, vmax=.5)
165
+ plt.title(prompt)
166
+ if filepath:
167
+ plt.savefig(filepath)
168
+ else:
169
+ plt.show()
170
+
171
+
172
+ def draw_dd_metadata(img, bbox, text="", target_res=1024):
173
+ img = img.resize((target_res, target_res))
174
+ image_editable = ImageDraw.Draw(img)
175
+
176
+ for region in [bbox]:
177
+ x0 = region[0] * target_res
178
+ y0 = region[2] * target_res
179
+ x1 = region[1] * target_res
180
+ y1 = region[3] * target_res
181
+ image_editable.rectangle(xy=[x0, y0, x1, y1], outline=(255, 0, 0, 255), width=5)
182
+ if text:
183
+ font = ImageFont.truetype("./assets/JetBrainsMono-Bold.ttf", size=13)
184
+ image_editable.multiline_text(
185
+ (15, 15),
186
+ text,
187
+ (255, 255, 255, 0),
188
+ font=font,
189
+ stroke_width=2,
190
+ stroke_fill=(0, 0, 0, 255),
191
+ spacing=0,
192
+ )
193
+ return img
194
+
195
+
196
+
197
+
198
+
199
+
200
+
201
+
202
+
203
+
204
+
205
+
206
+
207
+
208
+
209
+
210
+
211
+
212
+
213
+
214
+
215
+
216
+
217
+
218
+
219
+
220
+
221
+
222
+ if __name__ == "__main__":
223
+ latents = torch.load("assets/experiments/a-cat-sitting-on-a-car_230615-144611/latents.pt")
224
+ plot_latent_noise(latents, "GRAY")
TrailBlazer/Misc/__init__.py ADDED
File without changes
TrailBlazer/Pipeline/TextToVideoSDPipelineCall.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import CLIPTextModel, CLIPTokenizer
7
+ from dataclasses import dataclass
8
+
9
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
10
+ from diffusers.models import AutoencoderKL, UNet3DConditionModel
11
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
12
+ from diffusers.schedulers import KarrasDiffusionSchedulers
13
+ from diffusers.utils import (
14
+ deprecate,
15
+ logging,
16
+ replace_example_docstring,
17
+ BaseOutput,
18
+ )
19
+ from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import (
20
+ tensor2vid,
21
+ )
22
+
23
+ from ..Misc import Logger as log
24
+ from ..Misc import Const
25
+ from .Utils import initiailization, keyframed_bbox, keyframed_prompt_embeds, use_dd, use_dd_temporal
26
+
27
+ @dataclass
28
+ class TextToVideoSDPipelineOutput(BaseOutput):
29
+ """
30
+ Output class for text-to-video pipelines.
31
+
32
+ Args:
33
+ frames (`List[np.ndarray]` or `torch.FloatTensor`)
34
+ List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
35
+ a `torch` tensor. The length of the list denotes the video length (the number of frames).
36
+ """
37
+
38
+ frames: Union[List[np.ndarray], torch.FloatTensor]
39
+ latents: Union[List[np.ndarray], torch.FloatTensor]
40
+ bbox_per_frame: torch.tensor
41
+
42
+
43
+ @torch.no_grad()
44
+ def text_to_video_sd_pipeline_call(
45
+ self,
46
+ bundle=None,
47
+ # prompt: Union[str, List[str]] = None,
48
+ height: Optional[int] = None,
49
+ width: Optional[int] = None,
50
+ # num_frames: int = 16,
51
+ num_inference_steps: int = 50,
52
+ # num_dd_steps: int = 0,
53
+ guidance_scale: float = 9.0,
54
+ negative_prompt: Optional[Union[str, List[str]]] = None,
55
+ eta: float = 0.0,
56
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
57
+ latents: Optional[torch.FloatTensor] = None,
58
+ prompt_embeds: Optional[torch.FloatTensor] = None,
59
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
60
+ output_type: Optional[str] = "np",
61
+ return_dict: bool = True,
62
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
63
+ callback_steps: int = 1,
64
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
65
+ ):
66
+ r"""
67
+ The call function to the pipeline for generation.
68
+
69
+ Args:
70
+ prompt (`str` or `List[str]`, *optional*):
71
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
72
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
73
+ The height in pixels of the generated video.
74
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
75
+ The width in pixels of the generated video.
76
+ num_frames (`int`, *optional*, defaults to 16):
77
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
78
+ amounts to 2 seconds of video.
79
+ num_inference_steps (`int`, *optional*, defaults to 50):
80
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
81
+ expense of slower inference.
82
+ guidance_scale (`float`, *optional*, defaults to 7.5):
83
+ A higher guidance scale value encourages the model to generate images closely linked to the text
84
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
85
+ negative_prompt (`str` or `List[str]`, *optional*):
86
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
87
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
88
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
89
+ The number of images to generate per prompt.
90
+ eta (`float`, *optional*, defaults to 0.0):
91
+ Corresponds to parameter eta (Ξ·) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
92
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
93
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
94
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
95
+ generation deterministic.
96
+ latents (`torch.FloatTensor`, *optional*):
97
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
98
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
99
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
100
+ `(batch_size, num_channel, num_frames, height, width)`.
101
+ prompt_embeds (`torch.FloatTensor`, *optional*):
102
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
103
+ provided, text embeddings are generated from the `prompt` input argument.
104
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
105
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
106
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
107
+ output_type (`str`, *optional*, defaults to `"np"`):
108
+ The output format of the generated video. Choose between `torch.FloatTensor` or `np.array`.
109
+ return_dict (`bool`, *optional*, defaults to `True`):
110
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
111
+ of a plain tuple.
112
+ callback (`Callable`, *optional*):
113
+ A function that calls every `callback_steps` steps during inference. The function is called with the
114
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
115
+ callback_steps (`int`, *optional*, defaults to 1):
116
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
117
+ every step.
118
+ cross_attention_kwargs (`dict`, *optional*):
119
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
120
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
121
+
122
+ Examples:
123
+
124
+ Returns:
125
+ [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
126
+ If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
127
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
128
+ """
129
+
130
+ assert (
131
+ len(bundle["keyframe"]) >= 2
132
+ ), "Must be greater than 2 keyframes. Input {} keys".format(len(bundle["keyframe"]))
133
+
134
+ assert (
135
+ bundle["keyframe"][0]["frame"] == 0
136
+ ), "First keyframe must indicate frame at 0, but given {}".format(
137
+ bundle["keyframe"][0]["frame"]
138
+ )
139
+
140
+ if bundle["keyframe"][-1]["frame"] != 23:
141
+ log.info(
142
+ "It's recommended to set the last key to 23 to match"
143
+ " the sequence length 24 used in training ZeroScope"
144
+ )
145
+
146
+ for i in range(len(bundle["keyframe"]) - 1):
147
+ log.info
148
+ assert (
149
+ bundle["keyframe"][i + 1]["frame"] > bundle["keyframe"][i]["frame"]
150
+ ), "The keyframe indices must be ordered in the config file, Sorry!"
151
+
152
+ bundle["prompt_base"] = bundle["keyframe"][0]["prompt"]
153
+ prompt = bundle["prompt_base"]
154
+ #prompt += Const.POSITIVE_PROMPT
155
+ num_frames = bundle["keyframe"][-1]["frame"] + 1
156
+ num_dd_spatial_steps = bundle["num_dd_spatial_steps"]
157
+ num_dd_temporal_steps = bundle["num_dd_temporal_steps"]
158
+
159
+ bbox_per_frame = keyframed_bbox(bundle)
160
+ initiailization(unet=self.unet, bundle=bundle, bbox_per_frame=bbox_per_frame)
161
+
162
+ from pprint import pprint
163
+
164
+ log.info("Experiment parameters:")
165
+ print("==========================================")
166
+ pprint(bundle)
167
+ print("==========================================")
168
+ # 0. Default height and width to unet
169
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
170
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
171
+
172
+ num_images_per_prompt = 1
173
+ negative_prompt = Const.NEGATIVE_PROMPT
174
+ # 1. Check inputs. Raise error if not correct
175
+ # self.check_inputs(
176
+ # prompt,
177
+ # height,
178
+ # width,
179
+ # callback_steps,
180
+ # negative_prompt,
181
+ # prompt_embeds,
182
+ # negative_prompt_embeds,
183
+ # )
184
+
185
+ # # 2. Define call parameters
186
+ if prompt is not None and isinstance(prompt, str):
187
+ batch_size = 1
188
+ elif prompt is not None and isinstance(prompt, list):
189
+ batch_size = len(prompt)
190
+ else:
191
+ batch_size = prompt_embeds.shape[0]
192
+
193
+ device = self._execution_device
194
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
195
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
196
+ # corresponds to doing no classifier free guidance.
197
+ do_classifier_free_guidance = guidance_scale > 1.0
198
+
199
+ # 3. Encode input prompt
200
+ text_encoder_lora_scale = (
201
+ cross_attention_kwargs.get("scale", None)
202
+ if cross_attention_kwargs is not None
203
+ else None
204
+ )
205
+
206
+ # prompt_embeds, negative_prompt_embeds = self.encode_prompt(
207
+ # prompt,
208
+ # device,
209
+ # num_images_per_prompt,
210
+ # do_classifier_free_guidance,
211
+ # negative_prompt,
212
+ # prompt_embeds=prompt_embeds,
213
+ # negative_prompt_embeds=negative_prompt_embeds,
214
+ # lora_scale=text_encoder_lora_scale,
215
+ # )
216
+
217
+ prompt_embeds, negative_prompt_embeds = keyframed_prompt_embeds(
218
+ bundle, self.encode_prompt, device
219
+ )
220
+
221
+ # For classifier free guidance, we need to do two forward passes.
222
+ # Here we concatenate the unconditional and text embeddings into a single batch
223
+ # to avoid doing two forward passes
224
+ if do_classifier_free_guidance:
225
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
226
+
227
+ # 4. Prepare timesteps
228
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
229
+ timesteps = self.scheduler.timesteps
230
+
231
+ # 5. Prepare latent variables
232
+ num_channels_latents = self.unet.config.in_channels
233
+ latents = self.prepare_latents(
234
+ batch_size * num_images_per_prompt,
235
+ num_channels_latents,
236
+ num_frames,
237
+ height,
238
+ width,
239
+ prompt_embeds.dtype,
240
+ device,
241
+ generator,
242
+ latents,
243
+ )
244
+
245
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
246
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
247
+
248
+ # 7. Denoising loop
249
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
250
+
251
+ latents_at_steps = []
252
+
253
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
254
+ for i, t in enumerate(timesteps):
255
+ # expand the latents if we are doing classifier free guidance
256
+ latent_model_input = (
257
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
258
+ )
259
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
260
+
261
+ # predict the noise residual
262
+ if i < (num_dd_spatial_steps):
263
+ use_dd(self.unet, True)
264
+
265
+ if i < (num_dd_temporal_steps):
266
+ use_dd_temporal(self.unet, True)
267
+
268
+ noise_pred = self.unet(
269
+ latent_model_input,
270
+ t,
271
+ encoder_hidden_states=prompt_embeds,
272
+ cross_attention_kwargs=cross_attention_kwargs,
273
+ return_dict=False,
274
+ )[0]
275
+
276
+ use_dd(self.unet, False)
277
+ use_dd_temporal(self.unet, False)
278
+
279
+ # perform guidance
280
+ if do_classifier_free_guidance:
281
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
282
+ noise_pred = noise_pred_uncond + guidance_scale * (
283
+ noise_pred_text - noise_pred_uncond
284
+ )
285
+
286
+ # reshape latents
287
+ bsz, channel, frames, width, height = latents.shape
288
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(
289
+ bsz * frames, channel, width, height
290
+ )
291
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(
292
+ bsz * frames, channel, width, height
293
+ )
294
+
295
+ # compute the previous noisy sample x_t -> x_t-1
296
+ latents = self.scheduler.step(
297
+ noise_pred, t, latents, **extra_step_kwargs
298
+ ).prev_sample
299
+
300
+ # if i==num_dd_steps:
301
+ # print("PF!", latents.shape)
302
+ # n = latents.shape[0]
303
+ # for f in range(n):
304
+ # latents[f] = torch.roll(latents[f], -f, dims=-1)
305
+
306
+ # reshape latents back
307
+ latents = (
308
+ latents[None, :]
309
+ .reshape(bsz, frames, channel, width, height)
310
+ .permute(0, 2, 1, 3, 4)
311
+ )
312
+ latents_at_steps.append(latents)
313
+
314
+ # call the callback, if provided
315
+ if i == len(timesteps) - 1 or (
316
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
317
+ ):
318
+ progress_bar.update()
319
+ if callback is not None and i % callback_steps == 0:
320
+ callback(i, t, latents)
321
+
322
+ if output_type == "latent":
323
+ return TextToVideoSDPipelineOutput(frames=latents)
324
+
325
+ video_tensor = self.decode_latents(latents)
326
+
327
+ if output_type == "pt":
328
+ video = video_tensor
329
+ else:
330
+ video = tensor2vid(video_tensor)
331
+
332
+ # Offload all models
333
+ self.maybe_free_model_hooks()
334
+
335
+ if not return_dict:
336
+ return (video,)
337
+
338
+ latents_at_steps = torch.cat(latents_at_steps)
339
+ return TextToVideoSDPipelineOutput(frames=video, latents=latents_at_steps, bbox_per_frame=bbox_per_frame)
TrailBlazer/Pipeline/UNet3DConditionModelCall.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2023 The ModelScope Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.attention_processor import (
26
+ ADDED_KV_ATTENTION_PROCESSORS,
27
+ CROSS_ATTENTION_PROCESSORS,
28
+ AttentionProcessor,
29
+ AttnAddedKVProcessor,
30
+ AttnProcessor,
31
+ )
32
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
33
+ from diffusers.models.modeling_utils import ModelMixin
34
+ from diffusers.models.transformer_temporal import TransformerTemporalModel
35
+ from diffusers.models.unet_3d_blocks import (
36
+ CrossAttnDownBlock3D,
37
+ CrossAttnUpBlock3D,
38
+ DownBlock3D,
39
+ UNetMidBlock3DCrossAttn,
40
+ UpBlock3D,
41
+ get_down_block,
42
+ get_up_block,
43
+ )
44
+ from diffusers.models.unet_3d_condition import UNet3DConditionOutput
45
+
46
+
47
+
48
+ def unet3d_condition_model_forward(
49
+ self,
50
+ sample: torch.FloatTensor,
51
+ timestep: Union[torch.Tensor, float, int],
52
+ encoder_hidden_states: torch.Tensor,
53
+ class_labels: Optional[torch.Tensor] = None,
54
+ timestep_cond: Optional[torch.Tensor] = None,
55
+ attention_mask: Optional[torch.Tensor] = None,
56
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
57
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
58
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
59
+ return_dict: bool = True,
60
+ ) -> Union[UNet3DConditionOutput, Tuple]:
61
+ r"""
62
+ The [`UNet3DConditionModel`] forward method.
63
+
64
+ Args:
65
+ sample (`torch.FloatTensor`):
66
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
67
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
68
+ encoder_hidden_states (`torch.FloatTensor`):
69
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
70
+ return_dict (`bool`, *optional*, defaults to `True`):
71
+ Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
72
+ tuple.
73
+ cross_attention_kwargs (`dict`, *optional*):
74
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
75
+
76
+ Returns:
77
+ [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
78
+ If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
79
+ a `tuple` is returned where the first element is the sample tensor.
80
+ """
81
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
82
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
83
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
84
+ # on the fly if necessary.
85
+ default_overall_up_factor = 2**self.num_upsamplers
86
+
87
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
88
+ forward_upsample_size = False
89
+ upsample_size = None
90
+
91
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
92
+ logger.info("Forward upsample size to force interpolation output size.")
93
+ forward_upsample_size = True
94
+
95
+ # prepare attention_mask
96
+ if attention_mask is not None:
97
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
98
+ attention_mask = attention_mask.unsqueeze(1)
99
+
100
+ # 1. time
101
+ timesteps = timestep
102
+ if not torch.is_tensor(timesteps):
103
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
104
+ # This would be a good case for the `match` statement (Python 3.10+)
105
+ is_mps = sample.device.type == "mps"
106
+ if isinstance(timestep, float):
107
+ dtype = torch.float32 if is_mps else torch.float64
108
+ else:
109
+ dtype = torch.int32 if is_mps else torch.int64
110
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
111
+ elif len(timesteps.shape) == 0:
112
+ timesteps = timesteps[None].to(sample.device)
113
+
114
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
115
+ num_frames = sample.shape[2]
116
+ timesteps = timesteps.expand(sample.shape[0])
117
+
118
+ t_emb = self.time_proj(timesteps)
119
+
120
+ # timesteps does not contain any weights and will always return f32 tensors
121
+ # but time_embedding might actually be running in fp16. so we need to cast here.
122
+ # there might be better ways to encapsulate this.
123
+ t_emb = t_emb.to(dtype=self.dtype)
124
+
125
+ emb = self.time_embedding(t_emb, timestep_cond)
126
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
127
+ # encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
128
+ # print(encoder_hidden_states.shape)
129
+ # quit()
130
+
131
+ # 2. pre-process
132
+ sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
133
+ sample = self.conv_in(sample)
134
+
135
+ sample = self.transformer_in(
136
+ sample,
137
+ num_frames=num_frames,
138
+ cross_attention_kwargs=cross_attention_kwargs,
139
+ return_dict=False,
140
+ )[0]
141
+
142
+ # 3. down
143
+ down_block_res_samples = (sample,)
144
+ for downsample_block in self.down_blocks:
145
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
146
+ sample, res_samples = downsample_block(
147
+ hidden_states=sample,
148
+ temb=emb,
149
+ encoder_hidden_states=encoder_hidden_states,
150
+ attention_mask=attention_mask,
151
+ num_frames=num_frames,
152
+ cross_attention_kwargs=cross_attention_kwargs,
153
+ )
154
+ else:
155
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
156
+
157
+ down_block_res_samples += res_samples
158
+
159
+ if down_block_additional_residuals is not None:
160
+ new_down_block_res_samples = ()
161
+
162
+ for down_block_res_sample, down_block_additional_residual in zip(
163
+ down_block_res_samples, down_block_additional_residuals
164
+ ):
165
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
166
+ new_down_block_res_samples += (down_block_res_sample,)
167
+
168
+ down_block_res_samples = new_down_block_res_samples
169
+
170
+ # 4. mid
171
+ if self.mid_block is not None:
172
+ sample = self.mid_block(
173
+ sample,
174
+ emb,
175
+ encoder_hidden_states=encoder_hidden_states,
176
+ attention_mask=attention_mask,
177
+ num_frames=num_frames,
178
+ cross_attention_kwargs=cross_attention_kwargs,
179
+ )
180
+
181
+ if mid_block_additional_residual is not None:
182
+ sample = sample + mid_block_additional_residual
183
+
184
+ # 5. up
185
+ for i, upsample_block in enumerate(self.up_blocks):
186
+ is_final_block = i == len(self.up_blocks) - 1
187
+
188
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
189
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
190
+
191
+ # if we have not reached the final block and need to forward the
192
+ # upsample size, we do it here
193
+ if not is_final_block and forward_upsample_size:
194
+ upsample_size = down_block_res_samples[-1].shape[2:]
195
+
196
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
197
+ sample = upsample_block(
198
+ hidden_states=sample,
199
+ temb=emb,
200
+ res_hidden_states_tuple=res_samples,
201
+ encoder_hidden_states=encoder_hidden_states,
202
+ upsample_size=upsample_size,
203
+ attention_mask=attention_mask,
204
+ num_frames=num_frames,
205
+ cross_attention_kwargs=cross_attention_kwargs,
206
+ )
207
+ else:
208
+ sample = upsample_block(
209
+ hidden_states=sample,
210
+ temb=emb,
211
+ res_hidden_states_tuple=res_samples,
212
+ upsample_size=upsample_size,
213
+ num_frames=num_frames,
214
+ )
215
+
216
+ # 6. post-process
217
+ if self.conv_norm_out:
218
+ sample = self.conv_norm_out(sample)
219
+ sample = self.conv_act(sample)
220
+
221
+ sample = self.conv_out(sample)
222
+
223
+ # reshape to (batch, channel, framerate, width, height)
224
+ sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
225
+
226
+ if not return_dict:
227
+ return (sample,)
228
+
229
+ return UNet3DConditionOutput(sample=sample)
TrailBlazer/Pipeline/Utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import CLIPTextModel, CLIPTokenizer
7
+ from dataclasses import dataclass
8
+
9
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
10
+ from diffusers.models import AutoencoderKL, UNet3DConditionModel
11
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
12
+ from diffusers.schedulers import KarrasDiffusionSchedulers
13
+ from diffusers.utils import (
14
+ deprecate,
15
+ logging,
16
+ replace_example_docstring,
17
+ BaseOutput,
18
+ )
19
+ from diffusers.utils.torch_utils import randn_tensor
20
+ from diffusers.pipeline_utils import DiffusionPipeline
21
+ from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import (
22
+ tensor2vid,
23
+ )
24
+ from ..CrossAttn.InjecterProc import InjecterProcessor
25
+ from ..Misc import Logger as log
26
+ from ..Misc import Const
27
+
28
+
29
+
30
+
31
+ def use_dd_temporal(unet, use=True):
32
+ """ To determine using the temporal attention editing at a step
33
+ """
34
+ for name, module in unet.named_modules():
35
+ module_name = type(module).__name__
36
+ if module_name == "Attention" and "attn2" in name:
37
+ module.processor.use_dd_temporal = use
38
+
39
+
40
+ def use_dd(unet, use=True):
41
+ """ To determine using the spatial attention editing at a step
42
+ """
43
+ for name, module in unet.named_modules():
44
+ module_name = type(module).__name__
45
+ # if module_name == "CrossAttention" and "attn2" in name:
46
+ if module_name == "Attention" and "attn2" in name:
47
+ module.processor.use_dd = use
48
+
49
+
50
+ def initiailization(unet, bundle, bbox_per_frame):
51
+ log.info("Intialization")
52
+
53
+ for name, module in unet.named_modules():
54
+ module_name = type(module).__name__
55
+ if module_name == "Attention" and "attn2" in name:
56
+ if "temp_attentions" in name:
57
+ processor = InjecterProcessor(
58
+ bundle=bundle,
59
+ bbox_per_frame=bbox_per_frame,
60
+ strengthen_scale=bundle["temp_strengthen_scale"],
61
+ weaken_scale=bundle["temp_weaken_scale"],
62
+ is_text2vidzero=False,
63
+ name=name,
64
+ )
65
+ else:
66
+ processor = InjecterProcessor(
67
+ bundle=bundle,
68
+ bbox_per_frame=bbox_per_frame,
69
+ strengthen_scale=bundle["spatial_strengthen_scale"],
70
+ weaken_scale=bundle["spatial_weaken_scale"],
71
+ is_text2vidzero=False,
72
+ name=name,
73
+ )
74
+ module.processor = processor
75
+ # print(name)
76
+ log.info("Initialized")
77
+
78
+
79
+ def keyframed_prompt_embeds(bundle, encode_prompt_func, device):
80
+ num_frames = bundle["keyframe"][-1]["frame"] + 1
81
+ keyframe = bundle["keyframe"]
82
+ f = lambda start, end, index: (1 - index) * start + index * end
83
+ n = len(keyframe)
84
+ keyed_prompt_embeds = []
85
+ for i in range(n - 1):
86
+ if i == 0:
87
+ start_fr = keyframe[i]["frame"]
88
+ else:
89
+ start_fr = keyframe[i]["frame"] + 1
90
+ end_fr = keyframe[i + 1]["frame"]
91
+
92
+ start_prompt = keyframe[i]["prompt"] + Const.POSITIVE_PROMPT
93
+ end_prompt = keyframe[i + 1]["prompt"] + Const.POSITIVE_PROMPT
94
+ clip_length = end_fr - start_fr + 1
95
+
96
+ start_prompt_embeds, _ = encode_prompt_func(
97
+ start_prompt,
98
+ device=device,
99
+ num_images_per_prompt=1,
100
+ do_classifier_free_guidance=True,
101
+ negative_prompt=Const.NEGATIVE_PROMPT,
102
+ )
103
+
104
+ end_prompt_embeds, negative_prompt_embeds = encode_prompt_func(
105
+ end_prompt,
106
+ device=device,
107
+ num_images_per_prompt=1,
108
+ do_classifier_free_guidance=True,
109
+ negative_prompt=Const.NEGATIVE_PROMPT,
110
+ )
111
+
112
+ for fr in range(clip_length):
113
+ index = float(fr) / (clip_length - 1)
114
+ keyed_prompt_embeds.append(f(start_prompt_embeds, end_prompt_embeds, index))
115
+ assert len(keyed_prompt_embeds) == num_frames
116
+
117
+ return torch.cat(keyed_prompt_embeds), negative_prompt_embeds.repeat_interleave(
118
+ num_frames, dim=0
119
+ )
120
+
121
+
122
+ def keyframed_bbox(bundle):
123
+
124
+ keyframe = bundle["keyframe"]
125
+ bbox_per_frame = []
126
+ f = lambda start, end, index: (1 - index) * start + index * end
127
+ n = len(keyframe)
128
+ for i in range(n - 1):
129
+ if i == 0:
130
+ start_fr = keyframe[i]["frame"]
131
+ else:
132
+ start_fr = keyframe[i]["frame"] + 1
133
+ end_fr = keyframe[i + 1]["frame"]
134
+ start_bbox = keyframe[i]["bbox_ratios"]
135
+ end_bbox = keyframe[i + 1]["bbox_ratios"]
136
+ clip_length = end_fr - start_fr + 1
137
+ for fr in range(clip_length):
138
+ index = float(fr) / (clip_length - 1)
139
+ bbox = []
140
+ for j in range(4):
141
+ bbox.append(f(start_bbox[j], end_bbox[j], index))
142
+ bbox_per_frame.append(bbox)
143
+
144
+ return bbox_per_frame
TrailBlazer/Pipeline/__init__.py ADDED
File without changes
TrailBlazer/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # TrailBlazer - Codebase
TrailBlazer/Setting/Config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ DEVICE = "cuda"
5
+ GUIDANCE_SCALE = 7.5
6
+ WIDTH = 512
7
+ HEIGHT = 512
8
+ NUM_BACKWARD_STEPS = 50
9
+ STEPS = 50
10
+ DTYPE = torch.float16
11
+
12
+ MODEL_HOME = f"{os.path.expanduser('~')}/Workspace/Project/Models"
13
+
14
+ NEGATIVE_PROMPT = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
15
+ POSITIVE_PROMPT = "best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth"
16
+
17
+
18
+ SD_V1_5_ID = "runwayml/stable-diffusion-v1-5"
19
+ SD_V1_5_PATH = f"{MODEL_HOME}/{SD_V1_5_ID}"
20
+ CNET_CANNY_ID = "lllyasviel/sd-controlnet-canny"
21
+ CNET_CANNY_PATH = f"{MODEL_HOME}/{CNET_CANNY_ID}"
22
+ CNET_OPENPOSE_ID = "lllyasviel/sd-controlnet-openpose"
23
+ CNET_OPENPOSE_PATH = f"{MODEL_HOME}/{CNET_OPENPOSE_ID}"
TrailBlazer/Setting/Const.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ RECONS_NAME = "recons.jpg"
2
+ LATENTS_NAME = "latents.pt"
3
+ CATTN_NAME = "cattn.pt"
4
+ CATTN_VIZ_NAME = "cattn.jpg"
TrailBlazer/Setting/__init__.py ADDED
File without changes
TrailBlazer/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # # VideoDiffusion
2
+ # from .Pipeline.Dumnmy import DummyPipeline
3
+ # from .Pipeline.Standard import StandardPipeline
4
+ # from .Pipeline.ControlNet import ControlNetPipeline
5
+ # from .Pipeline.Img2Img import Img2ImgPipeline
6
+ # from .Pipeline.Video import VideoPipeline
7
+
8
+ # from .Pipeline.TestMayaNoise import TestMayaNoisePipeline
app.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import torch
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image, ImageOps, ImageDraw, ImageFont, ImageColor
7
+ from urllib.request import urlopen
8
+
9
+ root = os.path.dirname(os.path.abspath(__file__))
10
+ static = os.path.join(root, "static")
11
+
12
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
13
+ from diffusers.pipelines import TextToVideoSDPipeline
14
+ from diffusers.utils import export_to_video
15
+ from TrailBlazer.Misc import ConfigIO
16
+ from TrailBlazer.Misc import Logger as log
17
+ from TrailBlazer.Pipeline.TextToVideoSDPipelineCall import (
18
+ text_to_video_sd_pipeline_call,
19
+ )
20
+ from TrailBlazer.Pipeline.UNet3DConditionModelCall import (
21
+ unet3d_condition_model_forward,
22
+ )
23
+
24
+ TextToVideoSDPipeline.__call__ = text_to_video_sd_pipeline_call
25
+ from diffusers.models.unet_3d_condition import UNet3DConditionModel
26
+
27
+ unet3d_condition_model_forward_copy = UNet3DConditionModel.forward
28
+ UNet3DConditionModel.forward = unet3d_condition_model_forward
29
+
30
+
31
+ from diffusers.utils import export_to_video
32
+
33
+ model_id = "cerspense/zeroscope_v2_576w"
34
+ model_path = model_id
35
+ pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
36
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
37
+ pipe.enable_model_cpu_offload()
38
+
39
+ def core(bundle):
40
+
41
+ generator = torch.Generator().manual_seed(int(bundle["seed"]))
42
+ result = pipe(
43
+ bundle=bundle,
44
+ height=512,
45
+ width=512,
46
+ generator=generator,
47
+ num_inference_steps=40,
48
+ )
49
+ return result.frames
50
+
51
+
52
+ def clear_btn_fn():
53
+ return "", "", "", ""
54
+
55
+
56
+ def gen_btn_fn(
57
+ prompts,
58
+ bboxes,
59
+ frames,
60
+ word_prompt_indices,
61
+ trailing_length,
62
+ n_spatial_steps,
63
+ n_temporal_steps,
64
+ spatial_strengthen_scale,
65
+ spatial_weaken_scale,
66
+ temporal_strengthen_scale,
67
+ temporal_weaken_scale,
68
+ rand_seed,
69
+ ):
70
+
71
+ bundle = {}
72
+ bundle["trailing_length"] = trailing_length
73
+ bundle["num_dd_spatial_steps"] = n_spatial_steps
74
+ bundle["num_dd_temporal_steps"] = n_temporal_steps
75
+ bundle["num_frames"] = 24
76
+ bundle["seed"] = rand_seed
77
+ bundle["spatial_strengthen_scale"] = spatial_strengthen_scale
78
+ bundle["spatial_weaken_scale"] = spatial_weaken_scale
79
+ bundle["temp_strengthen_scale"] = temporal_strengthen_scale
80
+ bundle["temp_weaken_scale"] = temporal_weaken_scale
81
+ bundle["token_inds"] = [int(v) for v in word_prompt_indices.split(",")]
82
+
83
+ bundle["keyframe"] = []
84
+ frames = frames.split(";")
85
+ bboxes = bboxes.split(";")
86
+ if ";" in prompts:
87
+ prompts = prompts.split(";")
88
+ else:
89
+ prompts = [prompts for i in range(len(frames))]
90
+
91
+ assert (
92
+ len(frames) == len(bboxes) == len(prompts)
93
+ ), "Inconsistent number of keyframes in the given inputs."
94
+
95
+ frames.pop()
96
+ bboxes.pop()
97
+ prompts.pop()
98
+
99
+
100
+
101
+ for i in range(len(frames)):
102
+ keyframe = {}
103
+ keyframe["bbox_ratios"] = [float(v) for v in bboxes[i].split(",")]
104
+ keyframe["frame"] = int(frames[i])
105
+ keyframe["prompt"] = prompts[i]
106
+ bundle["keyframe"].append(keyframe)
107
+ print(bundle)
108
+ result = core(bundle)
109
+ path = export_to_video(result)
110
+ return path
111
+
112
+
113
+ def save_mask(inputs):
114
+ layers = inputs["layers"]
115
+ if not layers:
116
+ return inputs["background"]
117
+ mask = layers[0]
118
+ new_image = Image.new("RGBA", mask.size, color="white")
119
+ new_image.paste(mask, mask=mask)
120
+ new_image = new_image.convert("RGB")
121
+ print("SAve")
122
+ return ImageOps.invert(new_image)
123
+
124
+
125
+ def out_label_cb(im):
126
+ layers = im["layers"]
127
+ if not isinstance(layers, list):
128
+ layers = [layers]
129
+
130
+ img = None
131
+ text = "Bboxes: "
132
+ for idx, layer in enumerate(layers):
133
+ mask = np.array(layer).sum(axis=-1)
134
+ ys, xs = np.where(mask != 0)
135
+ h, w = mask.shape
136
+ if not list(xs) or not list(ys):
137
+ continue
138
+ x_min = np.min(xs)
139
+ x_max = np.max(xs)
140
+ y_min = np.min(ys)
141
+ y_max = np.max(ys)
142
+
143
+ text += "{:.2f},{:.2f},{:.2f},{:.2f}".format(
144
+ x_min * 1.0 / w, y_min * 1.0 / h, x_max * 1.0 / w, y_max * 1.0 / h
145
+ )
146
+ text += ";\n"
147
+ return text
148
+
149
+
150
+ def out_board_cb(im):
151
+
152
+ layers = im["layers"]
153
+ if not isinstance(layers, list):
154
+ layers = [layers]
155
+
156
+ img = None
157
+ for idx, layer in enumerate(layers):
158
+ mask = np.array(layer).sum(axis=-1)
159
+ ys, xs = np.where(mask != 0)
160
+
161
+ if not list(xs) or not list(ys):
162
+ continue
163
+
164
+ h, w = mask.shape
165
+ if not img:
166
+ img = Image.new("RGBA", (w, h))
167
+ x_min = np.min(xs)
168
+ x_max = np.max(xs)
169
+ y_min = np.min(ys)
170
+ y_max = np.max(ys)
171
+
172
+ # output
173
+ shape = [(x_min, y_min), (x_max, y_max)]
174
+ colors = list(ImageColor.colormap.keys())
175
+ draw = ImageDraw.Draw(img)
176
+ draw.rectangle(shape, outline=colors[idx], width=5)
177
+ text = "Bbox#{}".format(idx)
178
+ font = ImageFont.load_default()
179
+ draw.text((x_max - 0.5 * (x_max - x_min), y_max), text, font=font, align="left")
180
+
181
+ return img
182
+
183
+
184
+ with gr.Blocks(
185
+ analytics_enabled=False,
186
+ title="TrailBlazer Demo",
187
+ ) as main:
188
+
189
+ description = """
190
+ <h1 align="center" style="font-size: 48px">TrailBlazer: Trajectory Control for Diffusion-Based Video Generation</h1>
191
+ <h4 align="center" style="margin: 0;">If you like our project, please give us a star ✨ at our Huggingface space, and our Github repository.</h4>
192
+ <br>
193
+ <span align="center" style="font-size: 18px">
194
+ [<a href="https://hohonu-vicml.github.io/Trailblazer.Page/" target="_blank">Project Page</a>]
195
+ [<a href="http://arxiv.org/abs/2401.00896" target="_blank">Paper</a>]
196
+ [<a href="https://github.com/hohonu-vicml/Trailblazer" target="_blank">GitHub</a>]
197
+ [<a href="https://www.youtube.com/watch?v=kEN-32wN-xQ" target="_blank">Project Video</a>]
198
+ [<a href="https://www.youtube.com/watch?v=P-PSkS7sNco" target="_blank">Result Video</a>]
199
+ </span>
200
+ </p>
201
+ <p>
202
+ <strong>Usage:</strong> Our Gradio app is implemented based on our executable script CmdTrailBlazer in our github repository. Please see our general information below for a quick guidance, as well as the hints within the app widgets.
203
+ <ul>
204
+ <li>Basic: The bounding box (bbox) is the tuple of four floats for the rectangular corners: left, top, right, bottom in the normalized ratio. The Word prompt indices is a list of 1-indexed numbers determining the prompt word.</li>
205
+ <li>Advanced Options: We also offer some key parameters to adjust the synthesis result. Please see our paper for more information about the ablations.</li>
206
+ </ul>
207
+ </p>
208
+ """
209
+ gr.HTML(description)
210
+
211
+ with gr.Row():
212
+ with gr.Column(scale=2):
213
+ with gr.Row():
214
+ with gr.Tab("Main"):
215
+ text_prompt_tb = gr.Textbox(
216
+ interactive=True, label="Keyframe: Prompt"
217
+ )
218
+ bboxes_tb = gr.Textbox(interactive=True, label="Keyframe: Bboxes")
219
+ frame_tb = gr.Textbox(
220
+ interactive=True, label="Keyframe: frame indices"
221
+ )
222
+ with gr.Row():
223
+ word_prompt_indices_tb = gr.Textbox(
224
+ interactive=True, label="Word prompt indices:"
225
+ )
226
+ text = "Hint: Each keyframe ends with <strong>SEMICOLON</strong>, and <strong>COMMA</strong> for separating each value in the keyframe. The prompt field can be a single prompt without semicolon, or multiple prompts ended semicolon. One can use the SketchPadHelper tab to help to design the bboxes field."
227
+ gr.HTML(text)
228
+ with gr.Row():
229
+ clear_btn = gr.Button(value="Clear")
230
+ gen_btn = gr.Button(value="Generate")
231
+
232
+ with gr.Accordion("Advanced Options", open=False):
233
+ text = "Hint: This default value should be sufficient for most tasks. However, it's important to note that our approach is currently implemented on ZeroScope, and its performance may be influenced by the model's characteristics. We plan to conduct experiments on different models in the future."
234
+ gr.HTML(text)
235
+ with gr.Row():
236
+ trailing_length = gr.Slider(
237
+ minimum=0,
238
+ maximum=30,
239
+ step=1,
240
+ value=13,
241
+ interactive=True,
242
+ label="#Trailing",
243
+ )
244
+ n_spatial_steps = gr.Slider(
245
+ minimum=0,
246
+ maximum=30,
247
+ step=1,
248
+ value=5,
249
+ interactive=True,
250
+ label="#Spatial edits",
251
+ )
252
+ n_temporal_steps = gr.Slider(
253
+ minimum=0,
254
+ maximum=30,
255
+ step=1,
256
+ value=5,
257
+ interactive=True,
258
+ label="#Temporal edits",
259
+ )
260
+ with gr.Row():
261
+ spatial_strengthen_scale = gr.Slider(
262
+ minimum=0,
263
+ maximum=2,
264
+ step=0.01,
265
+ value=0.15,
266
+ interactive=True,
267
+ label="Spatial Strengthen Scale",
268
+ )
269
+ spatial_weaken_scale = gr.Slider(
270
+ minimum=0,
271
+ maximum=1,
272
+ step=0.01,
273
+ value=0.001,
274
+ interactive=True,
275
+ label="Spatial Weaken Scale",
276
+ )
277
+ temporal_strengthen_scale = gr.Slider(
278
+ minimum=0,
279
+ maximum=2,
280
+ step=0.01,
281
+ value=0.15,
282
+ interactive=True,
283
+ label="Temporal Strengthen Scale",
284
+ )
285
+ temporal_weaken_scale = gr.Slider(
286
+ minimum=0,
287
+ maximum=1,
288
+ step=0.01,
289
+ value=0.001,
290
+ interactive=True,
291
+ label="Temporal Weaken Scale",
292
+ )
293
+
294
+ with gr.Row():
295
+ guidance_scale = gr.Slider(
296
+ minimum=0,
297
+ maximum=50,
298
+ step=0.5,
299
+ value=7.5,
300
+ interactive=True,
301
+ label="Guidance Scale",
302
+ )
303
+ rand_seed = gr.Slider(
304
+ minimum=0,
305
+ maximum=523451232531,
306
+ step=1,
307
+ value=0,
308
+ interactive=True,
309
+ label="Seed",
310
+ )
311
+
312
+ with gr.Tab("SketchPadHelper"):
313
+ with gr.Row():
314
+ user_board = gr.ImageMask(type="pil", label="Draw me")
315
+ out_board = gr.Image(type="pil", label="Processed bbox")
316
+ user_board.change(
317
+ out_board_cb, inputs=[user_board], outputs=[out_board]
318
+ )
319
+ with gr.Row():
320
+ text = "Hint: Utilize a black pen with the Draw Button to create a ``rough'' bbox. When you press the green ``Save Changes'' Button, the app calculates the minimum and maximum boundaries. Each ``Layer'', located at the bottom left of the pad, corresponds to one bounding box. Copy the returned value to the bbox textfield in the main tab."
321
+ gr.HTML(text)
322
+ with gr.Row():
323
+ out_label = gr.Label(label="Converted bboxes string")
324
+ user_board.change(
325
+ out_label_cb, inputs=[user_board], outputs=[out_label]
326
+ )
327
+
328
+ with gr.Column(scale=1):
329
+ gr.HTML(
330
+ '<span style="font-size: 20px; font-weight: bold">Generated Images</span>'
331
+ )
332
+ with gr.Row():
333
+ out_gen_1 = gr.Video(visible=True, show_label=False)
334
+
335
+ with gr.Row():
336
+ gr.Examples(
337
+ examples=[
338
+ [
339
+ "A clown fish swimming in a coral reef",
340
+ "0.5,0.35,1.0,0.65; 0.0,0.35,0.5,0.65;",
341
+ "0; 24;",
342
+ "1,2,3",
343
+ "123451232531",
344
+ "assets/gradio/fish-RL.mp4",
345
+ ],
346
+ [
347
+ "A cat is running on the grass",
348
+ "0.0,0.35,0.4,0.65; 0.6,0.35,1.0,0.65; 0.0,0.35,0.4,0.65;"
349
+ "0.6,0.35,1.0,0.65; 0.0,0.35,0.4,0.65;",
350
+ "0; 6; 12; 18; 24;",
351
+ "1,2",
352
+ "123451232530",
353
+ "assets/gradio/cat-LRLR.mp4",
354
+ ],
355
+ [
356
+ "A fish swimming in the ocean",
357
+ "0.0,0.0,0.1,0.1; 0.5,0.5,1.0,1.0;",
358
+ "0; 24;",
359
+ "1, 2",
360
+ "0",
361
+ "assets/gradio/fish-TL2BR.mp4"
362
+ ],
363
+ [
364
+ "A tiger walking alone down the street",
365
+ "0.0,0.0,0.1,0.1; 0.5,0.5,1.0,1.0;",
366
+ "0; 24;",
367
+ "1, 2",
368
+ "0",
369
+ "assets/gradio/tiger-TL2BR.mp4"
370
+ ],
371
+ [
372
+ "A white cat walking on the grass; A yellow dog walking on the grass;",
373
+ "0.7,0.4,1.0,0.65; 0.0,0.4,0.3,0.65;",
374
+ "0; 24;",
375
+ "1,2,3",
376
+ "123451232531",
377
+ "assets/gradio/Cat2Dog.mp4",
378
+ ],
379
+ ],
380
+ inputs=[text_prompt_tb, bboxes_tb, frame_tb, word_prompt_indices_tb, rand_seed,out_gen_1],
381
+ outputs=None,
382
+ fn=None,
383
+ cache_examples=False,
384
+ )
385
+
386
+ clear_btn.click(
387
+ clear_btn_fn,
388
+ inputs=[],
389
+ outputs=[text_prompt_tb, bboxes_tb, frame_tb, word_prompt_indices_tb],
390
+ queue=False,
391
+ )
392
+
393
+ gen_btn.click(
394
+ gen_btn_fn,
395
+ inputs=[
396
+ text_prompt_tb,
397
+ bboxes_tb,
398
+ frame_tb,
399
+ word_prompt_indices_tb,
400
+ trailing_length,
401
+ n_spatial_steps,
402
+ n_temporal_steps,
403
+ spatial_strengthen_scale,
404
+ spatial_weaken_scale,
405
+ temporal_strengthen_scale,
406
+ temporal_weaken_scale,
407
+ rand_seed,
408
+ ],
409
+ outputs=[out_gen_1],
410
+ queue=False,
411
+ )
412
+
413
+
414
+ if __name__ == "__main__":
415
+ main.launch(share=False)