File size: 17,298 Bytes
fc8b181
f18f98b
bd40662
a32055a
fc8b181
 
 
5c44be0
e173c02
 
544ae95
24e47df
16120e1
544ae95
bd40662
24e47df
fc8b181
49098ca
24e47df
be68594
 
 
 
 
 
 
 
 
24e47df
fc8b181
24e47df
fc8b181
24e47df
 
 
fc8b181
16120e1
be68594
 
16120e1
e173c02
fc8b181
544ae95
4ad7b57
24e47df
be68594
 
4ad7b57
24e47df
 
2c32151
4ad7b57
 
 
 
 
 
 
 
24e47df
 
 
 
 
 
b7d2089
4ad7b57
 
 
 
 
24e47df
b7d2089
24e47df
4ad7b57
 
 
 
 
 
 
24e47df
4ad7b57
fc8b181
059f475
544ae95
fc8b181
544ae95
 
 
85dc4b0
4ad7b57
24e47df
 
 
 
 
 
 
 
be68594
 
 
 
9b8102d
 
 
24e47df
 
 
 
 
4ad7b57
24e47df
 
4ad7b57
2c32151
4ad7b57
2c32151
4ad7b57
055ea67
 
4ad7b57
055ea67
4ad7b57
055ea67
16120e1
4ad7b57
24e47df
2c32151
fc8b181
e14d19c
fc8b181
e173c02
4ad7b57
24e47df
be68594
 
4ad7b57
 
24e47df
 
4ad7b57
 
 
 
 
 
 
 
9b8102d
 
 
2c32151
24e47df
 
 
 
be68594
24e47df
 
4ad7b57
 
 
24e47df
6ea0ef3
2c32151
4ad7b57
2c32151
4ad7b57
055ea67
 
4ad7b57
055ea67
4ad7b57
055ea67
16120e1
4ad7b57
24e47df
2c32151
fc8b181
e14d19c
be68594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49098ca
be68594
 
 
 
 
 
 
 
 
 
 
 
 
 
49098ca
be68594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e14d19c
16120e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be68594
 
 
16120e1
49098ca
16120e1
 
 
 
 
 
 
 
 
 
 
 
 
 
49098ca
16120e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc8b181
 
 
24e47df
fc8b181
24e47df
fc8b181
 
85dc4b0
fc8b181
 
 
 
24e47df
fc8b181
85dc4b0
fc8b181
 
 
 
24e47df
16120e1
 
 
 
 
 
 
 
 
 
 
 
 
24e47df
 
 
 
 
 
 
be68594
 
 
 
 
 
 
 
 
 
 
 
 
fc8b181
 
e173c02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
import os
import gradio as gr
import numpy as np
import soundfile as sf
from semanticodec import SemantiCodec
from huggingface_hub import HfApi
import spaces
import torch
import tempfile
import io
import uuid
import pickle
import time
from pathlib import Path

# Initialize the model and ensure it's on the correct device
def load_model():
    model = SemantiCodec(token_rate=100, semantic_vocab_size=32768)  # 0.35 kbps
    if torch.cuda.is_available():
        # Move the model to CUDA and ensure it's fully initialized on CUDA
        model = model.to("cuda:0")
        # Force CUDA initialization
        dummy_input = torch.zeros(1, 1, 1, dtype=torch.long).cuda()
        try:
            with torch.no_grad():
                _ = model.decoder(dummy_input)
        except:
            print("Dummy forward pass failed, but CUDA initialization attempted")
    return model

# Initialize model
semanticodec = load_model()
# Get the device of the model
model_device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Model initialized on device: {model_device}")

# Define sample rate as a constant
# Changed from 32000 to 16000 to fix playback speed
SAMPLE_RATE = 16000

@spaces.GPU(duration=20)
def encode_audio(audio_path):
    """Encode audio file to tokens and return them as a file"""
    try:
        print(f"Encoding audio on device: {model_device}")
        # Ensure model is on the right device
        semanticodec.to(model_device)
        tokens = semanticodec.encode(audio_path)
        print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
        
        # Move tokens to CPU before converting to numpy
        if isinstance(tokens, torch.Tensor):
            tokens = tokens.cpu().numpy()
        
        # Ensure tokens are in the right shape for later decoding
        if tokens.ndim == 1:
            # Reshape to match expected format [batch, seq_len, features]
            tokens = tokens.reshape(1, -1, 1)
        
        # Save tokens in a way that preserves shape information
        token_data = {
            'tokens': tokens,
            'shape': tokens.shape,
            'device': str(model_device)  # Store intended device information
        }
        
        # Create a temporary file in /tmp which is writable in Spaces
        temp_dir = "/tmp"
        os.makedirs(temp_dir, exist_ok=True)
        temp_file_path = os.path.join(temp_dir, f"tokens_{uuid.uuid4()}.oterin")
        
        # Write using pickle instead of numpy save
        with open(temp_file_path, "wb") as f:
            pickle.dump(token_data, f)
        
        # Verify the file exists and has content
        if not os.path.exists(temp_file_path) or os.path.getsize(temp_file_path) == 0:
            raise Exception("Failed to create token file")
        
        return temp_file_path, f"Encoded to {tokens.shape[1]} tokens"
    except Exception as e:
        print(f"Encoding error: {str(e)}")
        return None, f"Error encoding audio: {str(e)}"

