oahzxl commited on
Commit
09f1eaa
1 Parent(s): 48357d2
Files changed (1) hide show
  1. app.py +28 -90
app.py CHANGED
@@ -5,9 +5,7 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
5
 
6
  import uuid
7
 
8
- import GPUtil
9
  import gradio as gr
10
- import psutil
11
  import spaces
12
 
13
  from videosys import CogVideoXConfig, CogVideoXPABConfig, VideoSysEngine
@@ -30,33 +28,6 @@ def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
30
  return output_path
31
 
32
 
33
- def get_server_status():
34
- cpu_percent = psutil.cpu_percent()
35
- memory = psutil.virtual_memory()
36
- disk = psutil.disk_usage("/")
37
- gpus = GPUtil.getGPUs()
38
- gpu_info = []
39
- for gpu in gpus:
40
- gpu_info.append(
41
- {
42
- "id": gpu.id,
43
- "name": gpu.name,
44
- "load": f"{gpu.load*100:.1f}%",
45
- "memory_used": f"{gpu.memoryUsed}MB",
46
- "memory_total": f"{gpu.memoryTotal}MB",
47
- }
48
- )
49
-
50
- return {"cpu": f"{cpu_percent}%", "memory": f"{memory.percent}%", "disk": f"{disk.percent}%", "gpu": gpu_info}
51
-
52
-
53
- @spaces.GPU(duration=540)
54
- def generate_vanilla(model_name, prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
55
- engine = load_model(model_name)
56
- video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
57
- return video_path
58
-
59
-
60
  @spaces.GPU(duration=400)
61
  def generate_vs(
62
  model_name,
@@ -75,33 +46,6 @@ def generate_vs(
75
  return video_path
76
 
77
 
78
- def get_server_status():
79
- cpu_percent = psutil.cpu_percent()
80
- memory = psutil.virtual_memory()
81
- disk = psutil.disk_usage("/")
82
- try:
83
- gpus = GPUtil.getGPUs()
84
- if gpus:
85
- gpu = gpus[0]
86
- gpu_memory = f"{gpu.memoryUsed}/{gpu.memoryTotal}MB ({gpu.memoryUtil*100:.1f}%)"
87
- else:
88
- gpu_memory = "No GPU found"
89
- except:
90
- gpu_memory = "GPU information unavailable"
91
-
92
- return {
93
- "cpu": f"{cpu_percent}%",
94
- "memory": f"{memory.percent}%",
95
- "disk": f"{disk.percent}%",
96
- "gpu_memory": gpu_memory,
97
- }
98
-
99
-
100
- def update_server_status():
101
- status = get_server_status()
102
- return (status["cpu"], status["memory"], status["disk"], status["gpu_memory"])
103
-
104
-
105
  css = """
106
  body {
107
  font-family: Arial, sans-serif;
@@ -203,7 +147,7 @@ with gr.Blocks(css=css) as demo:
203
 
204
  with gr.Row():
205
  with gr.Column():
206
- prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=3)
207
 
208
  with gr.Column():
209
  gr.Markdown("**Generation Parameters**<br>")
@@ -212,44 +156,40 @@ with gr.Blocks(css=css) as demo:
212
  ["THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b"], label="Model Type", value="THUDM/CogVideoX-2b"
213
  )
214
  with gr.Row():
215
- num_inference_steps = gr.Number(label="Inference Steps", value=50)
216
- guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
 
217
  with gr.Row():
218
- pab_range = gr.Number(
219
- label="PAB Broadcast Range", value=2, precision=0, info="Broadcast timesteps range."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  )
221
- pab_threshold_start = gr.Number(label="PAB Start Timestep", value=850, info="Start from step 1000.")
222
- pab_threshold_end = gr.Number(label="PAB End Timestep", value=100, info="End at step 0.")
223
  with gr.Row():
224
- generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys (Faster)")
225
- generate_button = gr.Button("🎬 Generate Video (Original)")
226
- with gr.Column(elem_classes="server-status"):
227
- gr.Markdown("#### Server Status")
228
-
229
- with gr.Row():
230
- cpu_status = gr.Textbox(label="CPU", scale=1)
231
- memory_status = gr.Textbox(label="Memory", scale=1)
232
-
233
- with gr.Row():
234
- disk_status = gr.Textbox(label="Disk", scale=1)
235
- gpu_status = gr.Textbox(label="GPU Memory", scale=1)
236
-
237
- with gr.Row():
238
- refresh_button = gr.Button("Refresh")
239
 
240
  with gr.Column():
241
  with gr.Row():
242
  video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
243
- with gr.Row():
244
- video_output = gr.Video(label="CogVideoX", width=720, height=480)
245
-
246
- generate_button.click(
247
- generate_vanilla,
248
- inputs=[model_name, prompt, num_inference_steps, guidance_scale],
249
- outputs=[video_output],
250
- concurrency_id="gen",
251
- concurrency_limit=1,
252
- )
253
 
254
  generate_button_vs.click(
255
  generate_vs,
@@ -267,8 +207,6 @@ with gr.Blocks(css=css) as demo:
267
  concurrency_limit=1,
268
  )
269
 
270
- refresh_button.click(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status])
271
- demo.load(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status], every=1)
272
 
273
  if __name__ == "__main__":
274
  demo.queue(max_size=10, default_concurrency_limit=1)
 
5
 
6
  import uuid
7
 
 
8
  import gradio as gr
 
9
  import spaces
10
 
11
  from videosys import CogVideoXConfig, CogVideoXPABConfig, VideoSysEngine
 
28
  return output_path
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @spaces.GPU(duration=400)
32
  def generate_vs(
33
  model_name,
 
46
  return video_path
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  css = """
50
  body {
51
  font-family: Arial, sans-serif;
 
147
 
148
  with gr.Row():
149
  with gr.Column():
150
+ prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=2)
151
 
152
  with gr.Column():
153
  gr.Markdown("**Generation Parameters**<br>")
 
156
  ["THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b"], label="Model Type", value="THUDM/CogVideoX-2b"
157
  )
158
  with gr.Row():
159
+ num_inference_steps = gr.Slider(label="Inference Steps", maximum=50, value=50)
160
+ guidance_scale = gr.Slider(label="Guidance Scale", value=6.0, maximum=15.0)
161
+ gr.Markdown("**Pyramid Attention Broadcast Parameters**<br>")
162
  with gr.Row():
163
+ pab_range = gr.Slider(
164
+ label="Broadcast Range",
165
+ value=2,
166
+ step=1,
167
+ minimum=1,
168
+ maximum=4,
169
+ info="Attention broadcast range.",
170
+ )
171
+ pab_threshold_start = gr.Slider(
172
+ label="Start Timestep",
173
+ minimum=500,
174
+ maximum=1000,
175
+ value=850,
176
+ step=1,
177
+ info="Broadcast start timestep (1000 is the fisrt).",
178
+ )
179
+ pab_threshold_end = gr.Slider(
180
+ label="End Timestep",
181
+ minimum=0,
182
+ maximum=500,
183
+ step=1,
184
+ value=100,
185
+ info="Broadcast end timestep (0 is the last).",
186
  )
 
 
187
  with gr.Row():
188
+ generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  with gr.Column():
191
  with gr.Row():
192
  video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
 
 
 
 
 
 
 
 
 
 
193
 
194
  generate_button_vs.click(
195
  generate_vs,
 
207
  concurrency_limit=1,
208
  )
209
 
 
 
210
 
211
  if __name__ == "__main__":
212
  demo.queue(max_size=10, default_concurrency_limit=1)