aiqtech commited on
Commit
2d89730
1 Parent(s): ac9a6d7

Update demo.py

Browse files
Files changed (1) hide show
  1. demo.py +38 -78
demo.py CHANGED
@@ -4,8 +4,7 @@ import torch
4
  import argparse
5
  import spaces
6
  import torchvision
7
-
8
-
9
  from pipelines.pipeline_videogen import VideoGenPipeline
10
  from diffusers.schedulers import DDIMScheduler
11
  from diffusers.models import AutoencoderKL
@@ -27,7 +26,15 @@ from copy import deepcopy
27
  import requests
28
  from datetime import datetime
29
  import random
30
-
 
 
 
 
 
 
 
 
31
  parser = argparse.ArgumentParser()
32
  parser.add_argument("--config", type=str, default="./configs/sample.yaml")
33
  args = parser.parse_args()
@@ -35,7 +42,7 @@ args = OmegaConf.load(args.config)
35
 
36
  torch.set_grad_enabled(False)
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
- dtype = torch.float16 # torch.float16
39
 
40
  unet = get_models(args).to(device, dtype=dtype)
41
 
@@ -49,15 +56,14 @@ else:
49
  vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
50
  vae = deepcopy(vae_for_base_content).to(dtype=dtype)
51
  tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
52
- text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device) # huge
53
 
54
- # set eval mode
55
  unet.eval()
56
  vae.eval()
57
  text_encoder.eval()
58
 
