owiedotch commited on
Commit
24e47df
·
verified ·
1 Parent(s): e5f91c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -56
app.py CHANGED
@@ -9,19 +9,31 @@ import torch
9
  import tempfile
10
  import io
11
  import uuid
 
12
  from pathlib import Path
13
 
14
- # Initialize the model
15
  def load_model():
16
- return SemantiCodec(token_rate=100, semantic_vocab_size=32768) # 1.40 kbps
 
 
 
 
17
 
 
18
  semanticodec = load_model()
 
 
 
19
 
20
  @spaces.GPU(duration=20)
21
  def encode_audio(audio_path):
22
  """Encode audio file to tokens and return them as a file"""
23
  try:
 
24
  tokens = semanticodec.encode(audio_path)
 
 
25
  # Move tokens to CPU before converting to numpy
26
  if isinstance(tokens, torch.Tensor):
27
  tokens = tokens.cpu().numpy()
@@ -31,23 +43,21 @@ def encode_audio(audio_path):
31
  # Reshape to match expected format [batch, seq_len, features]
32
  tokens = tokens.reshape(1, -1, 1)
33
 
34
- # Save to a BytesIO buffer first
35
- buffer = io.BytesIO()
36
- np.save(buffer, tokens)
37
- buffer.seek(0)
38
-
39
- # Verify the buffer has content
40
- if buffer.getbuffer().nbytes == 0:
41
- raise Exception("Failed to create token buffer")
42
 
43
  # Create a temporary file in /tmp which is writable in Spaces
44
  temp_dir = "/tmp"
45
  os.makedirs(temp_dir, exist_ok=True)
46
  temp_file_path = os.path.join(temp_dir, f"tokens_{uuid.uuid4()}.oterin")
47
 
48
- # Write buffer to the temporary file
49
  with open(temp_file_path, "wb") as f:
50
- f.write(buffer.getvalue())
51
 
52
  # Verify the file exists and has content
53
  if not os.path.exists(temp_file_path) or os.path.getsize(temp_file_path) == 0:
@@ -55,9 +65,10 @@ def encode_audio(audio_path):
55
 
56
  return temp_file_path, f"Encoded to {tokens.shape[1]} tokens"
57
  except Exception as e:
 
58
  return None, f"Error encoding audio: {str(e)}"
59
 
60
- @spaces.GPU(duration=60)
61
  def decode_tokens(token_file):
62
  """Decode tokens to audio"""
63
  # Ensure the file exists and has content
@@ -65,25 +76,29 @@ def decode_tokens(token_file):
65
  return None, "Error: Empty or missing token file"
66
 
67
  try:
68
- # Load tokens from file
69
- tokens = np.load(token_file, allow_pickle=True)
70
-
71
- # Convert to torch tensor with proper dimensions
72
- if isinstance(tokens, np.ndarray):
73
- # Ensure tokens are in the right shape
74
- if tokens.ndim == 1:
75
- # Reshape to match expected format [batch, seq_len, features]
76
- tokens = tokens.reshape(1, -1, 1)
77
-
78
- # Convert to torch tensor (on CPU first)
79
- tokens = torch.tensor(tokens)
80
-
81
- # Explicitly move tokens to CUDA
82
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
83
- tokens = tokens.to(device)
 
 
 
84
 
85
  # Decode the tokens
86
- waveform = semanticodec.decode(tokens)
 
87
 
88
  # Move waveform to CPU for audio processing
89
  if isinstance(waveform, torch.Tensor):
@@ -100,14 +115,18 @@ def decode_tokens(token_file):
100
 
101
  return output_buffer, f"Decoded {tokens.shape[1]} tokens to audio"
102
  except Exception as e:
 
103
  return None, f"Error decoding tokens: {str(e)}"
104
 
105
- @spaces.GPU(duration=80)
106
  def process_both(audio_path):
107
  """Encode and then decode the audio without saving intermediate files"""
108
  try:
 
109
  # Encode
110
  tokens = semanticodec.encode(audio_path)
 
 
111
  if isinstance(tokens, torch.Tensor):
112
  tokens = tokens.cpu().numpy()
113
 
@@ -117,14 +136,20 @@ def process_both(audio_path):
117
  tokens = tokens.reshape(1, -1, 1)
118
 
119
  # Convert back to torch tensor (on CPU first)
120
- tokens_tensor = torch.tensor(tokens)
 
121
 
122
- # Explicitly move tokens to CUDA
123
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
124
- tokens_tensor = tokens_tensor.to(device)
 
 
 
 
125
 
126
  # Decode
127
  waveform = semanticodec.decode(tokens_tensor)
 
128
 
129
  # Move waveform to CPU for audio processing
130
  if isinstance(waveform, torch.Tensor):
@@ -141,31 +166,15 @@ def process_both(audio_path):
141
 
142
  return output_buffer, f"Encoded to {tokens.shape[1]} tokens\nDecoded {tokens.shape[1]} tokens to audio"
143
  except Exception as e:
 
144
  return None, f"Error processing audio: {str(e)}"
