ysharma HF staff commited on
Commit
a447e83
·
verified ·
1 Parent(s): d65c6e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -195
app.py CHANGED
@@ -1,198 +1,25 @@
1
  import gradio as gr
2
- import torch
3
- import os
4
- import spaces
5
- import uuid
6
 
7
- from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
8
- from diffusers.utils import export_to_video
9
- from huggingface_hub import hf_hub_download
10
- from safetensors.torch import load_file
11
- from PIL import Image
12
- from gradio_client import Client, file
13
- from moviepy.editor import VideoFileClip, AudioFileClip, concatenate_videoclips
14
 
15
-
16
- # using tango2 via Gradio python client
17
- client = Client("declare-lab/tango2")
18
-
19
- # Constants
20
- bases = {
21
- "ToonYou": "frankjoshua/toonyou_beta6",
22
- "epiCRealism": "emilianJR/epiCRealism"
23
- }
24
- step_loaded = None
25
- base_loaded = "epiCRealism"
26
- motion_loaded = None
27
-
28
- # Ensure model and scheduler are initialized in GPU-enabled function
29
- if not torch.cuda.is_available():
30
- raise NotImplementedError("No GPU detected!")
31
-
32
- device = "cuda"
33
- dtype = torch.float16
34
- pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
35
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
36
-
37
- # Safety checkers
38
- from safety_checker import StableDiffusionSafetyChecker
39
- from transformers import CLIPFeatureExtractor
40
-
41
- safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device)
42
- feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
43
-
44
- def check_nsfw_images(images: list[Image.Image]) -> list[bool]:
45
- safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
46
- has_nsfw_concepts = safety_checker(images=[images], clip_input=safety_checker_input.pixel_values.to(device))
47
- return has_nsfw_concepts
48
-
49
- # Function
50
- @spaces.GPU(enable_queue=True)
51
- def generate_image(prompt, base, motion, step, progress=gr.Progress()):
52
- global step_loaded
53
- global base_loaded
54
- global motion_loaded
55
- print(prompt, base, step)
56
-
57
- if step_loaded != step:
58
- repo = "ByteDance/AnimateDiff-Lightning"
59
- ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
60
- pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
61
- step_loaded = step
62
-
63
- if base_loaded != base:
64
- pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
65
- base_loaded = base
66
-
67
- if motion_loaded != motion:
68
- pipe.unload_lora_weights()
69
- if motion != "":
70
- pipe.load_lora_weights(motion, adapter_name="motion")
71
- pipe.set_adapters(["motion"], [0.7])
72
- motion_loaded = motion
73
-
74
- progress((0, step))
75
- def progress_callback(i, t, z):
76
- progress((i+1, step))
77
-
78
- output = pipe(prompt=prompt, guidance_scale=1.0, num_inference_steps=step, callback=progress_callback, callback_steps=1)
79
-
80
- has_nsfw_concepts = check_nsfw_images([output.frames[0][0]])
81
- if has_nsfw_concepts[0]:
82
- gr.Warning("NSFW content detected.")
83
- return None
84
-
85
- name = str(uuid.uuid4()).replace("-", "")
86
- video_path = f"/tmp/{name}.mp4"
87
- export_to_video(output.frames[0], video_path, fps=10)
88
-
89
- audio_path = tango2(prompt)
90
- final_video_path = fuse_together(audio_path, video_path)
91
-
92
- return final_video_path
93
-
94
-
95
- def tango2(prompt):
96
- results = client.predict(
97
- prompt=prompt,
98
- steps=100,
99
- guidance=3,
100
- api_name="/predict"
101
- )
102
- return results
103
-
104
- def fuse_together(audio, video):
105
-
106
- # Load your video and audio files
107
- video_clip = VideoFileClip(video)
108
- audio_clip = AudioFileClip(audio)
109
-
110
- # Loop the video twice
111
- looped_video = concatenate_videoclips([video_clip, video_clip])
112
-
113
- # Cut the audio to match the duration of the looped video
114
- looped_audio = audio_clip.subclip(0, looped_video.duration)
115
-
116
- # Set the audio of the looped video to the adjusted audio clip
117
- final_video = looped_video.set_audio(looped_audio)
118
-
119
- # Write the result to a file (output will be twice the length of the original video)
120
- name = str(uuid.uuid4()).replace("-", "")
121
- path = f"/tmp/{name}.mp4"
122
- final_video.write_videofile(path, codec="libx264", audio_codec="aac")
123
-
124
- return path
125
-
126
-
127
- # Gradio Interface
128
- with gr.Blocks(css="style.css") as demo:
129
- gr.HTML(
130
- "<h1><center>AnimateDiff-Lightning⚡ + TANGO 2</center></h1>" +
131
- "<p><center>Using Gradio Python Client to combine <b>AnimateDiff Lightning</b> with <b>Tango2</b> to give Voice to your Generated Videos</center></p>" +
132
- "<p><center>Refer Gradio Guide for Python Clients here :<a href='https://www.gradio.app/guides/getting-started-with-the-python-client'>Getting Started with the Gradio Python client</a></center></p>"
133
- )
134
- with gr.Group():
135
- with gr.Row():
136
- prompt = gr.Textbox(
137
- label='Prompt (English)'
138
- )
139
- with gr.Row():
140
- select_base = gr.Dropdown(
141
- label='Base model',
142
- choices=[
143
- "ToonYou",
144
- "epiCRealism",
145
- ],
146
- value=base_loaded,
147
- interactive=True
148
- )
149
- select_motion = gr.Dropdown(
150
- label='Motion',
151
- choices=[
152
- ("Default", ""),
153
- ("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"),
154
- ("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"),
155
- ("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"),
156
- ("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"),
157
- ("Pan left", "guoyww/animatediff-motion-lora-pan-left"),
158
- ("Pan right", "guoyww/animatediff-motion-lora-pan-right"),
159
- ("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"),
160
- ("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"),
161
- ],
162
- value="",
163
- interactive=True
164
- )
165
- select_step = gr.Dropdown(
166
- label='Inference steps',
167
- choices=[
168
- ('1-Step', 1),
169
- ('2-Step', 2),
170
- ('4-Step', 4),
171
- ('8-Step', 8)],
172
- value=4,
173
- interactive=True
174
- )
175
- submit = gr.Button(
176
- scale=1,
177
- variant='primary'
178
- )
179
- video = gr.Video(
180
- label='AnimateDiff-Lightning',
181
- autoplay=True,
182
- height=512,
183
- width=512,
184
- elem_id="video_output"
185
- )
186
-
187
- prompt.submit(
188
- fn=generate_image,
189
- inputs=[prompt, select_base, select_motion, select_step],
190
- outputs=video,
191
- )
192
- submit.click(
193
- fn=generate_image,
194
- inputs=[prompt, select_base, select_motion, select_step],
195
- outputs=video,
196
- )
197
-
198
- demo.queue().launch()
 
