owiedotch commited on
Commit
eb0f782
·
verified ·
1 Parent(s): 9b97aff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -63
app.py CHANGED
@@ -1,88 +1,124 @@
1
  import gradio as gr
2
- import jax
3
- import jax.numpy as jnp
4
- import librosa
5
- import dac_jax
6
- from dac_jax.audio_utils import volume_norm, db2linear
7
- import spaces
8
  import tempfile
9
- import os
10
  import numpy as np
 
 
 
 
11
 
12
- # Download a model and bind variables to it.
13
- model, variables = dac_jax.load_model(model_type="44khz")
14
- model = model.bind(variables)
15
-
16
- @spaces.GPU
17
- def encode(audio_file_path):
18
- try:
19
- # Load audio with librosa, specifying duration
20
- signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True) # Set duration as needed
21
-
22
- signal = jnp.array(signal, dtype=jnp.float32)
23
- while signal.ndim < 3:
24
- signal = jnp.expand_dims(signal, axis=0)
25
-
26
- target_db = -16 # Normalize audio to -16 dB
27
- x, input_db = volume_norm(signal, target_db, sample_rate)
28
 
29
- # Encode audio signal
30
- x = model.preprocess(x, sample_rate)
31
- z, codes, latents, commitment_loss, codebook_loss = model.encode(x, train=False)
 
32
 
33
- # Save encoded data to a temporary file (using numpy.savez for now)
34
- with tempfile.NamedTemporaryFile(delete=False, suffix=".npz") as temp_file:
35
- np.savez(temp_file.name, z=z, codes=codes, latents=latents, input_db=input_db)
36
 
37
- return temp_file.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  except Exception as e:
40
- gr.Warning(f"An error occurred during encoding: {e}")
41
- return None
42
 
43
- @spaces.GPU
44
- def decode(compressed_file_path): # Changed input to compressed_file_path
45
  try:
46
- # Load encoded data directly from the file path
47
- data = np.load(compressed_file_path) # No need for temporary files
48
- z = data['z']
49
- codes = data['codes']
50
- latents = data['latents']
51
- input_db = data['input_db']
 
 
 
 
 
 
 
 
 
52
 
53
- # Decode audio signal
54
- y = model.decode(z, length=z.shape[1] * model.hop_length)
55
-
56
- # Undo previous loudness normalization
57
- y = y * db2linear(input_db - (-16)) # Using -16 as the target_db
58
 
