Update app.py
Browse files
app.py
CHANGED
@@ -15,10 +15,17 @@ from pathlib import Path
|
|
15 |
|
16 |
# Initialize the model and ensure it's on the correct device
|
17 |
def load_model():
|
18 |
-
model = SemantiCodec(token_rate=
|
19 |
if torch.cuda.is_available():
|
20 |
-
# Move the model to CUDA
|
21 |
-
model.to("cuda:0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
return model
|
23 |
|
24 |
# Initialize model
|
@@ -28,13 +35,16 @@ model_device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
28 |
print(f"Model initialized on device: {model_device}")
|
29 |
|
30 |
# Define sample rate as a constant
|
31 |
-
|
|
|
32 |
|
33 |
@spaces.GPU(duration=20)
|
34 |
def encode_audio(audio_path):
|
35 |
"""Encode audio file to tokens and return them as a file"""
|
36 |
try:
|
37 |
print(f"Encoding audio on device: {model_device}")
|
|
|
|
|
38 |
tokens = semanticodec.encode(audio_path)
|
39 |
print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
|
40 |
|
@@ -88,6 +98,10 @@ def decode_tokens(token_file):
|
|
88 |
intended_device = token_data.get('device', model_device)
|
89 |
print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
|
90 |
|
|
|
|
|
|
|
|
|
91 |
# Convert to torch tensor with Long dtype for embedding
|
92 |
tokens_tensor = torch.tensor(tokens, dtype=torch.long)
|
93 |
print(f"Tokens tensor created on device: {tokens_tensor.device} with dtype: {tokens_tensor.dtype}")
|
@@ -96,10 +110,6 @@ def decode_tokens(token_file):
|
|
96 |
tokens_tensor = tokens_tensor.to(model_device)
|
97 |
print(f"Tokens moved to device: {tokens_tensor.device}")
|
98 |
|
99 |
-
# Also ensure model is on the expected device
|
100 |
-
semanticodec.to(model_device)
|
101 |
-
print(f"Model device before decode: {next(semanticodec.parameters()).device}")
|
102 |
-
|
103 |
# Decode the tokens
|
104 |
waveform = semanticodec.decode(tokens_tensor)
|
105 |
print(f"Waveform device after decode: {waveform.device if isinstance(waveform, torch.Tensor) else 'numpy'}")
|
@@ -124,6 +134,8 @@ def process_both(audio_path):
|
|
124 |
"""Encode and then decode the audio without saving intermediate files"""
|
125 |
try:
|
126 |
print(f"Processing both on device: {model_device}")
|
|
|
|
|
127 |
# Encode
|
128 |
tokens = semanticodec.encode(audio_path)
|
129 |
print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
|
@@ -144,7 +156,7 @@ def process_both(audio_path):
|
|
144 |
tokens_tensor = tokens_tensor.to(model_device)
|
145 |
print(f"Tokens moved to device: {tokens_tensor.device}")
|
146 |
|
147 |
-
#
|
148 |
semanticodec.to(model_device)
|
149 |
print(f"Model device before decode: {next(semanticodec.parameters()).device}")
|
150 |
|
@@ -167,6 +179,88 @@ def process_both(audio_path):
|
|
167 |
print(f"Processing error: {str(e)}")
|
168 |
return None, f"Error processing audio: {str(e)}"
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
@spaces.GPU(duration=360)
|
171 |
def stream_decode_tokens(token_file):
|
172 |
"""Decode tokens to audio in streaming chunks"""
|
@@ -184,6 +278,9 @@ def stream_decode_tokens(token_file):
|
|
184 |
intended_device = token_data.get('device', model_device)
|
185 |
print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
|
186 |
|
|
|
|
|
|
|
187 |
# If tokens are too small, decode all at once
|
188 |
if tokens.shape[1] < 500:
|
189 |
# Convert to torch tensor with Long dtype for embedding
|
@@ -291,6 +388,19 @@ with gr.Blocks(title="Oterin Audio Codec") as demo:
|
|
291 |
both_status = gr.Textbox(label="Status")
|
292 |
both_btn = gr.Button("Process")
|
293 |
both_btn.click(process_both, inputs=both_input, outputs=[both_output, both_status])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
if __name__ == "__main__":
|
296 |
demo.launch(share=True)
|
|
|
15 |
|
16 |
# Initialize the model and ensure it's on the correct device
|
17 |
def load_model():
|
18 |
+
model = SemantiCodec(token_rate=25, semantic_vocab_size=32768) # 0.35 kbps
|
19 |
if torch.cuda.is_available():
|
20 |
+
# Move the model to CUDA and ensure it's fully initialized on CUDA
|
21 |
+
model = model.to("cuda:0")
|
22 |
+
# Force CUDA initialization
|
23 |
+
dummy_input = torch.zeros(1, 1, 1, dtype=torch.long).cuda()
|
24 |
+
try:
|
25 |
+
with torch.no_grad():
|
26 |
+
_ = model.decoder(dummy_input)
|
27 |
+
except:
|
28 |
+
print("Dummy forward pass failed, but CUDA initialization attempted")
|
29 |
return model
|
30 |
|
31 |
# Initialize model
|
|
|
35 |
print(f"Model initialized on device: {model_device}")
|
36 |
|
37 |
# Define sample rate as a constant
|
38 |
+
# Changed from 32000 to 16000 to fix playback speed
|
39 |
+
SAMPLE_RATE = 16000
|
40 |
|
41 |
@spaces.GPU(duration=20)
|
42 |
def encode_audio(audio_path):
|
43 |
"""Encode audio file to tokens and return them as a file"""
|
44 |
try:
|
45 |
print(f"Encoding audio on device: {model_device}")
|
46 |
+
# Ensure model is on the right device
|
47 |
+
semanticodec.to(model_device)
|
48 |
tokens = semanticodec.encode(audio_path)
|
49 |
print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
|
50 |
|
|
|
98 |
intended_device = token_data.get('device', model_device)
|
99 |
print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
|
100 |
|
101 |
+
# Ensure model is on the right device first
|
102 |
+
semanticodec.to(model_device)
|
103 |
+
print(f"Model device before tensor creation: {next(semanticodec.parameters()).device}")
|
104 |
+
|
105 |
# Convert to torch tensor with Long dtype for embedding
|
106 |
tokens_tensor = torch.tensor(tokens, dtype=torch.long)
|
107 |
print(f"Tokens tensor created on device: {tokens_tensor.device} with dtype: {tokens_tensor.dtype}")
|
|
|
110 |
tokens_tensor = tokens_tensor.to(model_device)
|
111 |
print(f"Tokens moved to device: {tokens_tensor.device}")
|
112 |
|
|
|
|
|
|
|
|
|
113 |
# Decode the tokens
|
114 |
waveform = semanticodec.decode(tokens_tensor)
|
115 |
print(f"Waveform device after decode: {waveform.device if isinstance(waveform, torch.Tensor) else 'numpy'}")
|
|
|
134 |
"""Encode and then decode the audio without saving intermediate files"""
|
135 |
try:
|
136 |
print(f"Processing both on device: {model_device}")
|
137 |
+
# Ensure model is on the right device
|
138 |
+
semanticodec.to(model_device)
|
139 |
# Encode
|
140 |
tokens = semanticodec.encode(audio_path)
|
141 |
print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
|
|
|
156 |
tokens_tensor = tokens_tensor.to(model_device)
|
157 |
print(f"Tokens moved to device: {tokens_tensor.device}")
|
158 |
|
159 |
+
# Ensure model is on the right device again before decoding
|
160 |
semanticodec.to(model_device)
|
161 |
print(f"Model device before decode: {next(semanticodec.parameters()).device}")
|
162 |
|
|
|
179 |
print(f"Processing error: {str(e)}")
|
180 |
return None, f"Error processing audio: {str(e)}"
|
181 |
|
182 |
+
@spaces.GPU(duration=360)
|
183 |
+
def stream_both(audio_path):
|
184 |
+
"""Encode and then stream decode the audio"""
|
185 |
+
try:
|
186 |
+
print(f"Processing both (streaming) on device: {model_device}")
|
187 |
+
# Ensure model is on the right device
|
188 |
+
semanticodec.to(model_device)
|
189 |
+
|
190 |
+
# First encode the audio
|
191 |
+
tokens = semanticodec.encode(audio_path)
|
192 |
+
if isinstance(tokens, torch.Tensor):
|
193 |
+
tokens = tokens.cpu().numpy()
|
194 |
+
|
195 |
+
# Ensure tokens are in the right shape for decoding
|
196 |
+
if tokens.ndim == 1:
|
197 |
+
tokens = tokens.reshape(1, -1, 1)
|
198 |
+
|
199 |
+
print(f"Encoded audio to {tokens.shape[1]} tokens, now streaming decoding...")
|
200 |
+
yield None, f"Encoded to {tokens.shape[1]} tokens, starting decoding..."
|
201 |
+
|
202 |
+
# If tokens are too small, decode all at once
|
203 |
+
if tokens.shape[1] < 500:
|
204 |
+
# Convert to torch tensor with Long dtype for embedding
|
205 |
+
tokens_tensor = torch.tensor(tokens, dtype=torch.long).to(model_device)
|
206 |
+
|
207 |
+
# Decode the tokens
|
208 |
+
semanticodec.to(model_device)
|
209 |
+
waveform = semanticodec.decode(tokens_tensor)
|
210 |
+
if isinstance(waveform, torch.Tensor):
|
211 |
+
waveform = waveform.cpu().numpy()
|
212 |
+
|
213 |
+
audio_data = waveform[0, 0]
|
214 |
+
yield (SAMPLE_RATE, audio_data), f"Encoded to {tokens.shape[1]} tokens and decoded to audio"
|
215 |
+
return
|
216 |
+
|
217 |
+
# Split tokens into chunks for streaming
|
218 |
+
chunk_size = 500 # Number of tokens per chunk
|
219 |
+
num_chunks = (tokens.shape[1] + chunk_size - 1) // chunk_size # Ceiling division
|
220 |
+
|
221 |
+
all_audio_chunks = []
|
222 |
+
|
223 |
+
for i in range(num_chunks):
|
224 |
+
start_idx = i * chunk_size
|
225 |
+
end_idx = min((i + 1) * chunk_size, tokens.shape[1])
|
226 |
+
|
227 |
+
print(f"Decoding chunk {i+1}/{num_chunks}, tokens {start_idx} to {end_idx}")
|
228 |
+
|
229 |
+
# Extract chunk of tokens
|
230 |
+
token_chunk = tokens[:, start_idx:end_idx, :]
|
231 |
+
|
232 |
+
# Convert to torch tensor with Long dtype
|
233 |
+
tokens_tensor = torch.tensor(token_chunk, dtype=torch.long).to(model_device)
|
234 |
+
|
235 |
+
# Ensure model is on the expected device
|
236 |
+
semanticodec.to(model_device)
|
237 |
+
|
238 |
+
# Decode the tokens
|
239 |
+
waveform = semanticodec.decode(tokens_tensor)
|
240 |
+
if isinstance(waveform, torch.Tensor):
|
241 |
+
waveform = waveform.cpu().numpy()
|
242 |
+
|
243 |
+
# Extract audio data
|
244 |
+
audio_chunk = waveform[0, 0]
|
245 |
+
all_audio_chunks.append(audio_chunk)
|
246 |
+
|
247 |
+
# Combine all chunks we have so far
|
248 |
+
combined_audio = np.concatenate(all_audio_chunks)
|
249 |
+
|
250 |
+
# Yield the combined audio for streaming playback
|
251 |
+
yield (SAMPLE_RATE, combined_audio), f"Encoded to {tokens.shape[1]} tokens\nDecoded chunk {i+1}/{num_chunks} ({end_idx}/{tokens.shape[1]} tokens)"
|
252 |
+
|
253 |
+
# Small delay to allow Gradio to update UI
|
254 |
+
time.sleep(0.1)
|
255 |
+
|
256 |
+
# Final complete audio
|
257 |
+
combined_audio = np.concatenate(all_audio_chunks)
|
258 |
+
yield (SAMPLE_RATE, combined_audio), f"Completed: Encoded to {tokens.shape[1]} tokens and fully decoded"
|
259 |
+
|
260 |
+
except Exception as e:
|
261 |
+
print(f"Streaming process error: {str(e)}")
|
262 |
+
yield None, f"Error processing audio: {str(e)}"
|
263 |
+
|
264 |
@spaces.GPU(duration=360)
|
265 |
def stream_decode_tokens(token_file):
|
266 |
"""Decode tokens to audio in streaming chunks"""
|
|
|
278 |
intended_device = token_data.get('device', model_device)
|
279 |
print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
|
280 |
|
281 |
+
# Ensure model is on the right device
|
282 |
+
semanticodec.to(model_device)
|
283 |
+
|
284 |
# If tokens are too small, decode all at once
|
285 |
if tokens.shape[1] < 500:
|
286 |
# Convert to torch tensor with Long dtype for embedding
|
|
|
388 |
both_status = gr.Textbox(label="Status")
|
389 |
both_btn = gr.Button("Process")
|
390 |
both_btn.click(process_both, inputs=both_input, outputs=[both_output, both_status])
|
391 |
+
|
392 |
+
with gr.Tab("Both Streaming (Encode & Stream Decode)"):
|
393 |
+
with gr.Row():
|
394 |
+
stream_both_input = gr.Audio(type="filepath", label="Input Audio")
|
395 |
+
stream_both_output = gr.Audio(label="Streaming Reconstructed Audio")
|
396 |
+
stream_both_status = gr.Textbox(label="Status")
|
397 |
+
stream_both_btn = gr.Button("Encode & Stream Decode")
|
398 |
+
stream_both_btn.click(
|
399 |
+
stream_both,
|
400 |
+
inputs=stream_both_input,
|
401 |
+
outputs=[stream_both_output, stream_both_status],
|
402 |
+
show_progress=True
|
403 |
+
)
|
404 |
|
405 |
if __name__ == "__main__":
|
406 |
demo.launch(share=True)
|