@spaces.GPU(duration=160)
def decode_tokens(token_file):
    """Decode tokens to audio"""
    # Ensure the file exists and has content
    if not token_file or not os.path.exists(token_file):
        return None, "Error: Empty or missing token file"
    
    try:
        # Load tokens using pickle instead of numpy load
        with open(token_file, "rb") as f:
            token_data = pickle.load(f)
        
        tokens = token_data['tokens']
        intended_device = token_data.get('device', model_device)
        print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
        
        # Ensure model is on the right device first
        semanticodec.to(model_device)
        print(f"Model device before tensor creation: {next(semanticodec.parameters()).device}")
        
        # Convert to torch tensor with Long dtype for embedding
        tokens_tensor = torch.tensor(tokens, dtype=torch.long)
        print(f"Tokens tensor created on device: {tokens_tensor.device} with dtype: {tokens_tensor.dtype}")
        
        # Explicitly move tokens to the model's device
        tokens_tensor = tokens_tensor.to(model_device)
        print(f"Tokens moved to device: {tokens_tensor.device}")
        
        # Decode the tokens
        waveform = semanticodec.decode(tokens_tensor)
        print(f"Waveform device after decode: {waveform.device if isinstance(waveform, torch.Tensor) else 'numpy'}")
        
        # Move waveform to CPU for audio processing
        if isinstance(waveform, torch.Tensor):
            waveform = waveform.cpu().numpy()
        
        # Extract audio data - this should be a numpy array
        audio_data = waveform[0, 0]  # Shape should be [time]
        
        print(f"Audio data shape: {audio_data.shape}, dtype: {audio_data.dtype}")
        
        # Return in Gradio Audio compatible format: (sample_rate, audio_data)
        return (SAMPLE_RATE, audio_data), f"Decoded {tokens.shape[1]} tokens to audio"
    except Exception as e:
        print(f"Decoding error: {str(e)}")
        return None, f"Error decoding tokens: {str(e)}"

@spaces.GPU(duration=250)
def process_both(audio_path):
    """Encode and then decode the audio without saving intermediate files"""
    try:
        print(f"Processing both on device: {model_device}")
        # Ensure model is on the right device
        semanticodec.to(model_device)
        # Encode
        tokens = semanticodec.encode(audio_path)
        print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
        
        if isinstance(tokens, torch.Tensor):
            tokens = tokens.cpu().numpy()
        
        # Ensure tokens are in the right shape for decoding
        if tokens.ndim == 1:
            # Reshape to match expected format [batch, seq_len, features]
            tokens = tokens.reshape(1, -1, 1)
        
        # Convert back to torch tensor with Long dtype for embedding
        tokens_tensor = torch.tensor(tokens, dtype=torch.long)
        print(f"Tokens tensor created on device: {tokens_tensor.device} with dtype: {tokens_tensor.dtype}")
        
        # Explicitly move tokens to the model's device
        tokens_tensor = tokens_tensor.to(model_device)
        print(f"Tokens moved to device: {tokens_tensor.device}")
        
        # Ensure model is on the right device again before decoding
        semanticodec.to(model_device)
        print(f"Model device before decode: {next(semanticodec.parameters()).device}")
        
        # Decode
        waveform = semanticodec.decode(tokens_tensor)
        print(f"Waveform device after decode: {waveform.device if isinstance(waveform, torch.Tensor) else 'numpy'}")
        
        # Move waveform to CPU for audio processing
        if isinstance(waveform, torch.Tensor):
            waveform = waveform.cpu().numpy()
        
        # Extract audio data - this should be a numpy array
        audio_data = waveform[0, 0]  # Shape should be [time]
        
        print(f"Audio data shape: {audio_data.shape}, dtype: {audio_data.dtype}")
        
        # Return in Gradio Audio compatible format: (sample_rate, audio_data)
        return (SAMPLE_RATE, audio_data), f"Encoded to {tokens.shape[1]} tokens\nDecoded {tokens.shape[1]} tokens to audio"
    except Exception as e:
        print(f"Processing error: {str(e)}")
        return None, f"Error processing audio: {str(e)}"

