multimodalart HF staff commited on
Commit
a7729a1
·
1 Parent(s): f3f18e3

Delete simple_video_sample.py

Browse files
Files changed (1) hide show
  1. simple_video_sample.py +0 -278
simple_video_sample.py DELETED
@@ -1,278 +0,0 @@
1
- import math
2
- import os
3
- from glob import glob
4
- from pathlib import Path
5
- from typing import Optional
6
-
7
- import cv2
8
- import numpy as np
9
- import torch
10
- from einops import rearrange, repeat
11
- from fire import Fire
12
- from omegaconf import OmegaConf
13
- from PIL import Image
14
- from torchvision.transforms import ToTensor
15
-
16
- from scripts.util.detection.nsfw_and_watermark_dectection import \
17
- DeepFloydDataFiltering
18
- from sgm.inference.helpers import embed_watermark
19
- from sgm.util import default, instantiate_from_config
20
-
21
-
22
- def sample(
23
- input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
24
- num_frames: Optional[int] = None,
25
- num_steps: Optional[int] = None,
26
- version: str = "svd",
27
- fps_id: int = 6,
28
- motion_bucket_id: int = 127,
29
- cond_aug: float = 0.02,
30
- seed: int = 23,
31
- decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
32
- device: str = "cuda",
33
- output_folder: Optional[str] = None,
34
- ):
35
- """
36
- Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
37
- image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
38
- """
39
-
40
- if version == "svd":
41
- num_frames = default(num_frames, 14)
42
- num_steps = default(num_steps, 25)
43
- output_folder = default(output_folder, "outputs/simple_video_sample/svd/")
44
- model_config = "scripts/sampling/configs/svd.yaml"
45
- elif version == "svd_xt":
46
- num_frames = default(num_frames, 25)
47
- num_steps = default(num_steps, 30)
48
- output_folder = default(output_folder, "outputs/simple_video_sample/svd_xt/")
49
- model_config = "scripts/sampling/configs/svd_xt.yaml"
50
- elif version == "svd_image_decoder":
51
- num_frames = default(num_frames, 14)
52
- num_steps = default(num_steps, 25)
53
- output_folder = default(
54
- output_folder, "outputs/simple_video_sample/svd_image_decoder/"
55
- )
56
- model_config = "scripts/sampling/configs/svd_image_decoder.yaml"
57
- elif version == "svd_xt_image_decoder":
58
- num_frames = default(num_frames, 25)
59
- num_steps = default(num_steps, 30)
60
- output_folder = default(
61
- output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/"
62
- )
63
- model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml"
64
- else:
65
- raise ValueError(f"Version {version} does not exist.")
66
-
67
- model, filter = load_model(
68
- model_config,
69
- device,
70
- num_frames,
71
- num_steps,
72
- )
73
- torch.manual_seed(seed)
74
-
75
- path = Path(input_path)
76
- all_img_paths = []
77
- if path.is_file():
78
- if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
79
- all_img_paths = [input_path]
80
- else:
81
- raise ValueError("Path is not valid image file.")
82
- elif path.is_dir():
83
- all_img_paths = sorted(
84
- [
85
- f
86
- for f in path.iterdir()
87
- if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
88
- ]
89
- )
90
- if len(all_img_paths) == 0:
91
- raise ValueError("Folder does not contain any images.")
92
- else:
93
- raise ValueError
94
-
95
- for input_img_path in all_img_paths:
96
- with Image.open(input_img_path) as image:
97
- if image.mode == "RGBA":
98
- image = image.convert("RGB")
99
- w, h = image.size
100
-
101
- if h % 64 != 0 or w % 64 != 0:
102
- width, height = map(lambda x: x - x % 64, (w, h))
103
- image = image.resize((width, height))
104
- print(
105
- f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
106
- )
107
-
108
- image = ToTensor()(image)
109
- image = image * 2.0 - 1.0
110
-
111
- image = image.unsqueeze(0).to(device)
112
- H, W = image.shape[2:]
113
- assert image.shape[1] == 3
114
- F = 8
115
- C = 4
116
- shape = (num_frames, C, H // F, W // F)
117
- if (H, W) != (576, 1024):
118
- print(
119
- "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
120
- )
121
- if motion_bucket_id > 255:
122
- print(
123
- "WARNING: High motion bucket! This may lead to suboptimal performance."
124
- )
125
-
126
- if fps_id < 5:
127
- print("WARNING: Small fps value! This may lead to suboptimal performance.")
128
-
129
- if fps_id > 30:
130
- print("WARNING: Large fps value! This may lead to suboptimal performance.")
131
-
132
- value_dict = {}
133
- value_dict["motion_bucket_id"] = motion_bucket_id
134
- value_dict["fps_id"] = fps_id
135
- value_dict["cond_aug"] = cond_aug
136
- value_dict["cond_frames_without_noise"] = image
137
- value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
138
- value_dict["cond_aug"] = cond_aug
139
-
140
- with torch.no_grad():
141
- with torch.autocast(device):
142
- batch, batch_uc = get_batch(
143
- get_unique_embedder_keys_from_conditioner(model.conditioner),
144
- value_dict,
145
- [1, num_frames],
146
- T=num_frames,
147
- device=device,
148
- )
149
- c, uc = model.conditioner.get_unconditional_conditioning(
150
- batch,
151
- batch_uc=batch_uc,
152
- force_uc_zero_embeddings=[
153
- "cond_frames",
154
- "cond_frames_without_noise",
155
- ],
156
- )
157
-
158
- for k in ["crossattn", "concat"]:
159
- uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
160
- uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
161
- c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
162
- c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
163
-
164
- randn = torch.randn(shape, device=device)
165
-
166
- additional_model_inputs = {}
167
- additional_model_inputs["image_only_indicator"] = torch.zeros(
168
- 2, num_frames
169
- ).to(device)
170
- additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
171
-
172
- def denoiser(input, sigma, c):
173
- return model.denoiser(
174
- model.model, input, sigma, c, **additional_model_inputs
175
- )
176
-
177
- samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
178
- model.en_and_decode_n_samples_a_time = decoding_t
179
- samples_x = model.decode_first_stage(samples_z)
180
- samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
181
-
182
- os.makedirs(output_folder, exist_ok=True)
183
- base_count = len(glob(os.path.join(output_folder, "*.mp4")))
184
- video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
185
- writer = cv2.VideoWriter(
186
- video_path,
187
- cv2.VideoWriter_fourcc(*"MP4V"),
188
- fps_id + 1,
189
- (samples.shape[-1], samples.shape[-2]),
190
- )
191
-
192
- samples = embed_watermark(samples)
193
- samples = filter(samples)
194
- vid = (
195
- (rearrange(samples, "t c h w -> t h w c") * 255)
196
- .cpu()
197
- .numpy()
198
- .astype(np.uint8)
199
- )
200
- for frame in vid:
201
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
202
- writer.write(frame)
203
- writer.release()
204
-
205
-
206
- def get_unique_embedder_keys_from_conditioner(conditioner):
207
- return list(set([x.input_key for x in conditioner.embedders]))
208
-
209
-
210
- def get_batch(keys, value_dict, N, T, device):
211
- batch = {}
212
- batch_uc = {}
213
-
214
- for key in keys:
215
- if key == "fps_id":
216
- batch[key] = (
217
- torch.tensor([value_dict["fps_id"]])
218
- .to(device)
219
- .repeat(int(math.prod(N)))
220
- )
221
- elif key == "motion_bucket_id":
222
- batch[key] = (
223
- torch.tensor([value_dict["motion_bucket_id"]])
224
- .to(device)
225
- .repeat(int(math.prod(N)))
226
- )
227
- elif key == "cond_aug":
228
- batch[key] = repeat(
229
- torch.tensor([value_dict["cond_aug"]]).to(device),
230
- "1 -> b",
231
- b=math.prod(N),
232
- )
233
- elif key == "cond_frames":
234
- batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
235
- elif key == "cond_frames_without_noise":
236
- batch[key] = repeat(
237
- value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
238
- )
239
- else:
240
- batch[key] = value_dict[key]
241
-
242
- if T is not None:
243
- batch["num_video_frames"] = T
244
-
245
- for key in batch.keys():
246
- if key not in batch_uc and isinstance(batch[key], torch.Tensor):
247
- batch_uc[key] = torch.clone(batch[key])
248
- return batch, batch_uc
249
-
250
-
251
- def load_model(
252
- config: str,
253
- device: str,
254
- num_frames: int,
255
- num_steps: int,
256
- ):
257
- config = OmegaConf.load(config)
258
- if device == "cuda":
259
- config.model.params.conditioner_config.params.emb_models[
260
- 0
261
- ].params.open_clip_embedding_config.params.init_device = device
262
-
263
- config.model.params.sampler_config.params.num_steps = num_steps
264
- config.model.params.sampler_config.params.guider_config.params.num_frames = (
265
- num_frames
266
- )
267
- if device == "cuda":
268
- with torch.device(device):
269
- model = instantiate_from_config(config.model).to(device).eval()
270
- else:
271
- model = instantiate_from_config(config.model).to(device).eval()
272
-
273
- filter = DeepFloydDataFiltering(verbose=False, device=device)
274
- return model, filter
275
-
276
-
277
- if __name__ == "__main__":
278
- Fire(sample)