MaxMilan1 commited on
Commit
95ecf9b
1 Parent(s): a2559bc

change app.py to original from V3D

Browse files
Files changed (1) hide show
  1. app.py +268 -79
app.py CHANGED
@@ -1,82 +1,271 @@
 
 
 
 
 
 
1
  import gradio as gr
2
- # from util.text_img import generate_image
3
- from util.v3d import generate_v3d, prep
4
-
5
- # Prepare the V3D model
6
- model, clip_model, ae_model, device, num_frames, num_steps, rembg_session, output_folder = prep()
7
-
8
- _TITLE = "Shoe Generator"
9
- with gr.Blocks(_TITLE) as ShoeGen:
10
- # with gr.Tab("Text to Image Generator"):
11
- # with gr.Row():
12
- # with gr.Column():
13
- # prompt = gr.Textbox(label="Enter a discription of a shoe")
14
- # # neg_prompt = gr.Textbox(label="Enter a negative prompt", value="low quality, watermark, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draft, closed eyes, text, logo")
15
- # button_gen = gr.Button("Generate Image")
16
- # with gr.Column():
17
- # with gr.Tab("With Background"):
18
- # image = gr.Image(label="Generated Image", show_download_button=True, show_label=False)
19
- # with gr.Tab("Without Background"):
20
- # image_nobg = gr.Image(label="Generated Image", show_download_button=True, show_label=False)
21
-
22
- # button_gen.click(generate_image, inputs=[prompt], outputs=[image, image_nobg])
23
-
24
- with gr.Tab("Image to Video Generator (V3D)"):
25
- with gr.Row(equal_height=True):
26
- with gr.Column():
27
- input_image = gr.Image(value=None, label="Input Image")
28
-
29
- border_ratio_slider = gr.Slider(
30
- value=0.3,
31
- label="Border Ratio",
32
- minimum=0.05,
33
- maximum=0.5,
34
- step=0.05,
35
- )
36
- decoding_t_slider = gr.Slider(
37
- value=1,
38
- label="Number of Decoding frames",
39
- minimum=1,
40
- maximum=num_frames,
41
- step=1,
42
- )
43
- min_guidance_slider = gr.Slider(
44
- value=3.5,
45
- label="Min CFG Value",
46
- minimum=0.05,
47
- maximum=0.5,
48
- step=0.05,
49
- )
50
- max_guidance_slider = gr.Slider(
51
- value=3.5,
52
- label="Max CFG Value",
53
- minimum=0.05,
54
- maximum=0.5,
55
- step=0.05,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  )
57
- run_button = gr.Button(value="Run V3D")
58
-
59
- with gr.Column():
60
- output_video = gr.Video(value=None, label="Output Orbit Video")
61
-
62
- run_button.click(generate_v3d,
63
- inputs=[
64
- input_image,
65
- model,
66
- clip_model,
67
- ae_model,
68
- num_frames,
69
- num_steps,
70
- int(decoding_t_slider),
71
- border_ratio_slider,
72
- False,
73
- rembg_session,
74
- output_folder,
75
- min_guidance_slider,
76
- max_guidance_slider,
77
- device,
78
- ],
79
- outputs=[output_video],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  )
81
-
82
- ShoeGen.launch()
 
 
1
+ # TODO
2
+ import numpy as np
3
+ import argparse
4
+ import torch
5
+ from torchvision.utils import make_grid
6
+ import tempfile
7
  import gradio as gr
