benibraz commited on
Commit
fc65614
1 Parent(s): 637b686

different tabs for different functionality

Browse files
Files changed (1) hide show
  1. app.py +246 -71
app.py CHANGED
@@ -24,12 +24,14 @@ hf_token = os.getenv("HF_TOKEN")
24
  # Set model download directory within Hugging Face Spaces
25
  model_path = "asset"
26
  if not os.path.exists(model_path):
27
- snapshot_download("Lightricks/LTX-Video", local_dir=model_path, repo_type='model', token=hf_token)
 
 
28
 
29
  # Global variables to load components
30
- vae_dir = Path(model_path) / 'vae'
31
- unet_dir = Path(model_path) / 'unet'
32
- scheduler_dir = Path(model_path) / 'scheduler'
33
 
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
 
@@ -37,7 +39,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
  def load_vae(vae_dir):
38
  vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
39
  vae_config_path = vae_dir / "config.json"
40
- with open(vae_config_path, 'r') as f:
41
  vae_config = json.load(f)
42
  vae = CausalVideoAutoencoder.from_config(vae_config)
43
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
@@ -69,11 +71,11 @@ def center_crop_and_resize(frame, target_height, target_width):
69
  if aspect_ratio_frame > aspect_ratio_target:
70
  new_width = int(h * aspect_ratio_target)
71
  x_start = (w - new_width) // 2
72
- frame_cropped = frame[:, x_start:x_start + new_width]
73
  else:
74
  new_height = int(w / aspect_ratio_target)
75
  y_start = (h - new_height) // 2
76
- frame_cropped = frame[y_start:y_start + new_height, :]
77
  frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
78
  return frame_resized
79
 
@@ -116,7 +118,7 @@ preset_options = [
116
  {"label": "544x320, 241 frames", "width": 544, "height": 320, "num_frames": 241},
117
  {"label": "512x320, 249 frames", "width": 512, "height": 320, "num_frames": 249},
118
  {"label": "512x320, 257 frames", "width": 512, "height": 320, "num_frames": 257},
119
- {"label": "Custom", "height": None, "width": None, "num_frames": None}
120
  ]
121
 
122
 
@@ -130,10 +132,17 @@ def preset_changed(preset):
130
  selected["num_frames"],
131
  gr.update(visible=False),
132
  gr.update(visible=False),
133
- gr.update(visible=False)
134
  )
135
  else:
136
- return None, None, None, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
 
 
 
 
 
 
 
137
 
138
 
139
  # Load models
@@ -141,8 +150,12 @@ vae = load_vae(vae_dir)
141
  unet = load_unet(unet_dir)
142
  scheduler = load_scheduler(scheduler_dir)
143
  patchifier = SymmetricPatchifier(patch_size=1)
144
- text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to(device)
145
- tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer")
 
 
 
 
146
 
