owiedotch commited on
Commit
be68594
·
verified ·
1 Parent(s): 16120e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -9
app.py CHANGED
@@ -15,10 +15,17 @@ from pathlib import Path
15
 
16
  # Initialize the model and ensure it's on the correct device
17
  def load_model():
18
- model = SemantiCodec(token_rate=100, semantic_vocab_size=16384) # 1.35 kbps
19
  if torch.cuda.is_available():
20
- # Move the model to CUDA
21
- model.to("cuda:0")
 
 
 
 
 
 
 
22
  return model
23
 
24
  # Initialize model
@@ -28,13 +35,16 @@ model_device = "cuda:0" if torch.cuda.is_available() else "cpu"
28
  print(f"Model initialized on device: {model_device}")
29
 
30
  # Define sample rate as a constant
31
- SAMPLE_RATE = 32000
 
32
 
33
  @spaces.GPU(duration=20)
34
  def encode_audio(audio_path):
35
  """Encode audio file to tokens and return them as a file"""
36
  try:
37
  print(f"Encoding audio on device: {model_device}")
 
 
38
  tokens = semanticodec.encode(audio_path)
39
  print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
40
 
@@ -88,6 +98,10 @@ def decode_tokens(token_file):
88
  intended_device = token_data.get('device', model_device)
89
  print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
90
 
 
 
 
 
91
  # Convert to torch tensor with Long dtype for embedding
92
  tokens_tensor = torch.tensor(tokens, dtype=torch.long)
93
  print(f"Tokens tensor created on device: {tokens_tensor.device} with dtype: {tokens_tensor.dtype}")
@@ -96,10 +110,6 @@ def decode_tokens(token_file):
96
  tokens_tensor = tokens_tensor.to(model_device)
97
  print(f"Tokens moved to device: {tokens_tensor.device}")
98
 
99
- # Also ensure model is on the expected device
100
- semanticodec.to(model_device)
101
- print(f"Model device before decode: {next(semanticodec.parameters()).device}")
102
-
103
  # Decode the tokens
104
  waveform = semanticodec.decode(tokens_tensor)
105
  print(f"Waveform device after decode: {waveform.device if isinstance(waveform, torch.Tensor) else 'numpy'}")
@@ -124,6 +134,8 @@ def process_both(audio_path):
124
  """Encode and then decode the audio without saving intermediate files"""
125
  try:
126
  print(f"Processing both on device: {model_device}")
 
 
127
  # Encode
128
  tokens = semanticodec.encode(audio_path)
129
  print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
@@ -144,7 +156,7 @@ def process_both(audio_path):
144
  tokens_tensor = tokens_tensor.to(model_device)
145
  print(f"Tokens moved to device: {tokens_tensor.device}")
146
 
147
- # Also ensure model is on the expected device
148
  semanticodec.to(model_device)
149
  print(f"Model device before decode: {next(semanticodec.parameters()).device}")
150
 
@@ -167,6 +179,88 @@ def process_both(audio_path):
167
  print(f"Processing error: {str(e)}")
168
  return None, f"Error processing audio: {str(e)}"
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  @spaces.GPU(duration=360)
171
  def stream_decode_tokens(token_file):
172
  """Decode tokens to audio in streaming chunks"""
@@ -184,6 +278,9 @@ def stream_decode_tokens(token_file):
184
  intended_device = token_data.get('device', model_device)
185
  print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
186
 
 
 
 
187
  # If tokens are too small, decode all at once
188
  if tokens.shape[1] < 500:
189
  # Convert to torch tensor with Long dtype for embedding
@@ -291,6 +388,19 @@ with gr.Blocks(title="Oterin Audio Codec") as demo:
291
  both_status = gr.Textbox(label="Status")
292
  both_btn = gr.Button("Process")
293
  both_btn.click(process_both, inputs=both_input, outputs=[both_output, both_status])
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
  if __name__ == "__main__":
296
  demo.launch(share=True)
 
15
 
16
  # Initialize the model and ensure it's on the correct device
17
  def load_model():
