WeichenFan commited on
Commit
290c968
1 Parent(s): df015e0

update demo

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -24,14 +24,14 @@ os.makedirs("./output", exist_ok=True)
24
  os.makedirs("./gradio_tmp", exist_ok=True)
25
 
26
  @spaces.GPU(duration=120)
27
- def infer(prompt: str, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True)):
28
  torch.cuda.empty_cache()
29
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
30
  video = pipe(
31
  prompt,
32
  negative_prompt="",
33
- num_inference_steps=num_inference_steps,
34
- guidance_scale=guidance_scale,
35
  width=432,
36
  height=240, #480x288 624x352 432x240 768x432
37
  frames=16
@@ -117,8 +117,8 @@ with gr.Blocks() as demo:
117
 
118
 
119
 
120
- def generate(prompt, num_inference_steps, guidance_scale, model_choice, progress=gr.Progress(track_tqdm=True)):
121
- tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
122
  video_path = save_video(tensor)
123
  video_update = gr.update(visible=True, value=video_path)
124
  gif_path = convert_to_gif(video_path)
@@ -133,7 +133,7 @@ with gr.Blocks() as demo:
133
 
134
  generate_button.click(
135
  generate,
136
- inputs=[prompt, num_inference_steps, guidance_scale],
137
  outputs=[video_output, download_video_button, download_gif_button]
138
  )
139
 
 
24
  os.makedirs("./gradio_tmp", exist_ok=True)
25
 
26
  @spaces.GPU(duration=120)
27
+ def infer(prompt: str, progress=gr.Progress(track_tqdm=True)):
28
  torch.cuda.empty_cache()
29
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
30
  video = pipe(
31
  prompt,
32
  negative_prompt="",
33
+ num_inference_steps=50,
34
+ guidance_scale=7.5,
35
  width=432,
36
  height=240, #480x288 624x352 432x240 768x432
37
  frames=16
 
117
 
118
 
119
 
120
+ def generate(prompt, model_choice, progress=gr.Progress(track_tqdm=True)):
121
+ tensor = infer(prompt, progress=progress)
122
  video_path = save_video(tensor)
123
  video_update = gr.update(visible=True, value=video_path)
124
  gif_path = convert_to_gif(video_path)
 
133
 
134
  generate_button.click(
135
  generate,
136
+ inputs=[prompt],
137
  outputs=[video_output, download_video_button, download_gif_button]
138
  )
139