@spaces.GPU(duration=250)
def stream_both(audio_path):
    """Encode and then stream decode the audio"""
    try:
        print(f"Processing both (streaming) on device: {model_device}")
        # Ensure model is on the right device
        semanticodec.to(model_device)
        
        # First encode the audio
        tokens = semanticodec.encode(audio_path)
        if isinstance(tokens, torch.Tensor):
            tokens = tokens.cpu().numpy()
        
        # Ensure tokens are in the right shape for decoding
        if tokens.ndim == 1:
            tokens = tokens.reshape(1, -1, 1)
        
        print(f"Encoded audio to {tokens.shape[1]} tokens, now streaming decoding...")
        yield None, f"Encoded to {tokens.shape[1]} tokens, starting decoding..."
        
        # If tokens are too small, decode all at once
        if tokens.shape[1] < 1500:  # Changed from 500 to 1500 (15 seconds at 100 tokens/sec)
            # Convert to torch tensor with Long dtype for embedding
            tokens_tensor = torch.tensor(tokens, dtype=torch.long).to(model_device)
            
            # Decode the tokens
            semanticodec.to(model_device)
            waveform = semanticodec.decode(tokens_tensor)
            if isinstance(waveform, torch.Tensor):
                waveform = waveform.cpu().numpy()
            
            audio_data = waveform[0, 0]
            yield (SAMPLE_RATE, audio_data), f"Encoded to {tokens.shape[1]} tokens and decoded to audio"
            return
        
        # Split tokens into chunks for streaming
        chunk_size = 1500  # Changed from 500 to 1500 (15 seconds at 100 tokens/sec)
        num_chunks = (tokens.shape[1] + chunk_size - 1) // chunk_size  # Ceiling division
        
        all_audio_chunks = []
        
        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = min((i + 1) * chunk_size, tokens.shape[1])
            
            print(f"Decoding chunk {i+1}/{num_chunks}, tokens {start_idx} to {end_idx}")
            
            # Extract chunk of tokens
            token_chunk = tokens[:, start_idx:end_idx, :]
            
            # Convert to torch tensor with Long dtype
            tokens_tensor = torch.tensor(token_chunk, dtype=torch.long).to(model_device)
            
            # Ensure model is on the expected device
            semanticodec.to(model_device)
            
            # Decode the tokens
            waveform = semanticodec.decode(tokens_tensor)
            if isinstance(waveform, torch.Tensor):
                waveform = waveform.cpu().numpy()
            
            # Extract audio data
            audio_chunk = waveform[0, 0]
            all_audio_chunks.append(audio_chunk)
            
            # Combine all chunks we have so far
            combined_audio = np.concatenate(all_audio_chunks)
            
            # Yield the combined audio for streaming playback
            yield (SAMPLE_RATE, combined_audio), f"Encoded to {tokens.shape[1]} tokens\nDecoded chunk {i+1}/{num_chunks} ({end_idx}/{tokens.shape[1]} tokens)"
            
            # Small delay to allow Gradio to update UI
            time.sleep(0.1)
        
        # Final complete audio
        combined_audio = np.concatenate(all_audio_chunks)
        yield (SAMPLE_RATE, combined_audio), f"Completed: Encoded to {tokens.shape[1]} tokens and fully decoded"
        
    except Exception as e:
        print(f"Streaming process error: {str(e)}")
        yield None, f"Error processing audio: {str(e)}"

