rodrigomasini commited on
Commit
554d8e7
1 Parent(s): 27dc24e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -30
app.py CHANGED
@@ -33,9 +33,9 @@ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, times
33
  # Safety checkers
34
  from transformers import CLIPFeatureExtractor
35
 
36
- feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
37
 
38
- # Function
39
  @spaces.GPU(duration=15,enable_queue=True)
40
  def generate_image(prompt, base, motion, step, progress=gr.Progress()):
41
  global step_loaded
@@ -44,8 +44,8 @@ def generate_image(prompt, base, motion, step, progress=gr.Progress()):
44
  print(prompt, base, step)
45
 
46
  if step_loaded != step:
47
- repo = "ByteDance/AnimateDiff-Lightning"
48
- ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
49
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
50
  step_loaded = step
51
 
@@ -61,10 +61,11 @@ def generate_image(prompt, base, motion, step, progress=gr.Progress()):
61
  motion_loaded = motion
62
 
63
  progress((0, step))
 
64
  def progress_callback(i, t, z):
65
  progress((i+1, step))
66
 
67
- output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step, callback=progress_callback, callback_steps=1)
68
 
69
  name = str(uuid.uuid4()).replace("-", "")
70
  path = f"/tmp/{name}.mp4"
@@ -73,13 +74,15 @@ def generate_image(prompt, base, motion, step, progress=gr.Progress()):
73
 
74
 
75
  # Gradio Interface
76
- with gr.Blocks(css="style.css") as demo:
77
  gr.HTML(
78
- "<h1><center>Instant⚡Video</center></h1>" +
79
- "<p><center>Lightning-fast text-to-video generation</center></p>" +
80
- "<p><center><span style='color: red;'>You may change the steps from 4 to 8, if you didn't get satisfied results.</center></p>" +
81
- "<p><center><strong> First Video Generating takes time then Videos generate faster.</p>" +
82
- "<p><center>Write prompts in style as Given in Example</p>"
 
 
83
  )
84
  with gr.Group():
85
  with gr.Row():
@@ -130,7 +133,7 @@ with gr.Blocks(css="style.css") as demo:
130
  variant='primary'
131
  )
132
  video = gr.Video(
133
- label='AnimateDiff-Lightning',
134
  autoplay=True,
135
  height=512,
136
  width=512,
@@ -148,21 +151,4 @@ with gr.Blocks(css="style.css") as demo:
148
  outputs=video,
149
  )
150
 
151
- gr.Examples(
152
- examples=[
153
- ["Focus: Eiffel Tower (Animate: Clouds moving)"], #Atmosphere Movement Example
154
- ["Focus: Trees In forest (Animate: Lion running)"], #Object Movement Example
155
- ["Focus: Astronaut in Space"], #Normal
156
- ["Focus: Group of Birds in sky (Animate: Birds Moving) (Shot From distance)"], #Camera distance
157
- ["Focus: Statue of liberty (Shot from Drone) (Animate: Drone coming toward statue)"], #Camera Movement
158
- ["Focus: Panda in Forest (Animate: Drinking Tea)"], #Doing Something
159
- ["Focus: Kids Playing (Season: Winter)"], #Atmosphere or Season
160
- {"Focus: Cars in Street (Season: Rain, Daytime) (Shot from Distance) (Movement: Cars running)"} #Mixture
161
- ],
162
- fn=generate_image,
163
- inputs=[prompt, select_base, select_motion, select_step],
164
- outputs=video,
165
- cache_examples=False,
166
- )
167
-
168
- demo.queue().launch()
 
33
  # Safety checkers
34
  from transformers import CLIPFeatureExtractor
35
 
36
+ feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") # change for open-source model
37
 
38
+ # Function: we are using Gradio server to queue calls. However this is open for different architectures
39
  @spaces.GPU(duration=15,enable_queue=True)
40
  def generate_image(prompt, base, motion, step, progress=gr.Progress()):
41
  global step_loaded
 
44
  print(prompt, base, step)
45
 
46
  if step_loaded != step:
47
+ repo = "ByteDance/AnimateDiff-Lightning" # we can change to other Diffusion models...
48
+ ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" #...but you must change the implementation at this point to match with the checkpoint
49
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
50
  step_loaded = step
51
 
 
61
  motion_loaded = motion
62
 
63
  progress((0, step))
64
+
65
  def progress_callback(i, t, z):
66
  progress((i+1, step))
67
 
68
+ output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step, callback=progress_callback, callback_steps=1) #providing visibility to progress. Useful if using gradio interface
69
 
70
  name = str(uuid.uuid4()).replace("-", "")
71
  path = f"/tmp/{name}.mp4"
 
74
 
75
 
76
  # Gradio Interface
77
+ with gr.Blocks(css="style.css", theme='sudeepshouche/minimalist') as syntvideo:
78
  gr.HTML(
79
+ "<h1><center>MAGIC Demo: synthetic video generation application</center></h1>" +
80
+ "<p><center><span style='color: red;'>Change the steps from 4 to 8 to get better results.</center></p>" +
81
+ "<p><center>Write prompts in style as given in the examples below:</center></p>" +
82
+ "<p><center>Focus: Group of Birds in sky (Animate: Birds Moving) (Shot From distance)</center></p>" +
83
+ "<p><center>Focus: Trees In forest (Animate: Lion running)</center></p>" +
84
+ "<p><center>Focus: Kids Playing (Season: Winter)</center></p>" +
85
+ "<p><center>Focus: Cars in Street (Season: Rain, Daytime) (Shot from Distance) (Movement: Cars running)</center></p>"
86
  )
87
  with gr.Group():
88
  with gr.Row():
 
133
  variant='primary'
134
  )
135
  video = gr.Video(
136
+ label='Generate Synthetic Video',
137
  autoplay=True,
138
  height=512,
139
  width=512,
 
151
  outputs=video,
152
  )
153
 
154
+ syntvideo.queue().launch()