59
- basedir = os.getcwd()
60
- savedir = os.path.join(basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
61
  savedir_sample = os.path.join(savedir, "sample")
62
  os.makedirs(savedir, exist_ok=True)
63
 
@@ -66,56 +72,55 @@ def update_and_resize_image(input_image_path, height_slider, width_slider):
66
  pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
67
  else:
68
  pil_image = Image.open(input_image_path).convert('RGB')
69
-
70
  original_width, original_height = pil_image.size
71
 
72
  if original_height == height_slider and original_width == width_slider:
73
  return gr.Image(value=np.array(pil_image))
74
-
75
  ratio1 = height_slider / original_height
76
  ratio2 = width_slider / original_width
77
-
78
  if ratio1 > ratio2:
79
  new_width = int(original_width * ratio1)
80
  new_height = int(original_height * ratio1)
81
  else:
82
  new_width = int(original_width * ratio2)
83
  new_height = int(original_height * ratio2)
84
-
85
  pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
86
-
87
  left = (new_width - width_slider) / 2
88
  top = (new_height - height_slider) / 2
89
  right = left + width_slider
90
  bottom = top + height_slider
91
-
92
  pil_image = pil_image.crop((left, top, right, bottom))
93
-
94
- return gr.Image(value=np.array(pil_image))
95
 
 
96
 
97
  def update_textbox_and_save_image(input_image, height_slider, width_slider):
98
  pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
99
 
100
  original_width, original_height = pil_image.size
101
-
102
  ratio1 = height_slider / original_height
103
  ratio2 = width_slider / original_width
104
-
105
  if ratio1 > ratio2:
106
  new_width = int(original_width * ratio1)
107
  new_height = int(original_height * ratio1)
108
  else:
109
  new_width = int(original_width * ratio2)
110
  new_height = int(original_height * ratio2)
111
-
112
  pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
113
-
114
  left = (new_width - width_slider) / 2
115
  top = (new_height - height_slider) / 2
116
  right = left + width_slider
117
  bottom = top + height_slider
118
-
119
  pil_image = pil_image.crop((left, top, right, bottom))
120
 
121
  img_path = os.path.join(savedir, "input_image.png")
@@ -130,10 +135,9 @@ def prepare_image(image, vae, transform_video, device, dtype=torch.float16):
130
  image = image.unsqueeze(2)
131
  return image
132
 
133
-
134
  @spaces.GPU
135
- def gen_video(input_image, prompt, negative_prompt, diffusion_step, height, width, scfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed):
136
-
137
  torch.manual_seed(seed)
138
 
139
  scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
@@ -147,7 +151,6 @@ def gen_video(input_image, prompt, negative_prompt, diffusion_step, height, widt
147
  tokenizer=tokenizer,
148
  scheduler=scheduler,
149
  unet=unet).to(device)
150
- # videogen_pipeline.enable_xformers_memory_efficient_attention()
151
 
152
  transform_video = transforms.Compose([
153
  video_transforms.ToTensorVideo(),
@@ -160,33 +163,25 @@ def gen_video(input_image, prompt, negative_prompt, diffusion_step, height, widt
160
  base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
161
 
162
  if use_dctinit:
163
- # filter params
164
- print("Using DCT!")
165
  base_content_repeat = repeat(base_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
166
-
167
- # define filter
168
  freq_filter = dct_low_pass_filter(dct_coefficients=base_content, percentage=dct_coefficients)
169
-
170
- noise = torch.randn(1, 4, 15, 40, 64).to(device)
171
 
172
- # add noise to base_content
173
- diffuse_timesteps = torch.full((1,),int(noise_level))
174
  diffuse_timesteps = diffuse_timesteps.long()
175
-
176
- # 3d content
177
  base_content_noise = scheduler.add_noise(
178
  original_samples=base_content_repeat.to(device),
179
  noise=noise,
180
  timesteps=diffuse_timesteps.to(device))
181
-
182
- # 3d content
183
  latents = exchanged_mixed_dct_freq(noise=noise,
184
  base_content=base_content_noise,
185
  LPF_3d=freq_filter).to(dtype=torch.float16)
186
-
187
  base_content = base_content.to(dtype=torch.float16)
188
 
189
- videos = videogen_pipeline(prompt,
190
  negative_prompt=negative_prompt,
191
  latents=latents if use_dctinit else None,
192
  base_content=base_content,
@@ -197,13 +192,11 @@ def gen_video(input_image, prompt, negative_prompt, diffusion_step, height, widt
197
  guidance_scale=scfg_scale,
198
  motion_bucket_id=100-motion_bucket_id,
199
  enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
200
-
201
  save_path = args.save_img_path + 'temp' + '.mp4'
202
- # torchvision.io.write_video(save_path, videos[0], fps=8, video_codec='h264', options={'crf': '10'})
203
  imageio.mimwrite(save_path, videos[0], fps=8, quality=7)
204
  return save_path
205
 
206
-
207
  if not os.path.exists(args.save_img_path):
208
  os.makedirs(args.save_img_path)
209
 
@@ -215,11 +208,9 @@ footer {
215
 
216
  with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
217
 
218
-
219
-
220
  with gr.Column(variant="panel"):
221
  with gr.Row():
222
- prompt_textbox = gr.Textbox(label="Prompt", lines=1)
223
  negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1)
224
 
225
  with gr.Row(equal_height=False):
@@ -231,13 +222,7 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
231
  generate_button = gr.Button(value="Generate", variant='primary')
232
 
233
  with gr.Accordion("Advanced options", open=False):
234
- gr.Markdown(
235
- """
236
- - Input image can be specified using the "Input Image URL" text box or uploaded by clicking or dragging the image to the "Input Image" box.
237
- - Input image will be resized and/or center cropped to a given resolution (320 x 512) automatically.
238
- - After setting the input image path, press the "Preview" button to visualize the resized input image.
239
- """
240
- )
241
  with gr.Column():
242
  with gr.Row():
243
  input_image_path = gr.Textbox(label="Input Image URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
@@ -248,9 +233,6 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
248
 
249
  with gr.Row():
250
  seed_textbox = gr.Slider(label="Seed", value=100, minimum=1, maximum=int(1e8), step=1, interactive=True)
251
- # seed_textbox = gr.Textbox(label="Seed", value=100)
252
- # seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
253
- # seed_button.click(fn=lambda: gr.Textbox(value=random.randint(1, int(1e8))), inputs=[], outputs=[seed_textbox])
254
 
255
  with gr.Row():
256
  height = gr.Slider(label="Height", value=320, minimum=0, maximum=512, step=16, interactive=False)
@@ -268,28 +250,6 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
268
  preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
269
  input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
270
 
271
- EXAMPLES = [
272
- ["./example/red_panda_eating_bamboo/0.jpg", "red panda eating bamboo" , "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
273
- ["./example/fireworks/0.jpg", "fireworks" , "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
274
- ["./example/flowers_swaying/0.jpg", "flowers swaying" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
275
- ["./example/girl_walking_on_the_beach/0.jpg", "girl walking on the beach" , "low quality, background changing", 50, 320, 512, 7.5, True, 0.25, 995, 10, 49494220],
276
- ["./example/house_rotating/0.jpg", "house rotating" , "low quality", 50, 320, 512, 7.5, True, 0.23, 985, 10, 46640174],
277
- ["./example/people_runing/0.jpg", "people runing" , "low quality, background changing", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
278
- ["./example/shark_swimming/0.jpg", "shark swimming" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 32947978],
279
- ["./example/car_moving/0.jpg", "car moving" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 75469653],
280
- ["./example/windmill_turning/0.jpg", "windmill turning" , "background changing", 50, 320, 512, 7.5, True, 0.21, 975, 10, 89378613],
281
- ]
282
-
283
-
284
- examples = gr.Examples(
285
- examples = EXAMPLES,
286
- fn = gen_video,
287
- inputs=[input_image, prompt_textbox, negative_prompt_textbox, sample_step_slider, height, width, txt_cfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed_textbox],
288
- outputs=[result_video],
289
- cache_examples=True,
290
- # cache_examples="lazy",
291
- )
292
-
293
  generate_button.click(
294
  fn=gen_video,
295
  inputs=[
@@ -309,4 +269,4 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
309
  outputs=[result_video]
310
  )
311
 
312
- demo.launch(debug=False, share=True)
 
4
  import argparse
5
  import spaces
6
  import torchvision
7
+ from transformers import pipeline
 
8
  from pipelines.pipeline_videogen import VideoGenPipeline
9
  from diffusers.schedulers import DDIMScheduler
10
  from diffusers.models import AutoencoderKL
 
26
  import requests
27
  from datetime import datetime
28
  import random
29
+
30
+ # 번역 파이프라인 생성
31
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
32
+
33
+ # 번역 함수
34
+ def translate_prompt(korean_prompt):
35
+ translation = translator(korean_prompt, max_length=512)
36
+ return translation[0]['translation_text']
37
+
38
  parser = argparse.ArgumentParser()
39
  parser.add_argument("--config", type=str, default="./configs/sample.yaml")
40
  args = parser.parse_args()
 
42
 
43
  torch.set_grad_enabled(False)
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ dtype = torch.float16
46
 
47
  unet = get_models(args).to(device, dtype=dtype)
48
 
 
56
  vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
57
  vae = deepcopy(vae_for_base_content).to(dtype=dtype)
58
  tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
59
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
60
 
 
61
  unet.eval()
62
  vae.eval()
63
  text_encoder.eval()
64
 
65
+ basedir = os.getcwd()
66
+ savedir = os.path.join(basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
67
  savedir_sample = os.path.join(savedir, "sample")
68
  os.makedirs(savedir, exist_ok=True)
69
 
 
72
  pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
73
  else:
74
  pil_image = Image.open(input_image_path).convert('RGB')
75
+
76
  original_width, original_height = pil_image.size
77
 
78
  if original_height == height_slider and original_width == width_slider:
79
  return gr.Image(value=np.array(pil_image))
80
+
81
  ratio1 = height_slider / original_height
82
  ratio2 = width_slider / original_width
83
+
84
  if ratio1 > ratio2:
85
  new_width = int(original_width * ratio1)
86
  new_height = int(original_height * ratio1)
87
  else:
88
  new_width = int(original_width * ratio2)
89
  new_height = int(original_height * ratio2)
90
+
91
  pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
92
+
93
  left = (new_width - width_slider) / 2
94
  top = (new_height - height_slider) / 2
95
  right = left + width_slider
96
  bottom = top + height_slider
97
+
98
  pil_image = pil_image.crop((left, top, right, bottom))
 
 
99
 
100
+ return gr.Image(value=np.array(pil_image))
101
 
102
  def update_textbox_and_save_image(input_image, height_slider, width_slider):
103
  pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
104
 
105
  original_width, original_height = pil_image.size
106
+
107
  ratio1 = height_slider / original_height
108
  ratio2 = width_slider / original_width
109
+
110
  if ratio1 > ratio2:
111
  new_width = int(original_width * ratio1)
112
  new_height = int(original_height * ratio1)
113
  else:
114
  new_width = int(original_width * ratio2)
115
  new_height = int(original_height * ratio2)
116
+
117
  pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
118
+
119
  left = (new_width - width_slider) / 2
120
  top = (new_height - height_slider) / 2
121
  right = left + width_slider
122
  bottom = top + height_slider
123
+
124
  pil_image = pil_image.crop((left, top, right, bottom))
125
 
126
  img_path = os.path.join(savedir, "input_image.png")
 
135
  image = image.unsqueeze(2)
136
  return image
137
 
 
138
  @spaces.GPU
139
+ def gen_video(input_image, korean_prompt, negative_prompt, diffusion_step, height, width, scfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed):
140
+ english_prompt = translate_prompt(korean_prompt)
141
  torch.manual_seed(seed)
142
 
143
  scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
 
151
  tokenizer=tokenizer,
152
  scheduler=scheduler,
153
  unet=unet).to(device)
 
154
 
155
  transform_video = transforms.Compose([
156
  video_transforms.ToTensorVideo(),
 
163
  base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
164
 
165
  if use_dctinit:
 
 
166
  base_content_repeat = repeat(base_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
 
 
167
  freq_filter = dct_low_pass_filter(dct_coefficients=base_content, percentage=dct_coefficients)
 
 
168
 
169
+ noise = torch.randn(1, 4, 15, 40, 64).to(device)
170
+ diffuse_timesteps = torch.full((1,), int(noise_level))
171
  diffuse_timesteps = diffuse_timesteps.long()
172
+
 
173
  base_content_noise = scheduler.add_noise(
174
  original_samples=base_content_repeat.to(device),
175
  noise=noise,
176
  timesteps=diffuse_timesteps.to(device))
177
+
 
178
  latents = exchanged_mixed_dct_freq(noise=noise,
179
  base_content=base_content_noise,
180
  LPF_3d=freq_filter).to(dtype=torch.float16)
181
+
182
  base_content = base_content.to(dtype=torch.float16)
183
 
184
+ videos = videogen_pipeline(english_prompt,
185
  negative_prompt=negative_prompt,
186
  latents=latents if use_dctinit else None,
187
  base_content=base_content,
 
192
  guidance_scale=scfg_scale,
193
  motion_bucket_id=100-motion_bucket_id,
194
  enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
195
+
196
  save_path = args.save_img_path + 'temp' + '.mp4'
 
197
  imageio.mimwrite(save_path, videos[0], fps=8, quality=7)
198
  return save_path
199
 
 
200
  if not os.path.exists(args.save_img_path):
201
  os.makedirs(args.save_img_path)
202
 
 
208
 
209
  with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
210
 
 
 
211
  with gr.Column(variant="panel"):
212
  with gr.Row():
213
+ prompt_textbox = gr.Textbox(label="Korean Prompt", lines=1)
214
  negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1)
215
 
216
  with gr.Row(equal_height=False):
 
222
  generate_button = gr.Button(value="Generate", variant='primary')
223
 
224
  with gr.Accordion("Advanced options", open=False):
225
+
 
 
 
 
 
 
226
  with gr.Column():
227
  with gr.Row():
228
  input_image_path = gr.Textbox(label="Input Image URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
 
233
 
234
  with gr.Row():
235
  seed_textbox = gr.Slider(label="Seed", value=100, minimum=1, maximum=int(1e8), step=1, interactive=True)
 
 
 
236
 
237
  with gr.Row():
238
  height = gr.Slider(label="Height", value=320, minimum=0, maximum=512, step=16, interactive=False)
 
250
  preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
251
  input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  generate_button.click(
254
  fn=gen_video,
255
  inputs=[
 
269
  outputs=[result_video]
270
  )
271
 
272
+ demo.launch(debug=False, share=True)