xxxpo13 commited on
Commit
023670c
·
verified ·
1 Parent(s): 1359eac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -119
app.py CHANGED
@@ -1,127 +1,143 @@
1
  import os
2
- import sys
3
- import torch
4
- import argparse
5
- from PIL import Image
6
- from diffusers.utils import export_to_video
7
-
8
- # Add the project root directory to sys.path
9
- SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
10
- PROJECT_ROOT = os.path.dirname(SCRIPT_DIR)
11
- if PROJECT_ROOT not in sys.path:
12
- sys.path.insert(0, PROJECT_ROOT)
13
-
14
- from pyramid_dit import PyramidDiTForVideoGeneration
15
- from trainer_misc import init_distributed_mode, init_sequence_parallel_group
16
-
17
- def get_args():
18
- parser = argparse.ArgumentParser('Pytorch Multi-process Script', add_help=False)
19
- parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16")
20
- parser.add_argument('--model_path', required=True, type=str, help='Path to the downloaded checkpoint directory')
21
- parser.add_argument('--variant', default='diffusion_transformer_768p', type=str)
22
- parser.add_argument('--task', default='t2v', type=str, choices=['i2v', 't2v'])
23
- parser.add_argument('--temp', default=16, type=int, help='The generated latent num, num_frames = temp * 8 + 1')
24
- parser.add_argument('--sp_group_size', default=2, type=int, help="The number of GPUs used for inference, should be 2 or 4")
25
- parser.add_argument('--sp_proc_num', default=-1, type=int, help="The number of processes used for video training, default=-1 means using all processes.")
26
- parser.add_argument('--prompt', type=str, required=True, help="Text prompt for video generation")
27
- parser.add_argument('--image_path', type=str, help="Path to the input image for image-to-video")
28
- parser.add_argument('--video_guidance_scale', type=float, default=5.0, help="Video guidance scale")
29
- parser.add_argument('--guidance_scale', type=float, default=9.0, help="Guidance scale for text-to-video")
30
- parser.add_argument('--resolution', type=str, default='768p', choices=['768p', '384p'], help="Model resolution")
31
- parser.add_argument('--output_path', type=str, required=True, help="Path to save the generated video")
32
- return parser.parse_args()
33
-
34
- def main():
35
- args = get_args()
36
-
37
- # Setup DDP
38
- init_distributed_mode(args)
39
-
40
- assert args.world_size == args.sp_group_size, "The sequence parallel size should match DDP world size"
41
-
42
- # Enable sequence parallel
43
- init_sequence_parallel_group(args)
44
-
45
- device = torch.device('cuda')
46
- rank = args.rank
47
- model_dtype = args.model_dtype
48
-
49
- model = PyramidDiTForVideoGeneration(
50
- args.model_path,
51
- model_dtype,
52
- model_variant=args.variant,
53
- )
54
-
55
- model.vae.to(device)
56
- model.dit.to(device)
57
- model.text_encoder.to(device)
58
- model.vae.enable_tiling()
59
-
60
- if model_dtype == "bf16":
61
- torch_dtype = torch.bfloat16
62
- elif model_dtype == "fp16":
63
- torch_dtype = torch.float16
64
  else:
65
- torch_dtype = torch.float32
 
66
 
67
- # The video generation config
68
- if args.resolution == '768p':
69
- width = 1280
70
- height = 768
71
- else:
72
- width = 640
73
- height = 384
74
 
75
- try:
76
- if args.task == 't2v':
77
- prompt = args.prompt
78
- with torch.no_grad(), torch.cuda.amp.autocast(enabled=(model_dtype != 'fp32'), dtype=torch_dtype):
79
- frames = model.generate(
80
- prompt=prompt,
81
- num_inference_steps=[20, 20, 20],
82
- video_num_inference_steps=[10, 10, 10],
83
- height=height,
84
- width=width,
85
- temp=args.temp,
86
- guidance_scale=args.guidance_scale,
87
- video_guidance_scale=args.video_guidance_scale,
88
- output_type="pil",
89
- save_memory=True,
90
- cpu_offloading=True,
91
- inference_multigpu=True,
92
- )
93
- if rank == 0:
94
- export_to_video(frames, args.output_path, fps=24)
95
 
96
- elif args.task == 'i2v':
97
- if not args.image_path:
98
- raise ValueError("Image path is required for image-to-video task")
99
- image = Image.open(args.image_path).convert("RGB")
100
- image = image.resize((width, height))
101
-
102
- prompt = args.prompt
103
 
104
- with torch.no_grad(), torch.cuda.amp.autocast(enabled=(model_dtype != 'fp32'), dtype=torch_dtype):
105
- frames = model.generate_i2v(
106
- prompt=prompt,
107
- input_image=image,
108
- num_inference_steps=[10, 10, 10],
109
- temp=args.temp,
110
- video_guidance_scale=args.video_guidance_scale,
111
- output_type="pil",
112
- save_memory=True,
113
- cpu_offloading=True,
114
- inference_multigpu=True,
 
 
 
 
 
 
 
 
 
115
  )
116
- if rank == 0:
117
- export_to_video(frames, args.output_path, fps=24)
118
-
119
- except Exception as e:
120
- if rank == 0:
121
- print(f"[ERROR] Error during video generation: {e}")
122
- raise
123
- finally:
124
- torch.distributed.barrier()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- if __name__ == "__main__":
127
- main()
 
1
  import os