@spaces.GPU(duration=250)
def stream_decode_tokens(token_file):
    """Decode tokens to audio in streaming chunks"""
    # Ensure the file exists and has content
    if not token_file or not os.path.exists(token_file):
        yield None, "Error: Empty or missing token file"
        return
    
    try:
        # Load tokens using pickle instead of numpy load
        with open(token_file, "rb") as f:
            token_data = pickle.load(f)
        
        tokens = token_data['tokens']
        intended_device = token_data.get('device', model_device)
        print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
        
        # Ensure model is on the right device
        semanticodec.to(model_device)
        
        # If tokens are too small, decode all at once
        if tokens.shape[1] < 1500:  # Changed from 500 to 1500 (15 seconds at 100 tokens/sec)
            # Convert to torch tensor with Long dtype for embedding
            tokens_tensor = torch.tensor(tokens, dtype=torch.long)
            tokens_tensor = tokens_tensor.to(model_device)
            
            # Decode the tokens
            waveform = semanticodec.decode(tokens_tensor)
            if isinstance(waveform, torch.Tensor):
                waveform = waveform.cpu().numpy()
            
            audio_data = waveform[0, 0]
            yield (SAMPLE_RATE, audio_data), f"Decoded {tokens.shape[1]} tokens to audio"
            return
        
        # Split tokens into chunks for streaming
        chunk_size = 1500  # Changed from 500 to 1500 (15 seconds at 100 tokens/sec)
        num_chunks = (tokens.shape[1] + chunk_size - 1) // chunk_size  # Ceiling division
        
        # First status update
        yield None, f"Starting decoding of {tokens.shape[1]} tokens in {num_chunks} chunks..."
        
        all_audio_chunks = []
        
        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = min((i + 1) * chunk_size, tokens.shape[1])
            
            print(f"Decoding chunk {i+1}/{num_chunks}, tokens {start_idx} to {end_idx}")
            
            # Extract chunk of tokens
            token_chunk = tokens[:, start_idx:end_idx, :]
            
            # Convert to torch tensor with Long dtype
            tokens_tensor = torch.tensor(token_chunk, dtype=torch.long)
            tokens_tensor = tokens_tensor.to(model_device)
            
            # Ensure model is on the expected device
            semanticodec.to(model_device)
            
            # Decode the tokens
            waveform = semanticodec.decode(tokens_tensor)
            if isinstance(waveform, torch.Tensor):
                waveform = waveform.cpu().numpy()
            
            # Extract audio data
            audio_chunk = waveform[0, 0]
            all_audio_chunks.append(audio_chunk)
            
            # Combine all chunks we have so far
            combined_audio = np.concatenate(all_audio_chunks)
            
            # Yield the combined audio for streaming playback
            yield (SAMPLE_RATE, combined_audio), f"Decoded chunk {i+1}/{num_chunks} ({end_idx}/{tokens.shape[1]} tokens)"
            
            # Small delay to allow Gradio to update UI
            time.sleep(0.1)
        
        # Final complete audio
        combined_audio = np.concatenate(all_audio_chunks)
        yield (SAMPLE_RATE, combined_audio), f"Completed decoding all {tokens.shape[1]} tokens"
        
    except Exception as e:
        print(f"Streaming decode error: {str(e)}")
        yield None, f"Error decoding tokens: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="Oterin Audio Codec") as demo:
    gr.Markdown("# Oterin Audio Codec")
    gr.Markdown("Upload an audio file to encode it to semantic tokens, decode tokens back to audio, or do both.")
    
    with gr.Tab("Encode Audio"):
        with gr.Row():
            encode_input = gr.Audio(type="filepath", label="Input Audio")
            encode_output = gr.File(label="Encoded Tokens (.oterin)", file_types=[".oterin"])
        encode_status = gr.Textbox(label="Status")
        encode_btn = gr.Button("Encode")
        encode_btn.click(encode_audio, inputs=encode_input, outputs=[encode_output, encode_status])
    
    with gr.Tab("Decode Tokens"):
        with gr.Row():
            decode_input = gr.File(label="Token File (.oterin)", file_types=[".oterin"])
            decode_output = gr.Audio(label="Decoded Audio")
        decode_status = gr.Textbox(label="Status")
        decode_btn = gr.Button("Decode")
        decode_btn.click(decode_tokens, inputs=decode_input, outputs=[decode_output, decode_status])
    
    with gr.Tab("Stream Decode (Listen while decoding)"):
        with gr.Row():
            stream_decode_input = gr.File(label="Token File (.oterin)", file_types=[".oterin"])
            stream_decode_output = gr.Audio(label="Streaming Audio Output")
        stream_decode_status = gr.Textbox(label="Status")
        stream_decode_btn = gr.Button("Start Streaming Decode")
        stream_decode_btn.click(
            stream_decode_tokens, 
            inputs=stream_decode_input, 
            outputs=[stream_decode_output, stream_decode_status],
            show_progress=True
        )
    
    with gr.Tab("Both (Encode & Decode)"):
        with gr.Row():
            both_input = gr.Audio(type="filepath", label="Input Audio")
            both_output = gr.Audio(label="Reconstructed Audio")
        both_status = gr.Textbox(label="Status")
        both_btn = gr.Button("Process")
        both_btn.click(process_both, inputs=both_input, outputs=[both_output, both_status])
    
    with gr.Tab("Both Streaming (Encode & Stream Decode)"):
        with gr.Row():
            stream_both_input = gr.Audio(type="filepath", label="Input Audio")
            stream_both_output = gr.Audio(label="Streaming Reconstructed Audio")
        stream_both_status = gr.Textbox(label="Status")
        stream_both_btn = gr.Button("Encode & Stream Decode")
        stream_both_btn.click(
            stream_both, 
            inputs=stream_both_input, 
            outputs=[stream_both_output, stream_both_status],
            show_progress=True
        )

if __name__ == "__main__":
    demo.launch(share=True)