145
 
146
  # Create Gradio interface
147
  with gr.Blocks(title="Oterin Audio Codec") as demo:
148
  gr.Markdown("# Oterin Audio Codec")
149
- gr.Markdown("Upload an audio file to encode it to semantic tokens and decode back to audio.")
150
 
151
- # Make "Both" the primary default tab
152
- with gr.Tab("Encode & Decode"):
153
- with gr.Row():
154
- both_input = gr.Audio(type="filepath", label="Input Audio")
155
- both_output = gr.Audio(label="Reconstructed Audio")
156
- both_status = gr.Textbox(label="Status")
157
- both_btn = gr.Button("Process", variant="primary")
158
- both_btn.click(process_both, inputs=both_input, outputs=[both_output, both_status])
159
-
160
- gr.Markdown("""
161
- ## How it works
162
- This option encodes your audio to semantic tokens and immediately decodes it back to audio.
163
- It's the recommended way to use the codec as it handles all device management internally.
164
- """)
165
-
166
- # Keep separate functions as secondary options with warning
167
- with gr.Tab("Advanced (Encode Only)"):
168
- gr.Markdown("⚠️ **DEPRECATED**: Using separate encode/decode can lead to device mismatch errors. The combined Encode & Decode tab is recommended.")
169
  with gr.Row():
170
  encode_input = gr.Audio(type="filepath", label="Input Audio")
171
  encode_output = gr.File(label="Encoded Tokens (.oterin)", file_types=[".oterin"])
@@ -173,14 +182,21 @@ with gr.Blocks(title="Oterin Audio Codec") as demo:
173
  encode_btn = gr.Button("Encode")
174
  encode_btn.click(encode_audio, inputs=encode_input, outputs=[encode_output, encode_status])
175
 
176
- with gr.Tab("Advanced (Decode Only)"):
177
- gr.Markdown("⚠️ **DEPRECATED**: Using separate encode/decode can lead to device mismatch errors. The combined Encode & Decode tab is recommended.")
178
  with gr.Row():
179
  decode_input = gr.File(label="Token File (.oterin)", file_types=[".oterin"])
180
  decode_output = gr.Audio(label="Decoded Audio")
181
  decode_status = gr.Textbox(label="Status")
182
  decode_btn = gr.Button("Decode")
183
  decode_btn.click(decode_tokens, inputs=decode_input, outputs=[decode_output, decode_status])
 
 
 
 
 
 
 
 
184
 
185
  if __name__ == "__main__":
186
  demo.launch(share=True)
 
9
  import tempfile
10
  import io
11
  import uuid
12
+ import pickle
13
  from pathlib import Path
14
 
15
+ # Initialize the model and ensure it's on the correct device
16
  def load_model():
17
+ model = SemantiCodec(token_rate=100, semantic_vocab_size=32768) # 1.40 kbps
18
+ if torch.cuda.is_available():
19
+ # Move the model to CUDA
20
+ model.to("cuda:0")
21
+ return model
22
 
23
+ # Initialize model
24
  semanticodec = load_model()
25
+ # Get the device of the model
26
+ model_device = "cuda:0" if torch.cuda.is_available() else "cpu"
27
+ print(f"Model initialized on device: {model_device}")
28
 
29
  @spaces.GPU(duration=20)
30
  def encode_audio(audio_path):
31
  """Encode audio file to tokens and return them as a file"""
32
  try:
33
+ print(f"Encoding audio on device: {model_device}")
34
  tokens = semanticodec.encode(audio_path)
35
+ print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
36
+
37
  # Move tokens to CPU before converting to numpy
38
  if isinstance(tokens, torch.Tensor):
39
  tokens = tokens.cpu().numpy()
 
43
  # Reshape to match expected format [batch, seq_len, features]
44
  tokens = tokens.reshape(1, -1, 1)
45
 
46
+ # Save tokens in a way that preserves shape information
47
+ token_data = {
48
+ 'tokens': tokens,
49
+ 'shape': tokens.shape,
50
+ 'device': str(model_device) # Store intended device information
51
+ }
 
 
52
 
53
  # Create a temporary file in /tmp which is writable in Spaces
54
  temp_dir = "/tmp"
55
  os.makedirs(temp_dir, exist_ok=True)
56
  temp_file_path = os.path.join(temp_dir, f"tokens_{uuid.uuid4()}.oterin")
57
 
58
+ # Write using pickle instead of numpy save
59
  with open(temp_file_path, "wb") as f:
60
+ pickle.dump(token_data, f)
61
 
62
  # Verify the file exists and has content
63
  if not os.path.exists(temp_file_path) or os.path.getsize(temp_file_path) == 0:
 
65
 
66
  return temp_file_path, f"Encoded to {tokens.shape[1]} tokens"
67
  except Exception as e:
68
+ print(f"Encoding error: {str(e)}")
69
  return None, f"Error encoding audio: {str(e)}"
70
 
71
+ @spaces.GPU(duration=340)
72
  def decode_tokens(token_file):
73
  """Decode tokens to audio"""