18
+ model = SemantiCodec(token_rate=25, semantic_vocab_size=32768) # 0.35 kbps
19
  if torch.cuda.is_available():
20
+ # Move the model to CUDA and ensure it's fully initialized on CUDA
21
+ model = model.to("cuda:0")
22
+ # Force CUDA initialization
23
+ dummy_input = torch.zeros(1, 1, 1, dtype=torch.long).cuda()
24
+ try:
25
+ with torch.no_grad():
26
+ _ = model.decoder(dummy_input)
27
+ except:
28
+ print("Dummy forward pass failed, but CUDA initialization attempted")
29
  return model
30
 
31
  # Initialize model
 
35
  print(f"Model initialized on device: {model_device}")
36
 
37
  # Define sample rate as a constant
38
+ # Changed from 32000 to 16000 to fix playback speed
39
+ SAMPLE_RATE = 16000
40
 
41
  @spaces.GPU(duration=20)
42
  def encode_audio(audio_path):
43
  """Encode audio file to tokens and return them as a file"""
44
  try:
45
  print(f"Encoding audio on device: {model_device}")
46
+ # Ensure model is on the right device
47
+ semanticodec.to(model_device)
48
  tokens = semanticodec.encode(audio_path)
49
  print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
50
 
 
98
  intended_device = token_data.get('device', model_device)
99
  print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
100
 
101
+ # Ensure model is on the right device first
102
+ semanticodec.to(model_device)
103
+ print(f"Model device before tensor creation: {next(semanticodec.parameters()).device}")
104
+
105
  # Convert to torch tensor with Long dtype for embedding
106
  tokens_tensor = torch.tensor(tokens, dtype=torch.long)
107
  print(f"Tokens tensor created on device: {tokens_tensor.device} with dtype: {tokens_tensor.dtype}")
 
110
  tokens_tensor = tokens_tensor.to(model_device)
111
  print(f"Tokens moved to device: {tokens_tensor.device}")
112
 
 
 
 
 
113
  # Decode the tokens
114
  waveform = semanticodec.decode(tokens_tensor)
115
  print(f"Waveform device after decode: {waveform.device if isinstance(waveform, torch.Tensor) else 'numpy'}")
 
134
  """Encode and then decode the audio without saving intermediate files"""
135
  try:
136
  print(f"Processing both on device: {model_device}")
137
+ # Ensure model is on the right device
138
+ semanticodec.to(model_device)
139
  # Encode
140
  tokens = semanticodec.encode(audio_path)
141
  print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
 
156
  tokens_tensor = tokens_tensor.to(model_device)
157
  print(f"Tokens moved to device: {tokens_tensor.device}")
158
 
159
+ # Ensure model is on the right device again before decoding
160
  semanticodec.to(model_device)
161
  print(f"Model device before decode: {next(semanticodec.parameters()).device}")
162
 
 
179
  print(f"Processing error: {str(e)}")
180
  return None, f"Error processing audio: {str(e)}"
181
 