8
+ from omegaconf import OmegaConf
9
+ from einops import rearrange
10
+ from scripts.pub.V3D_512 import (
11
+ sample_one,
12
+ get_batch,
13
+ get_unique_embedder_keys_from_conditioner,
14
+ load_model,
15
+ )
16
+ from sgm.util import default, instantiate_from_config
17
+ from safetensors.torch import load_file as load_safetensors
18
+ from PIL import Image
19
+ from kiui.op import recenter
20
+ from torchvision.transforms import ToTensor
21
+ from einops import rearrange, repeat
22
+ import rembg
23
+ import os
24
+ from glob import glob
25
+ from mediapy import write_video
26
+ from pathlib import Path
27
+ import spaces
28
+ from huggingface_hub import hf_hub_download
29
+ import imageio
30
+
31
+ import cv2
32
+
33
+
34
+ @spaces.GPU
35
+ def do_sample(
36
+ image,
37
+ num_frames,
38
+ num_steps,
39
+ decoding_t,
40
+ border_ratio,
41
+ ignore_alpha,
42
+ output_folder,
43
+ seed,
44
+ ):
45
+ # if image.mode == "RGBA":
46
+ # image = image.convert("RGB")
47
+ torch.manual_seed(seed)
48
+ image = Image.fromarray(image)
49
+ w, h = image.size
50
+
51
+ if border_ratio > 0:
52
+ if image.mode != "RGBA" or ignore_alpha:
53
+ image = image.convert("RGB")
54
+ image = np.asarray(image)
55
+ carved_image = rembg.remove(image, session=rembg_session) # [H, W, 4]
56
+ else:
57
+ image = np.asarray(image)
58
+ carved_image = image
59
+ mask = carved_image[..., -1] > 0
60
+ image = recenter(carved_image, mask, border_ratio=border_ratio)
61
+ image = image.astype(np.float32) / 255.0
62
+ if image.shape[-1] == 4:
63
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
64
+ image = Image.fromarray((image * 255).astype(np.uint8))
65
+ else:
66
+ print("Ignore border ratio")
67
+ image = image.resize((512, 512))
68
+
69
+ image = ToTensor()(image)
70
+ image = image * 2.0 - 1.0
71
+
72
+ image = image.unsqueeze(0).to(device)
73
+ H, W = image.shape[2:]
74
+ assert image.shape[1] == 3
75
+ F = 8
76
+ C = 4
77
+ shape = (num_frames, C, H // F, W // F)
78
+
79
+ value_dict = {}
80
+ value_dict["motion_bucket_id"] = 0
81
+ value_dict["fps_id"] = 0
82
+ value_dict["cond_aug"] = 0.05
83
+ value_dict["cond_frames_without_noise"] = clip_model(image)
84
+ value_dict["cond_frames"] = ae_model.encode(image)
85
+ value_dict["cond_frames"] += 0.05 * torch.randn_like(value_dict["cond_frames"])
86
+ value_dict["cond_aug"] = 0.05
87
+
88
+ print(device)
89
+ with torch.no_grad():
90
+ with torch.autocast(device_type="cuda"):
91
+ batch, batch_uc = get_batch(
92
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
93
+ value_dict,
94
+ [1, num_frames],
95
+ T=num_frames,
96
+ device=device,
97
+ )
98
+ c, uc = model.conditioner.get_unconditional_conditioning(
99
+ batch,
100
+ batch_uc=batch_uc,
101
+ force_uc_zero_embeddings=[
102
+ "cond_frames",
103
+ "cond_frames_without_noise",
104
+ ],
105
+ )
106
+
107
+ for k in ["crossattn", "concat"]:
108
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
109
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
110
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
111
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
112
+
113
+ randn = torch.randn(shape, device=device)
114
+ randn = randn.to(device)
115
+
116
+ additional_model_inputs = {}
117
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
118
+ 2, num_frames
119
+ ).to(device)
120
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
121
+
122
+ def denoiser(input, sigma, c):
123
+ return model.denoiser(
124
+ model.model, input, sigma, c, **additional_model_inputs
125
  )