74
  # Ensure the file exists and has content
 
76
  return None, "Error: Empty or missing token file"
77
 
78
  try:
79
+ # Load tokens using pickle instead of numpy load
80
+ with open(token_file, "rb") as f:
81
+ token_data = pickle.load(f)
82
+
83
+ tokens = token_data['tokens']
84
+ intended_device = token_data.get('device', model_device)
85
+ print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
86
+
87
+ # Convert to torch tensor
88
+ tokens_tensor = torch.tensor(tokens, dtype=torch.float32)
89
+ print(f"Tokens tensor created on device: {tokens_tensor.device}")
90
+
91
+ # Explicitly move tokens to the model's device
92
+ tokens_tensor = tokens_tensor.to(model_device)
93
+ print(f"Tokens moved to device: {tokens_tensor.device}")
94
+
95
+ # Also ensure model is on the expected device
96
+ semanticodec.to(model_device)
97
+ print(f"Model device before decode: {next(semanticodec.parameters()).device}")
98
 
99
  # Decode the tokens
100
+ waveform = semanticodec.decode(tokens_tensor)
101
+ print(f"Waveform device after decode: {waveform.device if isinstance(waveform, torch.Tensor) else 'numpy'}")
102
 
103
  # Move waveform to CPU for audio processing
104
  if isinstance(waveform, torch.Tensor):
 
115
 
116
  return output_buffer, f"Decoded {tokens.shape[1]} tokens to audio"
117
  except Exception as e:
118
+ print(f"Decoding error: {str(e)}")
119
  return None, f"Error decoding tokens: {str(e)}"
120
 
121
+ @spaces.GPU(duration=360)
122
  def process_both(audio_path):
123
  """Encode and then decode the audio without saving intermediate files"""
124
  try:
125
+ print(f"Processing both on device: {model_device}")
126
  # Encode
127
  tokens = semanticodec.encode(audio_path)
128
+ print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
129
+
130
  if isinstance(tokens, torch.Tensor):
131
  tokens = tokens.cpu().numpy()
132
 
 
136
  tokens = tokens.reshape(1, -1, 1)
137
 
138
  # Convert back to torch tensor (on CPU first)
139
+ tokens_tensor = torch.tensor(tokens, dtype=torch.float32)
140
+ print(f"Tokens tensor created on device: {tokens_tensor.device}")
141
 
142
+ # Explicitly move tokens to the model's device
143
+ tokens_tensor = tokens_tensor.to(model_device)
144
+ print(f"Tokens moved to device: {tokens_tensor.device}")
145
+
146
+ # Also ensure model is on the expected device
147
+ semanticodec.to(model_device)
148
+ print(f"Model device before decode: {next(semanticodec.parameters()).device}")
149
 
150
  # Decode
151
  waveform = semanticodec.decode(tokens_tensor)
152
+ print(f"Waveform device after decode: {waveform.device if isinstance(waveform, torch.Tensor) else 'numpy'}")
153
 
154
  # Move waveform to CPU for audio processing
155
  if isinstance(waveform, torch.Tensor):
 
166
 
167
  return output_buffer, f"Encoded to {tokens.shape[1]} tokens\nDecoded {tokens.shape[1]} tokens to audio"
168
  except Exception as e:
169
+ print(f"Processing error: {str(e)}")
170
  return None, f"Error processing audio: {str(e)}"
171
 
172
  # Create Gradio interface
173
  with gr.Blocks(title="Oterin Audio Codec") as demo:
174
  gr.Markdown("# Oterin Audio Codec")
175
+ gr.Markdown("Upload an audio file to encode it to semantic tokens, decode tokens back to audio, or do both.")
176
 
177
+ with gr.Tab("Encode Audio"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  with gr.Row():
179
  encode_input = gr.Audio(type="filepath", label="Input Audio")
180
  encode_output = gr.File(label="Encoded Tokens (.oterin)", file_types=[".oterin"])
 
182
  encode_btn = gr.Button("Encode")
183
  encode_btn.click(encode_audio, inputs=encode_input, outputs=[encode_output, encode_status])
184
 
185
+ with gr.Tab("Decode Tokens"):
 
186
  with gr.Row():
187
  decode_input = gr.File(label="Token File (.oterin)", file_types=[".oterin"])
188
  decode_output = gr.Audio(label="Decoded Audio")
189
  decode_status = gr.Textbox(label="Status")
190
  decode_btn = gr.Button("Decode")
191
  decode_btn.click(decode_tokens, inputs=decode_input, outputs=[decode_output, decode_status])
192
+
193
+ with gr.Tab("Both (Encode & Decode)"):
194
+ with gr.Row():
195
+ both_input = gr.Audio(type="filepath", label="Input Audio")
196
+ both_output = gr.Audio(label="Reconstructed Audio")
197
+ both_status = gr.Textbox(label="Status")
198
+ both_btn = gr.Button("Process")
199
+ both_btn.click(process_both, inputs=both_input, outputs=[both_output, both_status])
200
 
201
  if __name__ == "__main__":
202
  demo.launch(share=True)