File size: 17,298 Bytes
fc8b181 f18f98b bd40662 a32055a fc8b181 5c44be0 e173c02 544ae95 24e47df 16120e1 544ae95 bd40662 24e47df fc8b181 49098ca 24e47df be68594 24e47df fc8b181 24e47df fc8b181 24e47df fc8b181 16120e1 be68594 16120e1 e173c02 fc8b181 544ae95 4ad7b57 24e47df be68594 4ad7b57 24e47df 2c32151 4ad7b57 24e47df b7d2089 4ad7b57 24e47df b7d2089 24e47df 4ad7b57 24e47df 4ad7b57 fc8b181 059f475 544ae95 fc8b181 544ae95 85dc4b0 4ad7b57 24e47df be68594 9b8102d 24e47df 4ad7b57 24e47df 4ad7b57 2c32151 4ad7b57 2c32151 4ad7b57 055ea67 4ad7b57 055ea67 4ad7b57 055ea67 16120e1 4ad7b57 24e47df 2c32151 fc8b181 e14d19c fc8b181 e173c02 4ad7b57 24e47df be68594 4ad7b57 24e47df 4ad7b57 9b8102d 2c32151 24e47df be68594 24e47df 4ad7b57 24e47df 6ea0ef3 2c32151 4ad7b57 2c32151 4ad7b57 055ea67 4ad7b57 055ea67 4ad7b57 055ea67 16120e1 4ad7b57 24e47df 2c32151 fc8b181 e14d19c be68594 49098ca be68594 49098ca be68594 e14d19c 16120e1 be68594 16120e1 49098ca 16120e1 49098ca 16120e1 fc8b181 24e47df fc8b181 24e47df fc8b181 85dc4b0 fc8b181 24e47df fc8b181 85dc4b0 fc8b181 24e47df 16120e1 24e47df be68594 fc8b181 e173c02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 |
import os
import gradio as gr
import numpy as np
import soundfile as sf
from semanticodec import SemantiCodec
from huggingface_hub import HfApi
import spaces
import torch
import tempfile
import io
import uuid
import pickle
import time
from pathlib import Path
# Initialize the model and ensure it's on the correct device
def load_model():
model = SemantiCodec(token_rate=100, semantic_vocab_size=32768) # 0.35 kbps
if torch.cuda.is_available():
# Move the model to CUDA and ensure it's fully initialized on CUDA
model = model.to("cuda:0")
# Force CUDA initialization
dummy_input = torch.zeros(1, 1, 1, dtype=torch.long).cuda()
try:
with torch.no_grad():
_ = model.decoder(dummy_input)
except:
print("Dummy forward pass failed, but CUDA initialization attempted")
return model
# Initialize model
semanticodec = load_model()
# Get the device of the model
model_device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Model initialized on device: {model_device}")
# Define sample rate as a constant
# Changed from 32000 to 16000 to fix playback speed
SAMPLE_RATE = 16000
@spaces.GPU(duration=20)
def encode_audio(audio_path):
"""Encode audio file to tokens and return them as a file"""
try:
print(f"Encoding audio on device: {model_device}")
# Ensure model is on the right device
semanticodec.to(model_device)
tokens = semanticodec.encode(audio_path)
print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
# Move tokens to CPU before converting to numpy
if isinstance(tokens, torch.Tensor):
tokens = tokens.cpu().numpy()
# Ensure tokens are in the right shape for later decoding
if tokens.ndim == 1:
# Reshape to match expected format [batch, seq_len, features]
tokens = tokens.reshape(1, -1, 1)
# Save tokens in a way that preserves shape information
token_data = {
'tokens': tokens,
'shape': tokens.shape,
'device': str(model_device) # Store intended device information
}
# Create a temporary file in /tmp which is writable in Spaces
temp_dir = "/tmp"
os.makedirs(temp_dir, exist_ok=True)
temp_file_path = os.path.join(temp_dir, f"tokens_{uuid.uuid4()}.oterin")
# Write using pickle instead of numpy save
with open(temp_file_path, "wb") as f:
pickle.dump(token_data, f)
# Verify the file exists and has content
if not os.path.exists(temp_file_path) or os.path.getsize(temp_file_path) == 0:
raise Exception("Failed to create token file")
return temp_file_path, f"Encoded to {tokens.shape[1]} tokens"
except Exception as e:
print(f"Encoding error: {str(e)}")
return None, f"Error encoding audio: {str(e)}"
@spaces.GPU(duration=160)
def decode_tokens(token_file):
"""Decode tokens to audio"""
# Ensure the file exists and has content
if not token_file or not os.path.exists(token_file):
return None, "Error: Empty or missing token file"
try:
# Load tokens using pickle instead of numpy load
with open(token_file, "rb") as f:
token_data = pickle.load(f)
tokens = token_data['tokens']
intended_device = token_data.get('device', model_device)
print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
# Ensure model is on the right device first
semanticodec.to(model_device)
print(f"Model device before tensor creation: {next(semanticodec.parameters()).device}")
# Convert to torch tensor with Long dtype for embedding
tokens_tensor = torch.tensor(tokens, dtype=torch.long)
print(f"Tokens tensor created on device: {tokens_tensor.device} with dtype: {tokens_tensor.dtype}")
# Explicitly move tokens to the model's device
tokens_tensor = tokens_tensor.to(model_device)
print(f"Tokens moved to device: {tokens_tensor.device}")
# Decode the tokens
waveform = semanticodec.decode(tokens_tensor)
print(f"Waveform device after decode: {waveform.device if isinstance(waveform, torch.Tensor) else 'numpy'}")
# Move waveform to CPU for audio processing
if isinstance(waveform, torch.Tensor):
waveform = waveform.cpu().numpy()
# Extract audio data - this should be a numpy array
audio_data = waveform[0, 0] # Shape should be [time]
print(f"Audio data shape: {audio_data.shape}, dtype: {audio_data.dtype}")
# Return in Gradio Audio compatible format: (sample_rate, audio_data)
return (SAMPLE_RATE, audio_data), f"Decoded {tokens.shape[1]} tokens to audio"
except Exception as e:
print(f"Decoding error: {str(e)}")
return None, f"Error decoding tokens: {str(e)}"
@spaces.GPU(duration=250)
def process_both(audio_path):
"""Encode and then decode the audio without saving intermediate files"""
try:
print(f"Processing both on device: {model_device}")
# Ensure model is on the right device
semanticodec.to(model_device)
# Encode
tokens = semanticodec.encode(audio_path)
print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
if isinstance(tokens, torch.Tensor):
tokens = tokens.cpu().numpy()
# Ensure tokens are in the right shape for decoding
if tokens.ndim == 1:
# Reshape to match expected format [batch, seq_len, features]
tokens = tokens.reshape(1, -1, 1)
# Convert back to torch tensor with Long dtype for embedding
tokens_tensor = torch.tensor(tokens, dtype=torch.long)
print(f"Tokens tensor created on device: {tokens_tensor.device} with dtype: {tokens_tensor.dtype}")
# Explicitly move tokens to the model's device
tokens_tensor = tokens_tensor.to(model_device)
print(f"Tokens moved to device: {tokens_tensor.device}")
# Ensure model is on the right device again before decoding
semanticodec.to(model_device)
print(f"Model device before decode: {next(semanticodec.parameters()).device}")
# Decode
waveform = semanticodec.decode(tokens_tensor)
print(f"Waveform device after decode: {waveform.device if isinstance(waveform, torch.Tensor) else 'numpy'}")
# Move waveform to CPU for audio processing
if isinstance(waveform, torch.Tensor):
waveform = waveform.cpu().numpy()
# Extract audio data - this should be a numpy array
audio_data = waveform[0, 0] # Shape should be [time]
print(f"Audio data shape: {audio_data.shape}, dtype: {audio_data.dtype}")
# Return in Gradio Audio compatible format: (sample_rate, audio_data)
return (SAMPLE_RATE, audio_data), f"Encoded to {tokens.shape[1]} tokens\nDecoded {tokens.shape[1]} tokens to audio"
except Exception as e:
print(f"Processing error: {str(e)}")
return None, f"Error processing audio: {str(e)}"
@spaces.GPU(duration=250)
def stream_both(audio_path):
"""Encode and then stream decode the audio"""
try:
print(f"Processing both (streaming) on device: {model_device}")
# Ensure model is on the right device
semanticodec.to(model_device)
# First encode the audio
tokens = semanticodec.encode(audio_path)
if isinstance(tokens, torch.Tensor):
tokens = tokens.cpu().numpy()
# Ensure tokens are in the right shape for decoding
if tokens.ndim == 1:
tokens = tokens.reshape(1, -1, 1)
print(f"Encoded audio to {tokens.shape[1]} tokens, now streaming decoding...")
yield None, f"Encoded to {tokens.shape[1]} tokens, starting decoding..."
# If tokens are too small, decode all at once
if tokens.shape[1] < 1500: # Changed from 500 to 1500 (15 seconds at 100 tokens/sec)
# Convert to torch tensor with Long dtype for embedding
tokens_tensor = torch.tensor(tokens, dtype=torch.long).to(model_device)
# Decode the tokens
semanticodec.to(model_device)
waveform = semanticodec.decode(tokens_tensor)
if isinstance(waveform, torch.Tensor):
waveform = waveform.cpu().numpy()
audio_data = waveform[0, 0]
yield (SAMPLE_RATE, audio_data), f"Encoded to {tokens.shape[1]} tokens and decoded to audio"
return
# Split tokens into chunks for streaming
chunk_size = 1500 # Changed from 500 to 1500 (15 seconds at 100 tokens/sec)
num_chunks = (tokens.shape[1] + chunk_size - 1) // chunk_size # Ceiling division
all_audio_chunks = []
for i in range(num_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, tokens.shape[1])
print(f"Decoding chunk {i+1}/{num_chunks}, tokens {start_idx} to {end_idx}")
# Extract chunk of tokens
token_chunk = tokens[:, start_idx:end_idx, :]
# Convert to torch tensor with Long dtype
tokens_tensor = torch.tensor(token_chunk, dtype=torch.long).to(model_device)
# Ensure model is on the expected device
semanticodec.to(model_device)
# Decode the tokens
waveform = semanticodec.decode(tokens_tensor)
if isinstance(waveform, torch.Tensor):
waveform = waveform.cpu().numpy()
# Extract audio data
audio_chunk = waveform[0, 0]
all_audio_chunks.append(audio_chunk)
# Combine all chunks we have so far
combined_audio = np.concatenate(all_audio_chunks)
# Yield the combined audio for streaming playback
yield (SAMPLE_RATE, combined_audio), f"Encoded to {tokens.shape[1]} tokens\nDecoded chunk {i+1}/{num_chunks} ({end_idx}/{tokens.shape[1]} tokens)"
# Small delay to allow Gradio to update UI
time.sleep(0.1)
# Final complete audio
combined_audio = np.concatenate(all_audio_chunks)
yield (SAMPLE_RATE, combined_audio), f"Completed: Encoded to {tokens.shape[1]} tokens and fully decoded"
except Exception as e:
print(f"Streaming process error: {str(e)}")
yield None, f"Error processing audio: {str(e)}"
@spaces.GPU(duration=250)
def stream_decode_tokens(token_file):
"""Decode tokens to audio in streaming chunks"""
# Ensure the file exists and has content
if not token_file or not os.path.exists(token_file):
yield None, "Error: Empty or missing token file"
return
try:
# Load tokens using pickle instead of numpy load
with open(token_file, "rb") as f:
token_data = pickle.load(f)
tokens = token_data['tokens']
intended_device = token_data.get('device', model_device)
print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
# Ensure model is on the right device
semanticodec.to(model_device)
# If tokens are too small, decode all at once
if tokens.shape[1] < 1500: # Changed from 500 to 1500 (15 seconds at 100 tokens/sec)
# Convert to torch tensor with Long dtype for embedding
tokens_tensor = torch.tensor(tokens, dtype=torch.long)
tokens_tensor = tokens_tensor.to(model_device)
# Decode the tokens
waveform = semanticodec.decode(tokens_tensor)
if isinstance(waveform, torch.Tensor):
waveform = waveform.cpu().numpy()
audio_data = waveform[0, 0]
yield (SAMPLE_RATE, audio_data), f"Decoded {tokens.shape[1]} tokens to audio"
return
# Split tokens into chunks for streaming
chunk_size = 1500 # Changed from 500 to 1500 (15 seconds at 100 tokens/sec)
num_chunks = (tokens.shape[1] + chunk_size - 1) // chunk_size # Ceiling division
# First status update
yield None, f"Starting decoding of {tokens.shape[1]} tokens in {num_chunks} chunks..."
all_audio_chunks = []
for i in range(num_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, tokens.shape[1])
print(f"Decoding chunk {i+1}/{num_chunks}, tokens {start_idx} to {end_idx}")
# Extract chunk of tokens
token_chunk = tokens[:, start_idx:end_idx, :]
# Convert to torch tensor with Long dtype
tokens_tensor = torch.tensor(token_chunk, dtype=torch.long)
tokens_tensor = tokens_tensor.to(model_device)
# Ensure model is on the expected device
semanticodec.to(model_device)
# Decode the tokens
waveform = semanticodec.decode(tokens_tensor)
if isinstance(waveform, torch.Tensor):
waveform = waveform.cpu().numpy()
# Extract audio data
audio_chunk = waveform[0, 0]
all_audio_chunks.append(audio_chunk)
# Combine all chunks we have so far
combined_audio = np.concatenate(all_audio_chunks)
# Yield the combined audio for streaming playback
yield (SAMPLE_RATE, combined_audio), f"Decoded chunk {i+1}/{num_chunks} ({end_idx}/{tokens.shape[1]} tokens)"
# Small delay to allow Gradio to update UI
time.sleep(0.1)
# Final complete audio
combined_audio = np.concatenate(all_audio_chunks)
yield (SAMPLE_RATE, combined_audio), f"Completed decoding all {tokens.shape[1]} tokens"
except Exception as e:
print(f"Streaming decode error: {str(e)}")
yield None, f"Error decoding tokens: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Oterin Audio Codec") as demo:
gr.Markdown("# Oterin Audio Codec")
gr.Markdown("Upload an audio file to encode it to semantic tokens, decode tokens back to audio, or do both.")
with gr.Tab("Encode Audio"):
with gr.Row():
encode_input = gr.Audio(type="filepath", label="Input Audio")
encode_output = gr.File(label="Encoded Tokens (.oterin)", file_types=[".oterin"])
encode_status = gr.Textbox(label="Status")
encode_btn = gr.Button("Encode")
encode_btn.click(encode_audio, inputs=encode_input, outputs=[encode_output, encode_status])
with gr.Tab("Decode Tokens"):
with gr.Row():
decode_input = gr.File(label="Token File (.oterin)", file_types=[".oterin"])
decode_output = gr.Audio(label="Decoded Audio")
decode_status = gr.Textbox(label="Status")
decode_btn = gr.Button("Decode")
decode_btn.click(decode_tokens, inputs=decode_input, outputs=[decode_output, decode_status])
with gr.Tab("Stream Decode (Listen while decoding)"):
with gr.Row():
stream_decode_input = gr.File(label="Token File (.oterin)", file_types=[".oterin"])
stream_decode_output = gr.Audio(label="Streaming Audio Output")
stream_decode_status = gr.Textbox(label="Status")
stream_decode_btn = gr.Button("Start Streaming Decode")
stream_decode_btn.click(
stream_decode_tokens,
inputs=stream_decode_input,
outputs=[stream_decode_output, stream_decode_status],
show_progress=True
)
with gr.Tab("Both (Encode & Decode)"):
with gr.Row():
both_input = gr.Audio(type="filepath", label="Input Audio")
both_output = gr.Audio(label="Reconstructed Audio")
both_status = gr.Textbox(label="Status")
both_btn = gr.Button("Process")
both_btn.click(process_both, inputs=both_input, outputs=[both_output, both_status])
with gr.Tab("Both Streaming (Encode & Stream Decode)"):
with gr.Row():
stream_both_input = gr.Audio(type="filepath", label="Input Audio")
stream_both_output = gr.Audio(label="Streaming Reconstructed Audio")
stream_both_status = gr.Textbox(label="Status")
stream_both_btn = gr.Button("Encode & Stream Decode")
stream_both_btn.click(
stream_both,
inputs=stream_both_input,
outputs=[stream_both_output, stream_both_status],
show_progress=True
)
if __name__ == "__main__":
demo.launch(share=True) |