Files changed (1) hide show
  1. app1.py +310 -0
app1.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+ import base64
4
+ from functools import partial
5
+
6
+ from PIL import Image, ImageOps
7
+ import gradio as gr
8
+
9
+ from makeavid_sd.inference import (
10
+ InferenceUNetPseudo3D,
11
+ jnp,
12
+ SCHEDULERS
13
+ )
14
+
15
+ print(os.environ.get('XLA_PYTHON_CLIENT_PREALLOCATE', 'NotSet'))
16
+ print(os.environ.get('XLA_PYTHON_CLIENT_ALLOCATOR', 'NotSet'))
17
+
18
+ _preheat: bool = False
19
+
20
+ _seen_compilations = set()
21
+
22
+ _model = InferenceUNetPseudo3D(
23
+ model_path = 'TempoFunk/makeavid-sd-jax',
24
+ dtype = jnp.float16,
25
+ hf_auth_token = os.environ.get('HUGGING_FACE_HUB_TOKEN', None)
26
+ )
27
+
28
+ if _model.failed != False:
29
+ trace = f'```{_model.failed}```'
30
+ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled = False) as demo:
31
+ exception = gr.Markdown(trace)
32
+
33
+ demo.launch()
34
+
35
+ _output_formats = (
36
+ 'webp', 'gif'
37
+ )
38
+
39
+ # gradio is illiterate. type hints make it go poopoo in pantsu.
40
+ def generate(
41
+ prompt = 'An elderly man having a great time in the park.',
42
+ neg_prompt = '',
43
+ hint_image = None,
44
+ inference_steps = 20,
45
+ cfg = 15.0,
46
+ cfg_image = 9.0,
47
+ seed = 0,
48
+ fps = 24,
49
+ num_frames = 24,
50
+ height = 512,
51
+ width = 512,
52
+ scheduler_type = 'DPM',
53
+ output_format = 'webp'
54
+ ) -> str:
55
+ num_frames = int(num_frames)
56
+ inference_steps = int(inference_steps)
57
+ height = int(height)
58
+ width = int(width)
59
+ height = (height // 64) * 64
60
+ width = (width // 64) * 64
61
+ cfg = max(cfg, 1.0)
62
+ cfg_image = max(cfg_image, 1.0)
63
+ seed = int(seed)
64
+ if seed < 0:
65
+ seed = -seed
66
+ if hint_image is not None:
67
+ if hint_image.mode != 'RGB':
68
+ hint_image = hint_image.convert('RGB')
69
+ if hint_image.size != (width, height):
70
+ hint_image = ImageOps.fit(hint_image, (width, height), method = Image.Resampling.LANCZOS)
71
+ if scheduler_type not in SCHEDULERS:
72
+ scheduler_type = 'DPM'
73
+ output_format = output_format.lower()
74
+ if output_format not in _output_formats:
75
+ output_format = 'webp'
76
+ mask_image = None
77
+ images = _model.generate(
78
+ prompt = [prompt] * _model.device_count,
79
+ neg_prompt = neg_prompt,
80
+ hint_image = hint_image,
81
+ mask_image = mask_image,
82
+ inference_steps = inference_steps,
83
+ cfg = cfg,
84
+ cfg_image = cfg_image,
85
+ height = height,
86
+ width = width,
87
+ num_frames = num_frames,
88
+ seed = seed,
89
+ scheduler_type = scheduler_type
90
+ )
91
+ _seen_compilations.add((hint_image is None, inference_steps, height, width, num_frames))
92
+ buffer = BytesIO()
93
+ images[1].save(
94
+ buffer,
95
+ format = output_format,
96
+ save_all = True,
97
+ append_images = images[2:],
98
+ loop = 0,
99
+ duration = round(1000 / fps),
100
+ allow_mixed=True
101
+ )
102
+ data = base64.b64encode(buffer.getvalue()).decode()
103
+ buffer.close()
104
+ data = f'data:image/{output_format};base64,' + data
105
+ return data
106
+ def check_if_compiled(hint_image, inference_steps, height, width, num_frames, scheduler_type, message):
107
+ height = int(height)
108
+ width = int(width)
109
+ inference_steps = int(inference_steps)
110
+ height = (height // 64) * 64
111
+ width = (width // 64) * 64
112
+ if (hint_image is None, inference_steps, height, width, num_frames, scheduler_type) in _seen_compilations:
113
+ return ''
114
+ else:
115
+ return f"""{message}"""
116
+
117
+ if _preheat:
118
+ print('\npreheating the oven')
119
+ generate(
120
+ prompt = 'preheating the oven',
121
+ neg_prompt = '',
122
+ image = None,
123
+ inference_steps = 20,
124
+ cfg = 12.0,
125
+ seed = 0
126
+ )
127
+ print('Entertaining the guests with sailor songs played on an old piano.')
128
+ dada = generate(
129
+ prompt = 'Entertaining the guests with sailor songs played on an old harmonium.',
130
+ neg_prompt = '',
131
+ image = Image.new('RGB', size = (512, 512), color = (0, 0, 0)),
132
+ inference_steps = 20,
133
+ cfg = 12.0,
134
+ seed = 0
135
+ )
136
+ print('dinner is ready\n')
137
+
138
+ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled = False) as demo:
139
+ variant = 'panel'
140
+ with gr.Row():
141
+ with gr.Column():
142
+ intro1 = gr.Markdown("""
143
+ # Make-A-Video Stable Diffusion JAX
144
+ We have extended a pretrained LDM inpainting image generation model with temporal convolutions and attention.
145
+ By taking advantage of the extra 5 input channels of the inpaint model, we guide the video generation with a hint image.
146
+ In this demo the hint image can be given by the user, otherwise it is generated by an generative image model.
147
+ The temporal layers are a port of [Make-A-Video PyTorch](https://github.com/lucidrains/make-a-video-pytorch) to FLAX.
148
+ The convolution is pseudo 3D and seperately convolves accross the spatial dimension in 2D and over the temporal dimension in 1D.
149
+ Temporal attention is purely self attention and also separately attends to time.
150
+ Only the new temporal layers have been fine tuned on a dataset of videos themed around dance.
151
+ The model has been trained for 80 epochs on a dataset of 18,000 Videos with 120 frames each, randomly selecting a 24 frame range from each sample.
152
+ See model and dataset links in the metadata.
153
+ Model implementation and training code can be found at <https://github.com/lopho/makeavid-sd-tpu>
154
+ """)
155
+ with gr.Column():
156
+ intro3 = gr.Markdown("""
157
+ **Please be patient. The model might have to compile with current parameters.**
158
+ This can take up to 5 minutes on the first run, and 2-3 minutes on later runs.
159
+ The compilation will be cached and consecutive runs with the same parameters
160
+ will be much faster.
161
+ Changes to the following parameters require the model to compile
162
+ - Number of frames
163
+ - Width & Height
164
+ - Inference steps
165
+ - Input image vs. no input image
166
+ - Noise scheduler type
167
+ If you encounter any issues, please report them here: [Space discussions](https://huggingface.co/spaces/TempoFunk/makeavid-sd-jax/discussions)
168
+ """)
169
+
170
+ with gr.Row(variant = variant):
171
+ with gr.Column():
172
+ with gr.Row():
173
+ #cancel_button = gr.Button(value = 'Cancel')
174
+ submit_button = gr.Button(value = 'Make A Video', variant = 'primary')
175
+ prompt_input = gr.Textbox(
176
+ label = 'Prompt',
177
+ value = 'They are dancing in the club but everybody is a 3d cg hairy monster wearing a hairy costume.',
178
+ interactive = True
179
+ )
180
+ neg_prompt_input = gr.Textbox(
181
+ label = 'Negative prompt (optional)',
182
+ value = 'monochrome, saturated',
183
+ interactive = True
184
+ )
185
+ cfg_input = gr.Slider(
186
+ label = 'Guidance scale video',
187
+ minimum = 1.0,
188
+ maximum = 20.0,
189
+ step = 0.1,
190
+ value = 15.0,
191
+ interactive = True
192
+ )
193
+ cfg_image_input = gr.Slider(
194
+ label = 'Guidance scale hint (no effect with input image)',
195
+ minimum = 1.0,
196
+ maximum = 20.0,
197
+ step = 0.1,
198
+ value = 9.0,
199
+ interactive = True
200
+ )
201
+ seed_input = gr.Number(
202
+ label = 'Random seed',
203
+ value = 0,
204
+ interactive = True,
205
+ precision = 0
206
+ )
207
+ image_input = gr.Image(
208
+ label = 'Hint image (optional)',
209
+ interactive = True,
210
+ image_mode = 'RGB',
211
+ type = 'pil',
212
+ optional = True,
213
+ source = 'upload'
214
+ )
215
+ inference_steps_input = gr.Slider(
216
+ label = 'Steps',
217
+ minimum = 2,
218
+ maximum = 100,
219
+ value = 20,
220
+ step = 1,
221
+ interactive = True
222
+ )
223
+ num_frames_input = gr.Slider(
224
+ label = 'Number of frames to generate',
225
+ minimum = 1,
226
+ maximum = 24,
227
+ step = 1,
228
+ value = 24,
229
+ interactive = True
230
+ )
231
+ width_input = gr.Slider(
232
+ label = 'Width',
233
+ minimum = 64,
234
+ maximum = 576,
235
+ step = 64,
236
+ value = 512,
237
+ interactive = True
238
+ )
239
+ height_input = gr.Slider(
240
+ label = 'Height',
241
+ minimum = 64,
242
+ maximum = 576,
243
+ step = 64,
244
+ value = 512,
245
+ interactive = True
246
+ )
247
+ scheduler_input = gr.Dropdown(
248
+ label = 'Noise scheduler',
249
+ choices = list(SCHEDULERS.keys()),
250
+ value = 'DPM',
251
+ interactive = True
252
+ )
253
+ with gr.Row():
254
+ fps_input = gr.Slider(
255
+ label = 'Output FPS',
256
+ minimum = 1,
257
+ maximum = 1000,
258
+ step = 1,
259
+ value = 12,
260
+ interactive = True
261
+ )
262
+ output_format = gr.Dropdown(
263
+ label = 'Output format',
264
+ choices = _output_formats,
265
+ value = 'gif',
266
+ interactive = True
267
+ )
268
+ with gr.Column():
269
+ #will_trigger = gr.Markdown('')
270
+ patience = gr.Markdown('**Please be patient. The model might have to compile with current parameters.**')
271
+ image_output = gr.Image(
272
+ label = 'Output',
273
+ value = 'example.gif',
274
+ interactive = False
275
+ )
276
+ #trigger_inputs = [ image_input, inference_steps_input, height_input, width_input, num_frames_input, scheduler_input ]
277
+ #trigger_check_fun = partial(check_if_compiled, message = 'Current parameters need compilation.')
278
+ #height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
279
+ #width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
280
+ #num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
281
+ #image_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
282
+ #inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
283
+ #scheduler_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
284
+ submit_button.click(
285
+ fn = generate,
286
+ inputs = [
287
+ prompt_input,
288
+ neg_prompt_input,
289
+ image_input,
290
+ inference_steps_input,
291
+ cfg_input,
292
+ cfg_image_input,
293
+ seed_input,
294
+ fps_input,
295
+ num_frames_input,
296
+ height_input,
297
+ width_input,
298
+ scheduler_input,
299
+ output_format
300
+ ],
301
+ outputs = image_output,
302
+ postprocess = False
303
+ )
304
+ #cancel_button.click(fn = lambda: None, cancels = ev)
305
+
306
+ demo.queue(concurrency_count = 1, max_size = 12)
307
+ demo.launch()
308
+ # Photorealistic fantasy oil painting of the angry minotaur in a threatening pose by Randy Vargas.
309
+ # A girl is dancing by a beautiful lake by sophie anderson and greg rutkowski and alphonse mucha.
310
+ # They are dancing in the club but everybody is a 3d cg hairy monster wearing a hairy costume.