Kr08 commited on
Commit
8dbd60e
·
verified ·
1 Parent(s): d8c8b8a

Update audio_processing.py

Browse files
Files changed (1) hide show
  1. audio_processing.py +30 -89
audio_processing.py CHANGED
@@ -1,78 +1,38 @@
1
  import torch
2
- import spaces
3
  import whisper
4
- import subprocess
5
- import numpy as np
6
- import gradio as gr
7
- import soundfile as sf
8
  import torchaudio as ta
9
-
10
  from model_utils import get_processor, get_model, get_whisper_model_small, get_device
11
  from config import SAMPLING_RATE, CHUNK_LENGTH_S
12
-
13
-
14
- # def resample_with_ffmpeg(input_file, output_file, target_sr=16000):
15
- # command = [
16
- # 'ffmpeg', '-i', input_file, '-ar', str(target_sr), output_file
17
- # ]
18
- # subprocess.run(command, check=True)
19
 
20
 
21
  @spaces.GPU
22
- def load_and_resample_audio(file):
23
- try:
24
- # First attempt: Use torchaudio.load()
25
- waveform, sample_rate = torchaudio.load(file)
26
- except Exception as e:
27
- print(f"torchaudio.load() failed: {e}")
28
- try:
29
- # Second attempt: Use soundfile
30
- waveform, sample_rate = sf.read(file)
31
- waveform = torch.from_numpy(waveform.T).float()
32
- if waveform.dim() == 1:
33
- waveform = waveform.unsqueeze(0)
34
- except Exception as e:
35
- print(f"soundfile.read() failed: {e}")
36
- raise ValueError(f"Failed to load audio file: {file}")
37
-
38
- print(f"Original audio shape: {waveform.shape}, Sample rate: {sample_rate}")
39
 
40
  if sample_rate != SAMPLING_RATE:
41
- try:
42
- waveform = F.resample(waveform, sample_rate, SAMPLING_RATE)
43
- except Exception as e:
44
- print(f"Resampling failed: {e}")
45
- raise ValueError(f"Failed to resample audio from {sample_rate} to {SAMPLING_RATE}")
46
 
47
  # Ensure the audio is in the correct shape (mono)
48
  if waveform.dim() > 1 and waveform.shape[0] > 1:
49
  waveform = waveform.mean(dim=0, keepdim=True)
50
-
51
- print(f"Processed audio shape: {waveform.shape}, New sample rate: {SAMPLING_RATE}")
52
 
53
  return waveform, SAMPLING_RATE
54
 
 
55
  @spaces.GPU
56
- def detect_language(audio):
57
  whisper_model = get_whisper_model_small()
58
 
59
- # Save the input audio to a temporary file
60
- ta.save("input_audio.wav", torch.tensor(audio[1]).unsqueeze(0), audio[0])
61
-
62
- # Resample if necessary using ffmpeg
63
- if audio[0] != SAMPLING_RATE:
64
- resample_with_ffmpeg("input_audio.wav", "resampled_audio.wav", target_sr=SAMPLING_RATE)
65
- audio_tensor, _ = ta.load("resampled_audio.wav")
66
- else:
67
- audio_tensor = torch.tensor(audio[1]).float()
68
-
69
- # Ensure the audio is in the correct shape (mono)
70
- if audio_tensor.dim() == 2:
71
- audio_tensor = audio_tensor.mean(dim=0)
72
-
73
  # Use Whisper's preprocessing
74
- audio_tensor = whisper.pad_or_trim(audio_tensor)
75
- print(f"Audio length after pad/trim: {audio_tensor.shape[-1] / SAMPLING_RATE} seconds")
76
  mel = whisper.log_mel_spectrogram(audio_tensor).to(whisper_model.device)
77
 
78
  # Detect language
@@ -88,45 +48,20 @@ def detect_language(audio):
88
 
89
 
90
  @spaces.GPU
91
- def process_long_audio(audio, task="transcribe", language=None):
92
- if audio[0] != SAMPLING_RATE:
93
- # Save the input audio to a file for ffmpeg processing
94
- ta.save("input_audio_1.wav", torch.tensor(audio[1]).unsqueeze(0), audio[0])
95
-
96
- # Resample using ffmpeg
97
- try:
98
- resample_with_ffmpeg("input_audio_1.wav", "resampled_audio_2.wav", target_sr=SAMPLING_RATE)
99
- except subprocess.CalledProcessError as e:
100
- print(f"ffmpeg failed: {e.stderr}")
101
- raise e
102
-
103
- waveform, _ = ta.load("resampled_audio_2.wav")
104
- else:
105
- waveform = torch.tensor(audio[1]).float()
106
-
107
- # Ensure the audio is in the correct shape (mono)
108
- if waveform.dim() == 2:
109
- waveform = waveform.mean(dim=0)
110
-
111
- print(f"Waveform shape after processing: {waveform.shape}")
112
-
113
- if waveform.numel() == 0:
114
- raise ValueError("Waveform is empty. Please check the input audio file.")
115
 