147
  pipeline = XoraVideoPipeline(
148
  transformer=unet,
@@ -154,26 +167,108 @@ pipeline = XoraVideoPipeline(
154
  ).to(device)
155
 
156
 
157
- # Modified function to include validation with gr.Error
158
- #@spaces.GPU(duration=120)
159
- def generate_video(image_path=None, prompt="", negative_prompt="",
160
- seed=171198, num_inference_steps=40, num_images_per_prompt=1,
161
- guidance_scale=3, height=512, width=768, num_frames=121, frame_rate=25, progress=gr.Progress()):
162
- # Check prompt length and raise an error if it's too short
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  if len(prompt.strip()) < 50:
164
- raise gr.Error("Prompt must be at least 50 characters long. Please provide more details for the best results.", duration=5)
 
 
 
165
 
166
- if image_path:
167
- media_items = load_image_to_tensor_with_resize(image_path, height, width).to(device)
168
- media_items=None
169
 
 
170
 
171
  sample = {
172
  "prompt": prompt,
173
- 'prompt_attention_mask': None,
174
- 'negative_prompt': negative_prompt,
175
- 'negative_prompt_attention_mask': None,
176
- 'media_items': media_items,
177
  }
178
 
179
  generator = torch.Generator(device="cpu").manual_seed(seed)
@@ -196,14 +291,16 @@ def generate_video(image_path=None, prompt="", negative_prompt="",
196
  vae_per_channel_normalize=True,
197
  conditioning_method=ConditioningMethod.FIRST_FRAME,
198
  mixed_precision=True,
199
- callback_on_step_end=gradio_progress_callback
200
  ).images
201
 
202
  output_path = tempfile.mktemp(suffix=".mp4")
203
  video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
204
  video_np = (video_np * 255).astype(np.uint8)
205
  height, width = video_np.shape[1:3]
206
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
 
 
207
  for frame in video_np[..., ::-1]:
208
  out.write(frame)
209
  out.release()
@@ -211,55 +308,133 @@ def generate_video(image_path=None, prompt="", negative_prompt="",
211
  return output_path
212
 
213
 
214
- # Define the Gradio interface with presets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  with gr.Blocks() as iface:
216
  gr.Markdown("# Video Generation with LTX Video")
217
 
218
- with gr.Row():
219
- with gr.Column():
220
- image_input = gr.Image(type="filepath", label="Image Input")
221
- prompt = gr.Textbox(label="Prompt", value="A man riding a motorcycle down a winding road, surrounded by lush, green scenery and distant mountains. The sky is clear with a few wispy clouds, and the sunlight glistens on the motorcycle as it speeds along. The rider is dressed in a black leather jacket and helmet, leaning slightly forward as the wind rustles through nearby trees. The wheels kick up dust, creating a slight trail behind the motorcycle, adding a sense of speed and excitement to the scene.")
222
- negative_prompt = gr.Textbox(label="Negative Prompt", value="worst quality, inconsistent motion...")
223
-
224
- # Preset dropdown for resolution and frame settings
225
- preset_dropdown = gr.Dropdown(
226
- choices=[p["label"] for p in preset_options],
227
- value="1216x704, 41 frames",
228
- label="Resolution Preset"
229
- )
230
-
231
- # Advanced options section
232
- with gr.Accordion("Advanced Options", open=False):
233
- seed = gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=171198)
234
- inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=40)
235
- images_per_prompt = gr.Slider(label="Images per Prompt", minimum=1, maximum=10, step=1, value=1)
236
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, step=0.1, value=3.0)
237
-
238
- # Sliders to appear at the end of the advanced settings
239
- height_slider = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=704, visible=False)
240
- width_slider = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=1216, visible=False)
241
- num_frames_slider = gr.Slider(label="Number of Frames", minimum=1, maximum=200, step=1, value=41,
242
- visible=False)
243
-
244
- frame_rate = gr.Slider(label="Frame Rate", minimum=1, maximum=60, step=1, value=25, visible=False)
245
-
246
- generate_button = gr.Button("Generate Video")
247
-
248
- with gr.Column():
249
- output_video = gr.Video(label="Generated Video")
250
-
251
- # Link dropdown change to update sliders visibility and values
252
- preset_dropdown.change(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  fn=preset_changed,
254
- inputs=[preset_dropdown],
255
- outputs=[height_slider, width_slider, num_frames_slider, height_slider, width_slider, frame_rate]
256
  )
257
 
258
- generate_button.click(
259
- fn=generate_video,
260
- inputs=[image_input, prompt, negative_prompt, seed, inference_steps, images_per_prompt, guidance_scale,
261
- height_slider, width_slider, num_frames_slider, frame_rate],
262
- outputs=output_video
 
 
 
 
263
  )
264
 
265
  iface.launch(share=True)
 
24
  # Set model download directory within Hugging Face Spaces
25
  model_path = "asset"
26
  if not os.path.exists(model_path):
27
+ snapshot_download(
28
+ "Lightricks/LTX-Video", local_dir=model_path, repo_type="model", token=hf_token
29
+ )
30
 
31
  # Global variables to load components