182
+ @spaces.GPU(duration=360)
183
+ def stream_both(audio_path):
184
+ """Encode and then stream decode the audio"""
185
+ try:
186
+ print(f"Processing both (streaming) on device: {model_device}")
187
+ # Ensure model is on the right device
188
+ semanticodec.to(model_device)
189
+
190
+ # First encode the audio
191
+ tokens = semanticodec.encode(audio_path)
192
+ if isinstance(tokens, torch.Tensor):
193
+ tokens = tokens.cpu().numpy()
194
+
195
+ # Ensure tokens are in the right shape for decoding
196
+ if tokens.ndim == 1:
197
+ tokens = tokens.reshape(1, -1, 1)
198
+
199
+ print(f"Encoded audio to {tokens.shape[1]} tokens, now streaming decoding...")
200
+ yield None, f"Encoded to {tokens.shape[1]} tokens, starting decoding..."
201
+
202
+ # If tokens are too small, decode all at once
203
+ if tokens.shape[1] < 500:
204
+ # Convert to torch tensor with Long dtype for embedding
205
+ tokens_tensor = torch.tensor(tokens, dtype=torch.long).to(model_device)
206
+
207
+ # Decode the tokens
208
+ semanticodec.to(model_device)
209
+ waveform = semanticodec.decode(tokens_tensor)
210
+ if isinstance(waveform, torch.Tensor):
211
+ waveform = waveform.cpu().numpy()
212
+
213
+ audio_data = waveform[0, 0]
214
+ yield (SAMPLE_RATE, audio_data), f"Encoded to {tokens.shape[1]} tokens and decoded to audio"
215
+ return
216
+
217
+ # Split tokens into chunks for streaming
218
+ chunk_size = 500 # Number of tokens per chunk
219
+ num_chunks = (tokens.shape[1] + chunk_size - 1) // chunk_size # Ceiling division
220
+
221
+ all_audio_chunks = []
222
+
223
+ for i in range(num_chunks):
224
+ start_idx = i * chunk_size
225
+ end_idx = min((i + 1) * chunk_size, tokens.shape[1])
226
+
227
+ print(f"Decoding chunk {i+1}/{num_chunks}, tokens {start_idx} to {end_idx}")
228
+
229
+ # Extract chunk of tokens
230
+ token_chunk = tokens[:, start_idx:end_idx, :]
231
+
232
+ # Convert to torch tensor with Long dtype
233
+ tokens_tensor = torch.tensor(token_chunk, dtype=torch.long).to(model_device)
234
+
235
+ # Ensure model is on the expected device
236
+ semanticodec.to(model_device)
237
+
238
+ # Decode the tokens
239
+ waveform = semanticodec.decode(tokens_tensor)
240
+ if isinstance(waveform, torch.Tensor):
241
+ waveform = waveform.cpu().numpy()
242
+
243
+ # Extract audio data
244
+ audio_chunk = waveform[0, 0]
245
+ all_audio_chunks.append(audio_chunk)
246
+
247
+ # Combine all chunks we have so far
248
+ combined_audio = np.concatenate(all_audio_chunks)
249
+
250
+ # Yield the combined audio for streaming playback
251
+ yield (SAMPLE_RATE, combined_audio), f"Encoded to {tokens.shape[1]} tokens\nDecoded chunk {i+1}/{num_chunks} ({end_idx}/{tokens.shape[1]} tokens)"
252
+
253
+ # Small delay to allow Gradio to update UI
254
+ time.sleep(0.1)
255
+
256
+ # Final complete audio
257
+ combined_audio = np.concatenate(all_audio_chunks)
258
+ yield (SAMPLE_RATE, combined_audio), f"Completed: Encoded to {tokens.shape[1]} tokens and fully decoded"
259
+
260
+ except Exception as e:
261
+ print(f"Streaming process error: {str(e)}")
262
+ yield None, f"Error processing audio: {str(e)}"
263
+
264
  @spaces.GPU(duration=360)
265
  def stream_decode_tokens(token_file):
266
  """Decode tokens to audio in streaming chunks"""
 
278
  intended_device = token_data.get('device', model_device)
279
  print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
280
 
281
+ # Ensure model is on the right device
282
+ semanticodec.to(model_device)
283
+
284
  # If tokens are too small, decode all at once
285
  if tokens.shape[1] < 500:
286
  # Convert to torch tensor with Long dtype for embedding
 
388
  both_status = gr.Textbox(label="Status")
389
  both_btn = gr.Button("Process")
390
  both_btn.click(process_both, inputs=both_input, outputs=[both_output, both_status])
391
+
392
+ with gr.Tab("Both Streaming (Encode & Stream Decode)"):
393
+ with gr.Row():
394
+ stream_both_input = gr.Audio(type="filepath", label="Input Audio")
395
+ stream_both_output = gr.Audio(label="Streaming Reconstructed Audio")
396
+ stream_both_status = gr.Textbox(label="Status")
397
+ stream_both_btn = gr.Button("Encode & Stream Decode")
398
+ stream_both_btn.click(
399
+ stream_both,
400
+ inputs=stream_both_input,
401
+ outputs=[stream_both_output, stream_both_status],
402
+ show_progress=True
403
+ )
404
 
405
  if __name__ == "__main__":
406
  demo.launch(share=True)