126
+
127
+ samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
128
+ model.en_and_decode_n_samples_a_time = decoding_t
129
+ samples_x = model.decode_first_stage(samples_z)
130
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
131
+
132
+ os.makedirs(output_folder, exist_ok=True)
133
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
134
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
135
+
136
+ frames = (
137
+ (rearrange(samples, "t c h w -> t h w c") * 255)
138
+ .cpu()
139
+ .numpy()
140
+ .astype(np.uint8)
141
+ )
142
+ # write_video(video_path, frames, fps=6)
143
+ # writer = cv2.VideoWriter(
144
+ # video_path,
145
+ # cv2.VideoWriter_fourcc("m", "p", "4", "v"),
146
+ # 6,
147
+ # (frames.shape[-1], frames.shape[-2]),
148
+ # )
149
+ # for fr in frames:
150
+ # writer.write(cv2.cvtColor(fr, cv2.COLOR_RGB2BGR))
151
+ # writer.release()
152
+ imageio.mimwrite(video_path, frames, fps=6)
153
+
154
+ return video_path
155
+
156
+
157
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
158
+
159
+ # download
160
+ V3D_ckpt_path = hf_hub_download(repo_id="heheyas/V3D", filename="V3D.ckpt")
161
+ svd_xt_ckpt_path = hf_hub_download(
162
+ repo_id="stabilityai/stable-video-diffusion-img2vid-xt",
163
+ filename="svd_xt.safetensors",
164
+ )
165
+
166
+ model_config = "./scripts/pub/configs/V3D_512.yaml"
167
+ num_frames = OmegaConf.load(
168
+ model_config
169
+ ).model.params.sampler_config.params.guider_config.params.num_frames
170
+ print("Detected num_frames:", num_frames)
171
+ # num_steps = default(num_steps, 25)
172
+ num_steps = 25
173
+ output_folder = "outputs/V3D_512"
174
+
175
+ sd = load_safetensors(svd_xt_ckpt_path)
176
+ clip_model_config = OmegaConf.load("./configs/embedder/clip_image.yaml")
177
+ clip_model = instantiate_from_config(clip_model_config).eval()
178
+ clip_sd = dict()
179
+ for k, v in sd.items():
180
+ if "conditioner.embedders.0" in k:
181
+ clip_sd[k.replace("conditioner.embedders.0.", "")] = v
182
+ clip_model.load_state_dict(clip_sd)
183
+ clip_model = clip_model.to(device)
184
+
185
+ ae_model_config = OmegaConf.load("./configs/ae/video.yaml")
186
+ ae_model = instantiate_from_config(ae_model_config).eval()
187
+ encoder_sd = dict()
188
+ for k, v in sd.items():
189
+ if "first_stage_model" in k:
190
+ encoder_sd[k.replace("first_stage_model.", "")] = v
191
+ ae_model.load_state_dict(encoder_sd)
192
+ ae_model = ae_model.to(device)
193
+ rembg_session = rembg.new_session()
194
+
195
+ model, _ = load_model(
196
+ model_config,
197
+ device,
198
+ num_frames,
199
+ num_steps,
200
+ min_cfg=3.5,
201
+ max_cfg=3.5,
202
+ ckpt_path=V3D_ckpt_path,
203
+ )
204
+ model = model.to(device)
205
+
206
+ with gr.Blocks(title="V3D", theme=gr.themes.Monochrome()) as demo:
207
+ with gr.Row(equal_height=True):
208
+ with gr.Column():
209
+ input_image = gr.Image(value=None, label="Input Image")
210
+
211
+ border_ratio_slider = gr.Slider(
212
+ value=0.3,
213
+ label="Border Ratio",
214
+ minimum=0.05,
215
+ maximum=0.5,
216
+ step=0.05,
217
+ )
218
+ seed_input = gr.Number(value=42)
219
+ decoding_t_slider = gr.Slider(
220
+ value=1,
221
+ label="Number of Decoding frames",
222
+ minimum=1,
223
+ maximum=num_frames,
224
+ step=1,
225
+ )
226
+ min_guidance_slider = gr.Slider(
227
+ value=3.5,
228
+ label="Min CFG Value",
229
+ minimum=0.05,
230
+ maximum=5,
231
+ step=0.05,
232
+ )
233
+ max_guidance_slider = gr.Slider(
234
+ value=3.5,
235
+ label="Max CFG Value",
236
+ minimum=0.05,
237
+ maximum=5,
238
+ step=0.05,
239
+ )
240
+ run_button = gr.Button(value="Run V3D")
241
+
242
+ with gr.Column():
243
+ output_video = gr.Video(value=None, label="Output Orbit Video")
244
+
245
+ @run_button.click(
246
+ inputs=[
247
+ input_image,
248
+ border_ratio_slider,
249
+ min_guidance_slider,
250
+ max_guidance_slider,
251
+ decoding_t_slider,
252
+ seed_input,
253
+ ],
254
+ outputs=[output_video],
255
+ )
256
+ def _(image, border_ratio, min_guidance, max_guidance, decoding_t, seed):
257
+ model.sampler.guider.max_scale = max_guidance
258
+ model.sampler.guider.min_scale = min_guidance
259
+ return do_sample(
260
+ image,
261
+ num_frames,
262
+ num_steps,
263
+ int(decoding_t),
264
+ border_ratio,
265
+ False,
266
+ output_folder,
267
+ seed,
268
  )
269
+
270
+
271
+ demo.launch()