32
+ vae_dir = Path(model_path) / "vae"
33
+ unet_dir = Path(model_path) / "unet"
34
+ scheduler_dir = Path(model_path) / "scheduler"
35
 
36
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
 
 
39
  def load_vae(vae_dir):
40
  vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
41
  vae_config_path = vae_dir / "config.json"
42
+ with open(vae_config_path, "r") as f:
43
  vae_config = json.load(f)
44
  vae = CausalVideoAutoencoder.from_config(vae_config)
45
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
 
71
  if aspect_ratio_frame > aspect_ratio_target:
72
  new_width = int(h * aspect_ratio_target)
73
  x_start = (w - new_width) // 2
74
+ frame_cropped = frame[:, x_start : x_start + new_width]
75
  else:
76
  new_height = int(w / aspect_ratio_target)
77
  y_start = (h - new_height) // 2
78
+ frame_cropped = frame[y_start : y_start + new_height, :]
79
  frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
80
  return frame_resized
81
 
 
118
  {"label": "544x320, 241 frames", "width": 544, "height": 320, "num_frames": 241},
119
  {"label": "512x320, 249 frames", "width": 512, "height": 320, "num_frames": 249},
120
  {"label": "512x320, 257 frames", "width": 512, "height": 320, "num_frames": 257},
121
+ {"label": "Custom", "height": None, "width": None, "num_frames": None},
122
  ]
123
 
124
 
 
132
  selected["num_frames"],
133
  gr.update(visible=False),
134
  gr.update(visible=False),
135
+ gr.update(visible=False),
136
  )
137
  else:
138
+ return (
139
+ None,
140
+ None,
141
+ None,
142
+ gr.update(visible=True),
143
+ gr.update(visible=True),
144
+ gr.update(visible=True),
145
+ )
146
 
147
 
148
  # Load models
 
150
  unet = load_unet(unet_dir)
151
  scheduler = load_scheduler(scheduler_dir)
152
  patchifier = SymmetricPatchifier(patch_size=1)
153
+ text_encoder = T5EncoderModel.from_pretrained(
154
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
155
+ ).to(device)
156
+ tokenizer = T5Tokenizer.from_pretrained(
157
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
158
+ )
159
 
160
  pipeline = XoraVideoPipeline(
161
  transformer=unet,
 
167
  ).to(device)
168
 
169
 
170
+ import gradio as gr
171
+ import torch
172
+ from huggingface_hub import snapshot_download
173
+
174
+ # [Previous imports remain the same...]
175
+
176
+
177
+ def generate_video_from_text(
178
+ prompt="",
179
+ negative_prompt="",
180
+ seed=171198,
181
+ num_inference_steps=40,
182
+ num_images_per_prompt=1,
183
+ guidance_scale=3,
184
+ height=512,
185
+ width=768,
186
+ num_frames=121,
187
+ frame_rate=25,
188
+ progress=gr.Progress(),
189
+ ):
190
+ if len(prompt.strip()) < 50:
191
+ raise gr.Error(
192
+ "Prompt must be at least 50 characters long. Please provide more details for the best results.",
193
+ duration=5,
194
+ )
195
+
196
+ sample = {
197
+ "prompt": prompt,
198
+ "prompt_attention_mask": None,
199
+ "negative_prompt": negative_prompt,
200
+ "negative_prompt_attention_mask": None,
201
+ "media_items": None,
202
+ }
203
+
204
+ generator = torch.Generator(device="cpu").manual_seed(seed)
205
+
206
+ def gradio_progress_callback(self, step, timestep, kwargs):
207
+ progress((step + 1) / num_inference_steps)
208
+
209
+ images = pipeline(
210
+ num_inference_steps=num_inference_steps,
211
+ num_images_per_prompt=num_images_per_prompt,
212
+ guidance_scale=guidance_scale,
213
+ generator=generator,
214
+ output_type="pt",
215
+ height=height,
216
+ width=width,
217
+ num_frames=num_frames,
218
+ frame_rate=frame_rate,
219
+ **sample,
220
+ is_video=True,
221
+ vae_per_channel_normalize=True,
222
+ conditioning_method=ConditioningMethod.FIRST_FRAME,
223
+ mixed_precision=True,
224
+ callback_on_step_end=gradio_progress_callback,
225
+ ).images
226
+
227
+ output_path = tempfile.mktemp(suffix=".mp4")
228
+ video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
229
+ video_np = (video_np * 255).astype(np.uint8)
230
+ height, width = video_np.shape[1:3]
231
+ out = cv2.VideoWriter(
232
+ output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height)
233
+ )
234
+ for frame in video_np[..., ::-1]:
235
+ out.write(frame)
236
+ out.release()
237
+
238
+ return output_path
239
+
240
+
241
+ def generate_video_from_image(
242
+ image_path,
243
+ prompt="",
244
+ negative_prompt="",
245
+ seed=171198,
246
+ num_inference_steps=40,
247
+ num_images_per_prompt=1,
248
+ guidance_scale=3,
249
+ height=512,
250
+ width=768,
251
+ num_frames=121,
252
+ frame_rate=25,
253
+ progress=gr.Progress(),
254
+ ):
255
  if len(prompt.strip()) < 50:
