camenduru commited on
Commit
7004979
·
verified ·
1 Parent(s): 0d24c89

Create worker_runpod.py

Browse files
Files changed (1) hide show
  1. worker_runpod.py +246 -0
worker_runpod.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import torch
7
+ import torchaudio
8
+
9
+ from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
10
+ setup_eval_logging)
11
+ from mmaudio.model.flow_matching import FlowMatching
12
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
13
+ from mmaudio.model.sequence_config import SequenceConfig
14
+ from mmaudio.model.utils.features_utils import FeaturesUtils
15
+
16
+ torch.backends.cuda.matmul.allow_tf32 = True
17
+ torch.backends.cudnn.allow_tf32 = True
18
+
19
+ log = logging.getLogger()
20
+
21
+ device = 'cuda'
22
+ dtype = torch.bfloat16
23
+
24
+ model: ModelConfig = all_model_cfg['large_44k_v2']
25
+ model.download_if_needed()
26
+ output_dir = Path('./output/gradio')
27
+
28
+ setup_eval_logging()
29
+
30
+
31
+ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
32
+ seq_cfg = model.seq_cfg
33
+
34
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
35
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
36
+ log.info(f'Loaded weights from {model.model_path}')
37
+
38
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
39
+ synchformer_ckpt=model.synchformer_ckpt,
40
+ enable_conditions=True,
41
+ mode=model.mode,
42
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path)
43
+ feature_utils = feature_utils.to(device, dtype).eval()
44
+
45
+ return net, feature_utils, seq_cfg
46
+
47
+
48
+ net, feature_utils, seq_cfg = get_model()
49
+
50
+
51
+ @torch.inference_mode()
52
+ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
53
+ cfg_strength: float, duration: float):
54
+
55
+ rng = torch.Generator(device=device)
56
+ rng.manual_seed(seed)
57
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
58
+
59
+ clip_frames, sync_frames, duration = load_video(video, duration)
60
+ clip_frames = clip_frames.unsqueeze(0)
61
+ sync_frames = sync_frames.unsqueeze(0)
62
+ seq_cfg.duration = duration
63
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
64
+
65
+ audios = generate(clip_frames,
66
+ sync_frames, [prompt],
67
+ negative_text=[negative_prompt],
68
+ feature_utils=feature_utils,
69
+ net=net,
70
+ fm=fm,
71
+ rng=rng,
72
+ cfg_strength=cfg_strength)
73
+ audio = audios.float().cpu()[0]
74
+
75
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
76
+ output_dir.mkdir(exist_ok=True, parents=True)
77
+ video_save_path = output_dir / f'{current_time_string}.mp4'
78
+ make_video(video,
79
+ video_save_path,
80
+ audio,
81
+ sampling_rate=seq_cfg.sampling_rate,
82
+ duration_sec=seq_cfg.duration)
83
+ return video_save_path
84
+
85
+
86
+ @torch.inference_mode()
87
+ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
88
+ duration: float):
89
+
90
+ rng = torch.Generator(device=device)
91
+ rng.manual_seed(seed)
92
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
93
+
94
+ clip_frames = sync_frames = None
95
+ seq_cfg.duration = duration
96
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
97
+
98
+ audios = generate(clip_frames,
99
+ sync_frames, [prompt],
100
+ negative_text=[negative_prompt],
101
+ feature_utils=feature_utils,
102
+ net=net,
103
+ fm=fm,
104
+ rng=rng,
105
+ cfg_strength=cfg_strength)
106
+ audio = audios.float().cpu()[0]
107
+
108
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
109
+ output_dir.mkdir(exist_ok=True, parents=True)
110
+ audio_save_path = output_dir / f'{current_time_string}.flac'
111
+ torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
112
+ return audio_save_path
113
+
114
+
115
+ video_to_audio_tab = gr.Interface(
116
+ fn=video_to_audio,
117
+ inputs=[
118
+ gr.Video(),
119
+ gr.Text(label='Prompt'),
120
+ gr.Text(label='Negative prompt', value='music'),
121
+ gr.Number(label='Seed', value=0, precision=0, minimum=0),
122
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
123
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
124
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
125
+ ],
126
+ outputs='playable_video',
127
+ cache_examples=False,
128
+ title='MMAudio — Video-to-Audio Synthesis',
129
+ examples=[
130
+ [
131
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
132
+ '',
133
+ '',
134
+ 0,
135
+ 25,
136
+ 4.5,
137
+ 10,
138
+ ],
139
+ [
140
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4',
141
+ '',
142
+ 'music',
143
+ 0,
144
+ 25,
145
+ 4.5,
146
+ 10,
147
+ ],
148
+ [
149
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_seahorse.mp4',
150
+ 'bubbles',
151
+ '',
152
+ 0,
153
+ 25,
154
+ 4.5,
155
+ 10,
156
+ ],
157
+ [
158
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_india.mp4',
159
+ 'Indian holy music',
160
+ '',
161
+ 0,
162
+ 25,
163
+ 4.5,
164
+ 10,
165
+ ],
166
+ [
167
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_galloping.mp4',
168
+ 'galloping',
169
+ '',
170
+ 0,
171
+ 25,
172
+ 4.5,
173
+ 10,
174
+ ],
175
+ [
176
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
177
+ 'waves, seagulls',
178
+ '',
179
+ 0,
180
+ 25,
181
+ 4.5,
182
+ 10,
183
+ ],
184
+ [
185
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
186
+ 'waves, storm',
187
+ '',
188
+ 0,
189
+ 25,
190
+ 4.5,
191
+ 10,
192
+ ],
193
+ [
194
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
195
+ 'storm',
196
+ '',
197
+ 0,
198
+ 25,
199
+ 4.5,
200
+ 10,
201
+ ],
202
+ [
203
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
204
+ '',
205
+ '',
206
+ 0,
207
+ 25,
208
+ 4.5,
209
+ 10,
210
+ ],
211
+ [
212
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
213
+ 'typing',
214
+ '',
215
+ 0,
216
+ 25,
217
+ 4.5,
218
+ 10,
219
+ ],
220
+ [
221
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
222
+ '',
223
+ '',
224
+ 0,
225
+ 25,
226
+ 4.5,
227
+ 10,
228
+ ],
229
+ ])
230
+
231
+ text_to_audio_tab = gr.Interface(
232
+ fn=text_to_audio,
233
+ inputs=[
234
+ gr.Text(label='Prompt'),
235
+ gr.Text(label='Negative prompt'),
236
+ gr.Number(label='Seed', value=0, precision=0, minimum=0),
237
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
238
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
239
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
240
+ ],
241
+ outputs='audio',
242
+ cache_examples=False,
243
+ title='MMAudio — Text-to-Audio Synthesis',
244
+ )
245
+
246
+ gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab],['Video-to-Audio', 'Text-to-Audio']).launch(inline=False, share=False, debug=True, server_name='0.0.0.0', server_port=7860, allowed_paths=[output_dir])