2
+ import uuid
3
+ import gradio as gr
4
+ import subprocess
5
+ import tempfile
6
+ import shutil
7
+
8
+ def run_inference_multigpu(gpus, variant, model_path, temp, guidance_scale, video_guidance_scale, resolution, prompt):
9
+ """
10
+ Runs the external multi-GPU inference script and returns the path to the generated video.
11
+ """
12
+ # Create a temporary directory to store inputs and outputs
13
+ with tempfile.TemporaryDirectory() as tmpdir:
14
+ output_video = os.path.join(tmpdir, f"{uuid.uuid4()}_output.mp4")
15
+
16
+ # Path to the external shell script
17
+ script_path = "./scripts/app_multigpu_engine.sh" # Updated script path
18
+
19
+ # Prepare the command
20
+ cmd = [
21
+ script_path,
22
+ str(gpus),
23
+ variant,
24
+ model_path,
25
+ 't2v', # Task is always 't2v' since 'i2v' is removed
26
+ str(temp),
27
+ str(guidance_scale),
28
+ str(video_guidance_scale),
29
+ resolution,
30
+ output_video,
31
+ prompt # Pass the prompt directly as an argument
32
+ ]
33
+
34
+ try:
35
+ # Run the external script
36
+ subprocess.run(cmd, check=True)
37
+ except subprocess.CalledProcessError as e:
38
+ raise RuntimeError(f"Error during video generation: {e}")
39
+
40
+ # After generation, move the video to a permanent location
41
+ final_output = os.path.join("generated_videos", f"{uuid.uuid4()}_output.mp4")
42
+ os.makedirs("generated_videos", exist_ok=True)
43
+ shutil.move(output_video, final_output)
44
+
45
+ return final_output
46
+
47
+ def generate_text_to_video(prompt, temp, guidance_scale, video_guidance_scale, resolution, gpus):
48
+ model_path = "./pyramid_flow_model" # Use the model path as specified
49
+ # Determine variant based on resolution
50
+ if resolution == "768p":
51
+ variant = "diffusion_transformer_768p"
 
 
 
 
 
 
 
 
 
 
 
 
52
  else:
53
+ variant = "diffusion_transformer_384p"
54
+ return run_inference_multigpu(gpus, variant, model_path, temp, guidance_scale, video_guidance_scale, resolution, prompt)
55
 
56
+ # Gradio interface
57
+ with gr.Blocks() as demo:
58
+ gr.Markdown(
59
+ """
60
+ # Pyramid Flow Video Generation Demo
 
 
61
 
62
+ Pyramid Flow is a training-efficient **Autoregressive Video Generation** model based on **Flow Matching**. It is trained only on open-source datasets within 20.7k A100 GPU hours.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ [[Paper]](https://arxiv.org/abs/2410.05954) [[Project Page]](https://pyramid-flow.github.io) [[Code]](https://github.com/jy0205/Pyramid-Flow) [[Model]](https://huggingface.co/rain1011/pyramid-flow-sd3)
65
+ """
66
+ )
 
 
 
 
67
 
68
+ # Shared settings
69
+ with gr.Row():
70
+ gpus_dropdown = gr.Dropdown(
71
+ choices=[2, 4],
72
+ value=4,
73
+ label="Number of GPUs"
74
+ )
75
+ resolution_dropdown = gr.Dropdown(
76
+ choices=["768p", "384p"],
77
+ value="768p",
78
+ label="Model Resolution"
79
+ )
80
+
81
+ with gr.Tab("Text-to-Video"):
82
+ with gr.Row():
83
+ with gr.Column():
84
+ text_prompt = gr.Textbox(
85
+ label="Prompt (Less than 128 words)",
86
+ placeholder="Enter a text prompt for the video",
87
+ lines=2
88
  )
89
+ temp_slider = gr.Slider(1, 31, value=16, step=1, label="Duration")
90
+ guidance_scale_slider = gr.Slider(1.0, 15.0, value=9.0, step=0.1, label="Guidance Scale")
91
+ video_guidance_scale_slider = gr.Slider(1.0, 10.0, value=5.0, step=0.1, label="Video Guidance Scale")
92
+ txt_generate = gr.Button("Generate Video")
93
+ with gr.Column():
94
+ txt_output = gr.Video(label="Generated Video")
95
+ gr.Examples(
96
+ examples=[
97
+ [
98
+ "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors",
99
+ 16,
100
+ 9.0,
101
+ 5.0,
102
+ "768p",
103
+ 4
104
+ ],
105
+ [
106
+ "Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes",
107
+ 16,
108
+ 9.0,
109
+ 5.0,
110
+ "768p",
111
+ 4
112
+ ],
113
+ [
114
+ "Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. Shallow focus and light smoke. vivid colours",
115
+ 31,
116
+ 9.0,
117
+ 5.0,
118
+ "768p",
119
+ 4
120
+ ],
121
+ ],
122
+ inputs=[text_prompt, temp_slider, guidance_scale_slider, video_guidance_scale_slider, resolution_dropdown, gpus_dropdown],
123
+ outputs=[txt_output],
124
+ fn=generate_text_to_video,
125
+ cache_examples='lazy',
126
+ )
127
+
128
+ # Update generate function for Text-to-Video
129
+ txt_generate.click(
130
+ generate_text_to_video,
131
+ inputs=[
132
+ text_prompt,
133
+ temp_slider,
134
+ guidance_scale_slider,
135
+ video_guidance_scale_slider,
136
+ resolution_dropdown,
137
+ gpus_dropdown
138
+ ],
139
+ outputs=txt_output
140
+ )
141
 
142
+ # Launch Gradio app
143
+ demo.launch(share=False)