256
+ raise gr.Error(
257
+ "Prompt must be at least 50 characters long. Please provide more details for the best results.",
258
+ duration=5,
259
+ )
260
 
261
+ if not image_path:
262
+ raise gr.Error("Please provide an input image.", duration=5)
 
263
 
264
+ media_items = load_image_to_tensor_with_resize(image_path, height, width).to(device)
265
 
266
  sample = {
267
  "prompt": prompt,
268
+ "prompt_attention_mask": None,
269
+ "negative_prompt": negative_prompt,
270
+ "negative_prompt_attention_mask": None,
271
+ "media_items": media_items,
272
  }
273
 
274
  generator = torch.Generator(device="cpu").manual_seed(seed)
 
291
  vae_per_channel_normalize=True,
292
  conditioning_method=ConditioningMethod.FIRST_FRAME,
293
  mixed_precision=True,
294
+ callback_on_step_end=gradio_progress_callback,
295
  ).images
296
 
297
  output_path = tempfile.mktemp(suffix=".mp4")
298
  video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
299
  video_np = (video_np * 255).astype(np.uint8)
300
  height, width = video_np.shape[1:3]
301
+ out = cv2.VideoWriter(
302
+ output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height)
303
+ )
304
  for frame in video_np[..., ::-1]:
305
  out.write(frame)
306
  out.release()
 
308
  return output_path
309
 
310
 
311
+ def create_advanced_options():
312
+ with gr.Accordion("Advanced Options", open=False):
313
+ seed = gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=171198)
314
+ inference_steps = gr.Slider(
315
+ label="Inference Steps", minimum=1, maximum=100, step=1, value=40
316
+ )
317
+ images_per_prompt = gr.Slider(
318
+ label="Images per Prompt", minimum=1, maximum=10, step=1, value=1
319
+ )
320
+ guidance_scale = gr.Slider(
321
+ label="Guidance Scale", minimum=1.0, maximum=20.0, step=0.1, value=3.0
322
+ )
323
+
324
+ height_slider = gr.Slider(
325
+ label="Height", minimum=256, maximum=1024, step=64, value=704, visible=False
326
+ )
327
+ width_slider = gr.Slider(
328
+ label="Width", minimum=256, maximum=1024, step=64, value=1216, visible=False
329
+ )
330
+ num_frames_slider = gr.Slider(
331
+ label="Number of Frames",
332
+ minimum=1,
333
+ maximum=200,
334
+ step=1,
335
+ value=41,
336
+ visible=False,
337
+ )
338
+ frame_rate = gr.Slider(
339
+ label="Frame Rate", minimum=1, maximum=60, step=1, value=25, visible=False
340
+ )
341
+
342
+ return [
343
+ seed,
344
+ inference_steps,
345
+ images_per_prompt,
346
+ guidance_scale,
347
+ height_slider,
348
+ width_slider,
349
+ num_frames_slider,
350
+ frame_rate,
351
+ ]
352
+
353
+
354
+ # Define the Gradio interface with tabs
355
  with gr.Blocks() as iface:
