Spaces:
Running
on
Zero
Running
on
Zero
Update demo.py
Browse files
demo.py
CHANGED
@@ -4,8 +4,7 @@ import torch
|
|
4 |
import argparse
|
5 |
import spaces
|
6 |
import torchvision
|
7 |
-
|
8 |
-
|
9 |
from pipelines.pipeline_videogen import VideoGenPipeline
|
10 |
from diffusers.schedulers import DDIMScheduler
|
11 |
from diffusers.models import AutoencoderKL
|
@@ -27,7 +26,15 @@ from copy import deepcopy
|
|
27 |
import requests
|
28 |
from datetime import datetime
|
29 |
import random
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
parser = argparse.ArgumentParser()
|
32 |
parser.add_argument("--config", type=str, default="./configs/sample.yaml")
|
33 |
args = parser.parse_args()
|
@@ -35,7 +42,7 @@ args = OmegaConf.load(args.config)
|
|
35 |
|
36 |
torch.set_grad_enabled(False)
|
37 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
38 |
-
dtype = torch.float16
|
39 |
|
40 |
unet = get_models(args).to(device, dtype=dtype)
|
41 |
|
@@ -49,15 +56,14 @@ else:
|
|
49 |
vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
|
50 |
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
|
51 |
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
|
52 |
-
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
|
53 |
|
54 |
-
# set eval mode
|
55 |
unet.eval()
|
56 |
vae.eval()
|
57 |
text_encoder.eval()
|
58 |
|
59 |
-
basedir
|
60 |
-
savedir
|
61 |
savedir_sample = os.path.join(savedir, "sample")
|
62 |
os.makedirs(savedir, exist_ok=True)
|
63 |
|
@@ -66,56 +72,55 @@ def update_and_resize_image(input_image_path, height_slider, width_slider):
|
|
66 |
pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
|
67 |
else:
|
68 |
pil_image = Image.open(input_image_path).convert('RGB')
|
69 |
-
|
70 |
original_width, original_height = pil_image.size
|
71 |
|
72 |
if original_height == height_slider and original_width == width_slider:
|
73 |
return gr.Image(value=np.array(pil_image))
|
74 |
-
|
75 |
ratio1 = height_slider / original_height
|
76 |
ratio2 = width_slider / original_width
|
77 |
-
|
78 |
if ratio1 > ratio2:
|
79 |
new_width = int(original_width * ratio1)
|
80 |
new_height = int(original_height * ratio1)
|
81 |
else:
|
82 |
new_width = int(original_width * ratio2)
|
83 |
new_height = int(original_height * ratio2)
|
84 |
-
|
85 |
pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
|
86 |
-
|
87 |
left = (new_width - width_slider) / 2
|
88 |
top = (new_height - height_slider) / 2
|
89 |
right = left + width_slider
|
90 |
bottom = top + height_slider
|
91 |
-
|
92 |
pil_image = pil_image.crop((left, top, right, bottom))
|
93 |
-
|
94 |
-
return gr.Image(value=np.array(pil_image))
|
95 |
|
|
|
96 |
|
97 |
def update_textbox_and_save_image(input_image, height_slider, width_slider):
|
98 |
pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
|
99 |
|
100 |
original_width, original_height = pil_image.size
|
101 |
-
|
102 |
ratio1 = height_slider / original_height
|
103 |
ratio2 = width_slider / original_width
|
104 |
-
|
105 |
if ratio1 > ratio2:
|
106 |
new_width = int(original_width * ratio1)
|
107 |
new_height = int(original_height * ratio1)
|
108 |
else:
|
109 |
new_width = int(original_width * ratio2)
|
110 |
new_height = int(original_height * ratio2)
|
111 |
-
|
112 |
pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
|
113 |
-
|
114 |
left = (new_width - width_slider) / 2
|
115 |
top = (new_height - height_slider) / 2
|
116 |
right = left + width_slider
|
117 |
bottom = top + height_slider
|
118 |
-
|
119 |
pil_image = pil_image.crop((left, top, right, bottom))
|
120 |
|
121 |
img_path = os.path.join(savedir, "input_image.png")
|
@@ -130,10 +135,9 @@ def prepare_image(image, vae, transform_video, device, dtype=torch.float16):
|
|
130 |
image = image.unsqueeze(2)
|
131 |
return image
|
132 |
|
133 |
-
|
134 |
@spaces.GPU
|
135 |
-
def gen_video(input_image,
|
136 |
-
|
137 |
torch.manual_seed(seed)
|
138 |
|
139 |
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
|
@@ -147,7 +151,6 @@ def gen_video(input_image, prompt, negative_prompt, diffusion_step, height, widt
|
|
147 |
tokenizer=tokenizer,
|
148 |
scheduler=scheduler,
|
149 |
unet=unet).to(device)
|
150 |
-
# videogen_pipeline.enable_xformers_memory_efficient_attention()
|
151 |
|
152 |
transform_video = transforms.Compose([
|
153 |
video_transforms.ToTensorVideo(),
|
@@ -160,33 +163,25 @@ def gen_video(input_image, prompt, negative_prompt, diffusion_step, height, widt
|
|
160 |
base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
|
161 |
|
162 |
if use_dctinit:
|
163 |
-
# filter params
|
164 |
-
print("Using DCT!")
|
165 |
base_content_repeat = repeat(base_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
|
166 |
-
|
167 |
-
# define filter
|
168 |
freq_filter = dct_low_pass_filter(dct_coefficients=base_content, percentage=dct_coefficients)
|
169 |
-
|
170 |
-
noise = torch.randn(1, 4, 15, 40, 64).to(device)
|
171 |
|
172 |
-
|
173 |
-
diffuse_timesteps = torch.full((1,),int(noise_level))
|
174 |
diffuse_timesteps = diffuse_timesteps.long()
|
175 |
-
|
176 |
-
# 3d content
|
177 |
base_content_noise = scheduler.add_noise(
|
178 |
original_samples=base_content_repeat.to(device),
|
179 |
noise=noise,
|
180 |
timesteps=diffuse_timesteps.to(device))
|
181 |
-
|
182 |
-
# 3d content
|
183 |
latents = exchanged_mixed_dct_freq(noise=noise,
|
184 |
base_content=base_content_noise,
|
185 |
LPF_3d=freq_filter).to(dtype=torch.float16)
|
186 |
-
|
187 |
base_content = base_content.to(dtype=torch.float16)
|
188 |
|
189 |
-
videos = videogen_pipeline(
|
190 |
negative_prompt=negative_prompt,
|
191 |
latents=latents if use_dctinit else None,
|
192 |
base_content=base_content,
|
@@ -197,13 +192,11 @@ def gen_video(input_image, prompt, negative_prompt, diffusion_step, height, widt
|
|
197 |
guidance_scale=scfg_scale,
|
198 |
motion_bucket_id=100-motion_bucket_id,
|
199 |
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
|
200 |
-
|
201 |
save_path = args.save_img_path + 'temp' + '.mp4'
|
202 |
-
# torchvision.io.write_video(save_path, videos[0], fps=8, video_codec='h264', options={'crf': '10'})
|
203 |
imageio.mimwrite(save_path, videos[0], fps=8, quality=7)
|
204 |
return save_path
|
205 |
|
206 |
-
|
207 |
if not os.path.exists(args.save_img_path):
|
208 |
os.makedirs(args.save_img_path)
|
209 |
|
@@ -215,11 +208,9 @@ footer {
|
|
215 |
|
216 |
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
|
217 |
|
218 |
-
|
219 |
-
|
220 |
with gr.Column(variant="panel"):
|
221 |
with gr.Row():
|
222 |
-
prompt_textbox = gr.Textbox(label="Prompt", lines=1)
|
223 |
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1)
|
224 |
|
225 |
with gr.Row(equal_height=False):
|
@@ -231,13 +222,7 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
|
|
231 |
generate_button = gr.Button(value="Generate", variant='primary')
|
232 |
|
233 |
with gr.Accordion("Advanced options", open=False):
|
234 |
-
|
235 |
-
"""
|
236 |
-
- Input image can be specified using the "Input Image URL" text box or uploaded by clicking or dragging the image to the "Input Image" box.
|
237 |
-
- Input image will be resized and/or center cropped to a given resolution (320 x 512) automatically.
|
238 |
-
- After setting the input image path, press the "Preview" button to visualize the resized input image.
|
239 |
-
"""
|
240 |
-
)
|
241 |
with gr.Column():
|
242 |
with gr.Row():
|
243 |
input_image_path = gr.Textbox(label="Input Image URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
|
@@ -248,9 +233,6 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
|
|
248 |
|
249 |
with gr.Row():
|
250 |
seed_textbox = gr.Slider(label="Seed", value=100, minimum=1, maximum=int(1e8), step=1, interactive=True)
|
251 |
-
# seed_textbox = gr.Textbox(label="Seed", value=100)
|
252 |
-
# seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
253 |
-
# seed_button.click(fn=lambda: gr.Textbox(value=random.randint(1, int(1e8))), inputs=[], outputs=[seed_textbox])
|
254 |
|
255 |
with gr.Row():
|
256 |
height = gr.Slider(label="Height", value=320, minimum=0, maximum=512, step=16, interactive=False)
|
@@ -268,28 +250,6 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
|
|
268 |
preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
|
269 |
input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
|
270 |
|
271 |
-
EXAMPLES = [
|
272 |
-
["./example/red_panda_eating_bamboo/0.jpg", "red panda eating bamboo" , "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
|
273 |
-
["./example/fireworks/0.jpg", "fireworks" , "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
|
274 |
-
["./example/flowers_swaying/0.jpg", "flowers swaying" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
|
275 |
-
["./example/girl_walking_on_the_beach/0.jpg", "girl walking on the beach" , "low quality, background changing", 50, 320, 512, 7.5, True, 0.25, 995, 10, 49494220],
|
276 |
-
["./example/house_rotating/0.jpg", "house rotating" , "low quality", 50, 320, 512, 7.5, True, 0.23, 985, 10, 46640174],
|
277 |
-
["./example/people_runing/0.jpg", "people runing" , "low quality, background changing", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
|
278 |
-
["./example/shark_swimming/0.jpg", "shark swimming" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 32947978],
|
279 |
-
["./example/car_moving/0.jpg", "car moving" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 75469653],
|
280 |
-
["./example/windmill_turning/0.jpg", "windmill turning" , "background changing", 50, 320, 512, 7.5, True, 0.21, 975, 10, 89378613],
|
281 |
-
]
|
282 |
-
|
283 |
-
|
284 |
-
examples = gr.Examples(
|
285 |
-
examples = EXAMPLES,
|
286 |
-
fn = gen_video,
|
287 |
-
inputs=[input_image, prompt_textbox, negative_prompt_textbox, sample_step_slider, height, width, txt_cfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed_textbox],
|
288 |
-
outputs=[result_video],
|
289 |
-
cache_examples=True,
|
290 |
-
# cache_examples="lazy",
|
291 |
-
)
|
292 |
-
|
293 |
generate_button.click(
|
294 |
fn=gen_video,
|
295 |
inputs=[
|
@@ -309,4 +269,4 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
|
|
309 |
outputs=[result_video]
|
310 |
)
|
311 |
|
312 |
-
demo.launch(debug=False, share=True)
|
|
|
4 |
import argparse
|
5 |
import spaces
|
6 |
import torchvision
|
7 |
+
from transformers import pipeline
|
|
|
8 |
from pipelines.pipeline_videogen import VideoGenPipeline
|
9 |
from diffusers.schedulers import DDIMScheduler
|
10 |
from diffusers.models import AutoencoderKL
|
|
|
26 |
import requests
|
27 |
from datetime import datetime
|
28 |
import random
|
29 |
+
|
30 |
+
# 번역 파이프라인 생성
|
31 |
+
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
|
32 |
+
|
33 |
+
# 번역 함수
|
34 |
+
def translate_prompt(korean_prompt):
|
35 |
+
translation = translator(korean_prompt, max_length=512)
|
36 |
+
return translation[0]['translation_text']
|
37 |
+
|
38 |
parser = argparse.ArgumentParser()
|
39 |
parser.add_argument("--config", type=str, default="./configs/sample.yaml")
|
40 |
args = parser.parse_args()
|
|
|
42 |
|
43 |
torch.set_grad_enabled(False)
|
44 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
45 |
+
dtype = torch.float16
|
46 |
|
47 |
unet = get_models(args).to(device, dtype=dtype)
|
48 |
|
|
|
56 |
vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
|
57 |
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
|
58 |
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
|
59 |
+
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
|
60 |
|
|
|
61 |
unet.eval()
|
62 |
vae.eval()
|
63 |
text_encoder.eval()
|
64 |
|
65 |
+
basedir = os.getcwd()
|
66 |
+
savedir = os.path.join(basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
|
67 |
savedir_sample = os.path.join(savedir, "sample")
|
68 |
os.makedirs(savedir, exist_ok=True)
|
69 |
|
|
|
72 |
pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
|
73 |
else:
|
74 |
pil_image = Image.open(input_image_path).convert('RGB')
|
75 |
+
|
76 |
original_width, original_height = pil_image.size
|
77 |
|
78 |
if original_height == height_slider and original_width == width_slider:
|
79 |
return gr.Image(value=np.array(pil_image))
|
80 |
+
|
81 |
ratio1 = height_slider / original_height
|
82 |
ratio2 = width_slider / original_width
|
83 |
+
|
84 |
if ratio1 > ratio2:
|
85 |
new_width = int(original_width * ratio1)
|
86 |
new_height = int(original_height * ratio1)
|
87 |
else:
|
88 |
new_width = int(original_width * ratio2)
|
89 |
new_height = int(original_height * ratio2)
|
90 |
+
|
91 |
pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
|
92 |
+
|
93 |
left = (new_width - width_slider) / 2
|
94 |
top = (new_height - height_slider) / 2
|
95 |
right = left + width_slider
|
96 |
bottom = top + height_slider
|
97 |
+
|
98 |
pil_image = pil_image.crop((left, top, right, bottom))
|
|
|
|
|
99 |
|
100 |
+
return gr.Image(value=np.array(pil_image))
|
101 |
|
102 |
def update_textbox_and_save_image(input_image, height_slider, width_slider):
|
103 |
pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
|
104 |
|
105 |
original_width, original_height = pil_image.size
|
106 |
+
|
107 |
ratio1 = height_slider / original_height
|
108 |
ratio2 = width_slider / original_width
|
109 |
+
|
110 |
if ratio1 > ratio2:
|
111 |
new_width = int(original_width * ratio1)
|
112 |
new_height = int(original_height * ratio1)
|
113 |
else:
|
114 |
new_width = int(original_width * ratio2)
|
115 |
new_height = int(original_height * ratio2)
|
116 |
+
|
117 |
pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
|
118 |
+
|
119 |
left = (new_width - width_slider) / 2
|
120 |
top = (new_height - height_slider) / 2
|
121 |
right = left + width_slider
|
122 |
bottom = top + height_slider
|
123 |
+
|
124 |
pil_image = pil_image.crop((left, top, right, bottom))
|
125 |
|
126 |
img_path = os.path.join(savedir, "input_image.png")
|
|
|
135 |
image = image.unsqueeze(2)
|
136 |
return image
|
137 |
|
|
|
138 |
@spaces.GPU
|
139 |
+
def gen_video(input_image, korean_prompt, negative_prompt, diffusion_step, height, width, scfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed):
|
140 |
+
english_prompt = translate_prompt(korean_prompt)
|
141 |
torch.manual_seed(seed)
|
142 |
|
143 |
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
|
|
|
151 |
tokenizer=tokenizer,
|
152 |
scheduler=scheduler,
|
153 |
unet=unet).to(device)
|
|
|
154 |
|
155 |
transform_video = transforms.Compose([
|
156 |
video_transforms.ToTensorVideo(),
|
|
|
163 |
base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
|
164 |
|
165 |
if use_dctinit:
|
|
|
|
|
166 |
base_content_repeat = repeat(base_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
|
|
|
|
|
167 |
freq_filter = dct_low_pass_filter(dct_coefficients=base_content, percentage=dct_coefficients)
|
|
|
|
|
168 |
|
169 |
+
noise = torch.randn(1, 4, 15, 40, 64).to(device)
|
170 |
+
diffuse_timesteps = torch.full((1,), int(noise_level))
|
171 |
diffuse_timesteps = diffuse_timesteps.long()
|
172 |
+
|
|
|
173 |
base_content_noise = scheduler.add_noise(
|
174 |
original_samples=base_content_repeat.to(device),
|
175 |
noise=noise,
|
176 |
timesteps=diffuse_timesteps.to(device))
|
177 |
+
|
|
|
178 |
latents = exchanged_mixed_dct_freq(noise=noise,
|
179 |
base_content=base_content_noise,
|
180 |
LPF_3d=freq_filter).to(dtype=torch.float16)
|
181 |
+
|
182 |
base_content = base_content.to(dtype=torch.float16)
|
183 |
|
184 |
+
videos = videogen_pipeline(english_prompt,
|
185 |
negative_prompt=negative_prompt,
|
186 |
latents=latents if use_dctinit else None,
|
187 |
base_content=base_content,
|
|
|
192 |
guidance_scale=scfg_scale,
|
193 |
motion_bucket_id=100-motion_bucket_id,
|
194 |
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
|
195 |
+
|
196 |
save_path = args.save_img_path + 'temp' + '.mp4'
|
|
|
197 |
imageio.mimwrite(save_path, videos[0], fps=8, quality=7)
|
198 |
return save_path
|
199 |
|
|
|
200 |
if not os.path.exists(args.save_img_path):
|
201 |
os.makedirs(args.save_img_path)
|
202 |
|
|
|
208 |
|
209 |
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
|
210 |
|
|
|
|
|
211 |
with gr.Column(variant="panel"):
|
212 |
with gr.Row():
|
213 |
+
prompt_textbox = gr.Textbox(label="Korean Prompt", lines=1)
|
214 |
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1)
|
215 |
|
216 |
with gr.Row(equal_height=False):
|
|
|
222 |
generate_button = gr.Button(value="Generate", variant='primary')
|
223 |
|
224 |
with gr.Accordion("Advanced options", open=False):
|
225 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
with gr.Column():
|
227 |
with gr.Row():
|
228 |
input_image_path = gr.Textbox(label="Input Image URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
|
|
|
233 |
|
234 |
with gr.Row():
|
235 |
seed_textbox = gr.Slider(label="Seed", value=100, minimum=1, maximum=int(1e8), step=1, interactive=True)
|
|
|
|
|
|
|
236 |
|
237 |
with gr.Row():
|
238 |
height = gr.Slider(label="Height", value=320, minimum=0, maximum=512, step=16, interactive=False)
|
|
|
250 |
preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
|
251 |
input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
|
252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
generate_button.click(
|
254 |
fn=gen_video,
|
255 |
inputs=[
|
|
|
269 |
outputs=[result_video]
|
270 |
)
|
271 |
|
272 |
+
demo.launch(debug=False, share=True)
|