59
- decoded_audio = np.array(y).squeeze()
60
- return (44100, decoded_audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  except Exception as e:
63
- gr.Warning(f"An error occurred during decoding: {e}")
64
- return None
65
 
66
- # Gradio interface
67
  with gr.Blocks() as demo:
68
- gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>")
69
 
70
  with gr.Tab("Encode"):
71
- with gr.Row():
72
- audio_input = gr.Audio(type="filepath", label="Input Audio")
73
- encode_button = gr.Button("Encode", variant="primary")
74
- with gr.Row():
75
- encoded_output = gr.File(label="Compressed Audio (.npz)")
76
 
77
- encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
78
 
79
  with gr.Tab("Decode"):
80
- with gr.Row():
81
- compressed_input = gr.File(label="Compressed Audio (.npz)")
82
- decode_button = gr.Button("Decode", variant="primary")
83
- with gr.Row():
84
- decoded_output = gr.Audio(label="Decompressed Audio")
 
 
 
 
 
85
 
86
- decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
87
 
88
  demo.queue().launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ from agc import AGC
 
 
 
5
  import tempfile
 
6
  import numpy as np
7
+ import lz4.frame
8
+ import os
9
+ from typing import Generator
10
+ import spaces
11
 
12
+ # Attempt to use GPU, fallback to CPU
13
+ try:
14
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ print(f"Using device: {torch_device}")
16
+ except Exception as e:
17
+ print(f"Error detecting GPU. Using CPU. Error: {e}")
18
+ torch_device = torch.device("cpu")
 
 
 
 
 
 
 
 
 
19
 
20
+ # Load the AGC model
21
+ @spaces.GPU(duration=180)
22
+ def load_agc_model():
23
+ return AGC.from_pretrained("Audiogen/agc-continuous").to(torch_device)
24
 
25
+ agc = load_agc_model()
 
 
26
 
27
+ @spaces.GPU(duration=180)
28
+ def encode_audio(audio_file_path):
29
+ try:
30
+ # Load the audio file
31
+ waveform, sample_rate = torchaudio.load(audio_file_path)
32
+
33
+ # Convert to stereo if necessary
34
+ if waveform.size(0) == 1:
35
+ waveform = waveform.repeat(2, 1)
36
+
37
+ # Encode the audio
38
+ audio = waveform.unsqueeze(0).to(torch_device)
39
+ with torch.no_grad():
40
+ z = agc.encode(audio)
41
+
42
+ # Convert to NumPy and save to a temporary .owie file
43
+ z_numpy = z.detach().cpu().numpy()
44
+ temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
45
+ os.close(temp_fd) # Close the file descriptor to avoid issues with os.fdopen
46
+ with open(temp_file_path, 'wb') as temp_file:
47
+ compressed_data = lz4.frame.compress(z_numpy.tobytes())
48
+ temp_file.write(compressed_data)
49
+
50
+ return temp_file_path
51
 
52
  except Exception as e:
53
+ return f"Encoding error: {e}"
 
54
 
55
+ @spaces.GPU(duration=180)
56
+ def decode_audio(encoded_file_path):
57
  try:
58
+ # Load encoded data from the .owie file
59
+ with open(encoded_file_path, 'rb') as temp_file:
60
+ compressed_data = temp_file.read()
61
+ z_numpy_bytes = lz4.frame.decompress(compressed_data)
62
+ z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1)
63
+ z = torch.from_numpy(z_numpy).to(torch_device)
64
+
65
+ # Decode the audio
66
+ with torch.no_grad():
67
+ reconstructed_audio = agc.decode(z)
68
+
69
+ # Save to a temporary WAV file
70
+ temp_wav_path = tempfile.mktemp(suffix=".wav")
71
+ torchaudio.save(temp_wav_path, reconstructed_audio.squeeze(0).cpu(), sample_rate)
72
+ return temp_wav_path
73
 
74
+ except Exception as e:
75
+ return f"Decoding error: {e}"
 
 
 
76
 
77
+ @spaces.GPU(duration=180)
78
+ def stream_decode_audio(encoded_file_path) -> Generator[np.ndarray, None, None]:
79
+ try:
80
+ # Load encoded data from the .owie file
81
+ with open(encoded_file_path, 'rb') as temp_file:
82
+ compressed_data = temp_file.read()
83
+ z_numpy_bytes = lz4.frame.decompress(compressed_data)
84
+ z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1)
85
+ z = torch.from_numpy(z_numpy).to(torch_device)
86
+
87
+ # Decode the audio in chunks
88
+ chunk_size = 16000 # 1 second of audio at 16kHz
89
+ with torch.no_grad():
90
+ for i in range(0, z.shape[2], chunk_size):
91
+ z_chunk = z[:, :, i:i+chunk_size]
92
+ audio_chunk = agc.decode(z_chunk)
93
+ yield audio_chunk.squeeze(0).cpu().numpy()
94
 
95
  except Exception as e:
96
+ yield np.zeros((2, chunk_size)) # Return silence in case of error
97
+ print(f"Streaming decoding error: {e}")
98
 
99
+ # Gradio Interface
100
  with gr.Blocks() as demo:
101
+ gr.Markdown("## Audio Compression with AGC (GPU/CPU)")
102
 
103
  with gr.Tab("Encode"):
104
+ input_audio = gr.Audio(label="Input Audio", type="filepath")
105
+ encode_button = gr.Button("Encode")
106
+ encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")
 
 
107
 
108
+ encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output)
109
 
110
  with gr.Tab("Decode"):
111
+ input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
112
+ decode_button = gr.Button("Decode")
113
+ decoded_output = gr.Audio(label="Decoded Audio", type="filepath")
114
+
115
+ decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
116
+
117
+ with gr.Tab("Streaming"):
118
+ input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
119
+ stream_button = gr.Button("Start Streaming")
120
+ audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
121
 
122
+ stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
123
 
124
  demo.queue().launch()