356
  gr.Markdown("# Video Generation with LTX Video")
357
 
358
+ with gr.Tabs():
359
+ with gr.TabItem("Text to Video"):
360
+ with gr.Row():
361
+ with gr.Column():
362
+ txt2vid_prompt = gr.Textbox(
363
+ label="Prompt",
364
+ value="A man riding a motorcycle down a winding road, surrounded by lush, green scenery and distant mountains. The sky is clear with a few wispy clouds, and the sunlight glistens on the motorcycle as it speeds along. The rider is dressed in a black leather jacket and helmet, leaning slightly forward as the wind rustles through nearby trees. The wheels kick up dust, creating a slight trail behind the motorcycle, adding a sense of speed and excitement to the scene.",
365
+ )
366
+ txt2vid_negative_prompt = gr.Textbox(
367
+ label="Negative Prompt",
368
+ value="worst quality, inconsistent motion...",
369
+ )
370
+
371
+ # Preset dropdown for resolution and frame settings
372
+ txt2vid_preset = gr.Dropdown(
373
+ choices=[p["label"] for p in preset_options],
374
+ value="1216x704, 41 frames",
375
+ label="Resolution Preset",
376
+ )
377
+
378
+ txt2vid_advanced = create_advanced_options()
379
+ txt2vid_generate = gr.Button("Generate Video")
380
+
381
+ with gr.Column():
382
+ txt2vid_output = gr.Video(label="Generated Video")
383
+
384
+ with gr.TabItem("Image to Video"):
385
+ with gr.Row():
386
+ with gr.Column():
387
+ img2vid_image = gr.Image(type="filepath", label="Input Image")
388
+ img2vid_prompt = gr.Textbox(
389
+ label="Prompt",
390
+ value="A man riding a motorcycle down a winding road, surrounded by lush, green scenery and distant mountains...",
391
+ )
392
+ img2vid_negative_prompt = gr.Textbox(
393
+ label="Negative Prompt",
394
+ value="worst quality, inconsistent motion...",
395
+ )
396
+
397
+ img2vid_preset = gr.Dropdown(
398
+ choices=[p["label"] for p in preset_options],
399
+ value="1216x704, 41 frames",
400
+ label="Resolution Preset",
401
+ )
402
+
403
+ img2vid_advanced = create_advanced_options()
404
+ img2vid_generate = gr.Button("Generate Video")
405
+
406
+ with gr.Column():
407
+ img2vid_output = gr.Video(label="Generated Video")
408
+
409
+ # Event handlers for text-to-video tab
410
+ txt2vid_preset.change(
411
+ fn=preset_changed,
412
+ inputs=[txt2vid_preset],
413
+ outputs=txt2vid_advanced[4:], # height, width, num_frames, and their visibility
414
+ )
415
+
416
+ txt2vid_generate.click(
417
+ fn=generate_video_from_text,
418
+ inputs=[txt2vid_prompt, txt2vid_negative_prompt, *txt2vid_advanced],
419
+ outputs=txt2vid_output,
420
+ )
421
+
422
+ # Event handlers for image-to-video tab
423
+ img2vid_preset.change(
424
  fn=preset_changed,
425
+ inputs=[img2vid_preset],
426
+ outputs=img2vid_advanced[4:], # height, width, num_frames, and their visibility
427
  )
428
 
429
+ img2vid_generate.click(
430
+ fn=generate_video_from_image,
431
+ inputs=[
432
+ img2vid_image,
433
+ img2vid_prompt,
434
+ img2vid_negative_prompt,
435
+ *img2vid_advanced,
436
+ ],
437
+ outputs=img2vid_output,
438
  )
439
 
440
  iface.launch(share=True)