116
- input_length = waveform.shape[0] # Since waveform is 1D, access the length with shape[0]
117
- chunk_length = int(CHUNK_LENGTH_S * SAMPLING_RATE)
118
 
119
- # Corrected slicing for 1D tensor
120
- chunks = [waveform[i:i + chunk_length] for i in range(0, input_length, chunk_length)]
121
-
122
- # Initialize the processor
123
  processor = get_processor()
124
  model = get_model()
125
  device = get_device()
126
 
127
  results = []
128
  for chunk in chunks:
129
- input_features = processor(chunk, sampling_rate=SAMPLING_RATE, return_tensors="pt").input_features.to(device)
 
130
 
131
  with torch.no_grad():
132
  if task == "translate":
@@ -149,12 +84,15 @@ def process_audio(audio):
149
  if audio is None:
150
  return "No file uploaded", "", ""
151
 
152
- detected_lang = detect_language(audio)
153
- transcription = process_long_audio(audio, task="transcribe")
154
- translation = process_long_audio(audio, task="translate", language=detected_lang)
 
 
155
 
156
  return detected_lang, transcription, translation
157
 
 
158
  # Gradio interface
159
  iface = gr.Interface(
160
  fn=process_audio,
@@ -168,4 +106,7 @@ iface = gr.Interface(
168
  description="Upload an audio file to detect its language, transcribe, and translate it.",
169
  allow_flagging="never",
170
  css=".output-textbox { font-family: 'Noto Sans Devanagari', sans-serif; font-size: 18px; }"
171
- )
 
 
 
 
1
  import torch
 
2
  import whisper
 
 
 
 
3
  import torchaudio as ta
4
+ import gradio as gr
5
  from model_utils import get_processor, get_model, get_whisper_model_small, get_device
6
  from config import SAMPLING_RATE, CHUNK_LENGTH_S
7
+ import spaces
 
 
 
 
 
 
8
 
9
 
10
  @spaces.GPU
11
+ def load_and_resample_audio(audio):
12
+ if isinstance(audio, str): # If audio is a file path
13
+ waveform, sample_rate = ta.load(audio)
14
+ else: # If audio is already loaded (sample_rate, waveform)
15
+ sample_rate, waveform = audio
16
+ waveform = torch.tensor(waveform).float()
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  if sample_rate != SAMPLING_RATE:
19
+ waveform = ta.functional.resample(waveform, sample_rate, SAMPLING_RATE)
 
 
 
 
20
 
21
  # Ensure the audio is in the correct shape (mono)
22
  if waveform.dim() > 1 and waveform.shape[0] > 1:
23
  waveform = waveform.mean(dim=0, keepdim=True)
24
+ elif waveform.dim() == 1:
25
+ waveform = waveform.unsqueeze(0)
26
 
27
  return waveform, SAMPLING_RATE
28
 
29
+
30
  @spaces.GPU
31
+ def detect_language(waveform):
32
  whisper_model = get_whisper_model_small()
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Use Whisper's preprocessing
35
+ audio_tensor = whisper.pad_or_trim(waveform.squeeze())
 
36
  mel = whisper.log_mel_spectrogram(audio_tensor).to(whisper_model.device)
37
 
38
  # Detect language
 
48
 
49
 
50
  @spaces.GPU
51
+ def process_long_audio(waveform, sample_rate, task="transcribe", language=None):
52
+ input_length = waveform.shape[1]
53
+ chunk_length = int(CHUNK_LENGTH_S * sample_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ chunks = [waveform[:, i:i + chunk_length] for i in range(0, input_length, chunk_length)]
 
56
 
 
 
 
 
57
  processor = get_processor()
58
  model = get_model()
59
  device = get_device()
60
 
61
  results = []
62
  for chunk in chunks:
63
+ input_features = processor(chunk.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_features.to(
64
+ device)
65
 
66
  with torch.no_grad():
67
  if task == "translate":
 
84
  if audio is None:
85
  return "No file uploaded", "", ""
86
 
87
+ waveform, sample_rate = load_and_resample_audio(audio)
88
+
89
+ detected_lang = detect_language(waveform)
90
+ transcription = process_long_audio(waveform, sample_rate, task="transcribe")
91
+ translation = process_long_audio(waveform, sample_rate, task="translate", language=detected_lang)
92
 
93
  return detected_lang, transcription, translation
94
 
95
+
96
  # Gradio interface
97
  iface = gr.Interface(
98
  fn=process_audio,
 
106
  description="Upload an audio file to detect its language, transcribe, and translate it.",
107
  allow_flagging="never",
108
  css=".output-textbox { font-family: 'Noto Sans Devanagari', sans-serif; font-size: 18px; }"
109
+ )
110
+
111
+ if __name__ == "__main__":
112
+ iface.launch()