Spaces:
Runtime error
Runtime error
rodrigomasini
commited on
Commit
•
554d8e7
1
Parent(s):
27dc24e
Update app.py
Browse files
app.py
CHANGED
@@ -33,9 +33,9 @@ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, times
|
|
33 |
# Safety checkers
|
34 |
from transformers import CLIPFeatureExtractor
|
35 |
|
36 |
-
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
|
37 |
|
38 |
-
# Function
|
39 |
@spaces.GPU(duration=15,enable_queue=True)
|
40 |
def generate_image(prompt, base, motion, step, progress=gr.Progress()):
|
41 |
global step_loaded
|
@@ -44,8 +44,8 @@ def generate_image(prompt, base, motion, step, progress=gr.Progress()):
|
|
44 |
print(prompt, base, step)
|
45 |
|
46 |
if step_loaded != step:
|
47 |
-
repo = "ByteDance/AnimateDiff-Lightning"
|
48 |
-
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
|
49 |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
|
50 |
step_loaded = step
|
51 |
|
@@ -61,10 +61,11 @@ def generate_image(prompt, base, motion, step, progress=gr.Progress()):
|
|
61 |
motion_loaded = motion
|
62 |
|
63 |
progress((0, step))
|
|
|
64 |
def progress_callback(i, t, z):
|
65 |
progress((i+1, step))
|
66 |
|
67 |
-
output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step, callback=progress_callback, callback_steps=1)
|
68 |
|
69 |
name = str(uuid.uuid4()).replace("-", "")
|
70 |
path = f"/tmp/{name}.mp4"
|
@@ -73,13 +74,15 @@ def generate_image(prompt, base, motion, step, progress=gr.Progress()):
|
|
73 |
|
74 |
|
75 |
# Gradio Interface
|
76 |
-
with gr.Blocks(css="style.css") as
|
77 |
gr.HTML(
|
78 |
-
"<h1><center>
|
79 |
-
"<p><center>
|
80 |
-
"<p><center
|
81 |
-
"<p><center
|
82 |
-
"<p><center>
|
|
|
|
|
83 |
)
|
84 |
with gr.Group():
|
85 |
with gr.Row():
|
@@ -130,7 +133,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
130 |
variant='primary'
|
131 |
)
|
132 |
video = gr.Video(
|
133 |
-
label='
|
134 |
autoplay=True,
|
135 |
height=512,
|
136 |
width=512,
|
@@ -148,21 +151,4 @@ with gr.Blocks(css="style.css") as demo:
|
|
148 |
outputs=video,
|
149 |
)
|
150 |
|
151 |
-
|
152 |
-
examples=[
|
153 |
-
["Focus: Eiffel Tower (Animate: Clouds moving)"], #Atmosphere Movement Example
|
154 |
-
["Focus: Trees In forest (Animate: Lion running)"], #Object Movement Example
|
155 |
-
["Focus: Astronaut in Space"], #Normal
|
156 |
-
["Focus: Group of Birds in sky (Animate: Birds Moving) (Shot From distance)"], #Camera distance
|
157 |
-
["Focus: Statue of liberty (Shot from Drone) (Animate: Drone coming toward statue)"], #Camera Movement
|
158 |
-
["Focus: Panda in Forest (Animate: Drinking Tea)"], #Doing Something
|
159 |
-
["Focus: Kids Playing (Season: Winter)"], #Atmosphere or Season
|
160 |
-
{"Focus: Cars in Street (Season: Rain, Daytime) (Shot from Distance) (Movement: Cars running)"} #Mixture
|
161 |
-
],
|
162 |
-
fn=generate_image,
|
163 |
-
inputs=[prompt, select_base, select_motion, select_step],
|
164 |
-
outputs=video,
|
165 |
-
cache_examples=False,
|
166 |
-
)
|
167 |
-
|
168 |
-
demo.queue().launch()
|
|
|
33 |
# Safety checkers
|
34 |
from transformers import CLIPFeatureExtractor
|
35 |
|
36 |
+
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") # change for open-source model
|
37 |
|
38 |
+
# Function: we are using Gradio server to queue calls. However this is open for different architectures
|
39 |
@spaces.GPU(duration=15,enable_queue=True)
|
40 |
def generate_image(prompt, base, motion, step, progress=gr.Progress()):
|
41 |
global step_loaded
|
|
|
44 |
print(prompt, base, step)
|
45 |
|
46 |
if step_loaded != step:
|
47 |
+
repo = "ByteDance/AnimateDiff-Lightning" # we can change to other Diffusion models...
|
48 |
+
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" #...but you must change the implementation at this point to match with the checkpoint
|
49 |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
|
50 |
step_loaded = step
|
51 |
|
|
|
61 |
motion_loaded = motion
|
62 |
|
63 |
progress((0, step))
|
64 |
+
|
65 |
def progress_callback(i, t, z):
|
66 |
progress((i+1, step))
|
67 |
|
68 |
+
output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step, callback=progress_callback, callback_steps=1) #providing visibility to progress. Useful if using gradio interface
|
69 |
|
70 |
name = str(uuid.uuid4()).replace("-", "")
|
71 |
path = f"/tmp/{name}.mp4"
|
|
|
74 |
|
75 |
|
76 |
# Gradio Interface
|
77 |
+
with gr.Blocks(css="style.css", theme='sudeepshouche/minimalist') as syntvideo:
|
78 |
gr.HTML(
|
79 |
+
"<h1><center>MAGIC Demo: synthetic video generation application</center></h1>" +
|
80 |
+
"<p><center><span style='color: red;'>Change the steps from 4 to 8 to get better results.</center></p>" +
|
81 |
+
"<p><center>Write prompts in style as given in the examples below:</center></p>" +
|
82 |
+
"<p><center>Focus: Group of Birds in sky (Animate: Birds Moving) (Shot From distance)</center></p>" +
|
83 |
+
"<p><center>Focus: Trees In forest (Animate: Lion running)</center></p>" +
|
84 |
+
"<p><center>Focus: Kids Playing (Season: Winter)</center></p>" +
|
85 |
+
"<p><center>Focus: Cars in Street (Season: Rain, Daytime) (Shot from Distance) (Movement: Cars running)</center></p>"
|
86 |
)
|
87 |
with gr.Group():
|
88 |
with gr.Row():
|
|
|
133 |
variant='primary'
|
134 |
)
|
135 |
video = gr.Video(
|
136 |
+
label='Generate Synthetic Video',
|
137 |
autoplay=True,
|
138 |
height=512,
|
139 |
width=512,
|
|
|
151 |
outputs=video,
|
152 |
)
|
153 |
|
154 |
+
syntvideo.queue().launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|