BestWishYsh commited on
Commit
f5123f5
·
verified ·
1 Parent(s): 1c6d7a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -13
app.py CHANGED
@@ -116,13 +116,15 @@ os.makedirs("./gradio_tmp", exist_ok=True)
116
  upscale_model = load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device)
117
  frame_interpolation_model = load_rife_model("model_rife")
118
 
119
- def infer(
 
120
  prompt: str,
121
  image_input: str,
122
  num_inference_steps: int,
123
  guidance_scale: float,
124
  seed: int = 42,
125
- progress=gr.Progress(track_tqdm=True),
 
126
  ):
127
  if seed == -1:
128
  seed = random.randint(0, 2**8 - 1)
@@ -167,6 +169,12 @@ def infer(
167
  ).frames
168
 
169
  free_memory()
 
 
 
 
 
 
170
  return (video_pt, seed)
171
 
172
 
@@ -320,8 +328,8 @@ with gr.Blocks() as demo:
320
  </table>
321
  """)
322
 
323
- @spaces.GPU(duration=180)
324
- def generate(
325
  prompt,
326
  image_input,
327
  seed_value,
@@ -329,18 +337,15 @@ with gr.Blocks() as demo:
329
  rife_status,
330
  progress=gr.Progress(track_tqdm=True)
331
  ):
332
- latents, seed = infer(
333
  prompt,
334
  image_input,
335
  num_inference_steps=50,
336
  guidance_scale=7.0,
337
  seed=seed_value,
338
- progress=progress,
339
- )
340
- if scale_status:
341
- latents = upscale_batch_and_concatenate(upscale_model, latents, device)
342
- if rife_status:
343
- latents = rife_inference_with_latents(frame_interpolation_model, latents)
344
 
345
  batch_size = latents.shape[0]
346
  batch_video_frames = []
@@ -361,11 +366,11 @@ with gr.Blocks() as demo:
361
  return video_path, video_update, gif_update, seed_update
362
 
363
  generate_button.click(
364
- generate,
365
  inputs=[prompt, image_input, seed_param, enable_scale, enable_rife],
366
  outputs=[video_output, download_video_button, download_gif_button, seed_text],
367
  )
368
 
369
  if __name__ == "__main__":
370
  demo.queue(max_size=15)
371
- demo.launch()
 
116
  upscale_model = load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device)
117
  frame_interpolation_model = load_rife_model("model_rife")
118
 
119
+ @spaces.GPU(duration=300)
120
+ def generate(
121
  prompt: str,
122
  image_input: str,
123
  num_inference_steps: int,
124
  guidance_scale: float,
125
  seed: int = 42,
126
+ scale_status: bool = False,
127
+ rife_status: bool = False,
128
  ):
129
  if seed == -1:
130
  seed = random.randint(0, 2**8 - 1)
 
169
  ).frames
170
 
171
  free_memory()
172
+
173
+ if scale_status:
174
+ video_pt = upscale_batch_and_concatenate(upscale_model, video_pt, device)
175
+ if rife_status:
176
+ video_pt = rife_inference_with_latents(frame_interpolation_model, video_pt)
177
+
178
  return (video_pt, seed)
179
 
180
 
 
328
  </table>
329
  """)
330
 
331
+
332
+ def run(
333
  prompt,
334
  image_input,
335
  seed_value,
 
337
  rife_status,
338
  progress=gr.Progress(track_tqdm=True)
339
  ):
340
+ latents, seed = generate(
341
  prompt,
342
  image_input,
343
  num_inference_steps=50,
344
  guidance_scale=7.0,
345
  seed=seed_value,
346
+ scale_status=scale_status,
347
+ rife_status=rife_status,
348
+ )
 
 
 
349
 
350
  batch_size = latents.shape[0]
351
  batch_video_frames = []
 
366
  return video_path, video_update, gif_update, seed_update
367
 
368
  generate_button.click(
369
+ fn=run,
370
  inputs=[prompt, image_input, seed_param, enable_scale, enable_rife],
371
  outputs=[video_output, download_video_button, download_gif_button, seed_text],
372
  )
373
 
374
  if __name__ == "__main__":
375
  demo.queue(max_size=15)
376
+ demo.launch()