1
  import gradio as gr
 
 
 
 
2
 
 
 
 
 
 
 
 
3
 
4
+ with gr.Blocks() as demo:
5
+ with gr.Row():
6
+ with gr.Column(visible=False, min_width=200, scale=0) as sidebar:
7
+ btn1 = gr.Button("Button 1")
8
+ btn2 = gr.Button("Button 2")
9
+ with gr.Column() as main:
10
+ open_sidebar_btn = gr.Button("Open Sidebar", scale=0)
11
+ close_sidebar_btn = gr.Button("Close Sidebar", visible=False, scale=0)
12
+ open_sidebar_btn.click(lambda: {
13
+ open_sidebar_btn: gr.Button(visible=False),
14
+ close_sidebar_btn: gr.Button(visible=True),
15
+ sidebar: gr.Column(visible=True)
16
+ }, outputs={open_sidebar_btn, close_sidebar_btn, sidebar})
17
+ close_sidebar_btn.click(lambda: {
18
+ open_sidebar_btn: gr.Button(visible=True),
19
+ close_sidebar_btn: gr.Button(visible=False),
20
+ sidebar: gr.Column(visible=False)
21
+ }, outputs={open_sidebar_btn, close_sidebar_btn, sidebar})
22
+ gr.Markdown("# Hello Blocks")
23
+ gr.Markdown("Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nullam nec nulla nec nulla fermentum fermentum. Nullam nec nulla nec nulla fermentum fermentum.")
24
+
25
+ demo.launch()