Kr08 commited on
Commit
43f1b5e
·
verified ·
1 Parent(s): 6c36e37

Update audio_processing.py

Browse files
Files changed (1) hide show
  1. audio_processing.py +54 -1
audio_processing.py CHANGED
@@ -47,7 +47,60 @@ def detect_language(audio):
47
  return detected_lang
48
 
49
  def process_long_audio(audio, task="transcribe", language=None):
50
- # ... (rest of the function remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def process_audio(audio):
53
  if audio is None:
 
47
  return detected_lang
48
 
49
  def process_long_audio(audio, task="transcribe", language=None):
50
+ if audio[0] != SAMPLING_RATE:
51
+ # Save the input audio to a file for ffmpeg processing
52
+ ta.save("input_audio_1.wav", torch.tensor(audio[1]).unsqueeze(0), audio[0])
53
+
54
+ # Resample using ffmpeg
55
+ try:
56
+ resample_with_ffmpeg("input_audio_1.wav", "resampled_audio_2.wav", target_sr=SAMPLING_RATE)
57
+ except subprocess.CalledProcessError as e:
58
+ print(f"ffmpeg failed: {e.stderr}")
59
+ raise e
60
+
61
+ waveform, _ = ta.load("resampled_audio_2.wav")
62
+ else:
63
+ waveform = torch.tensor(audio[1]).float()
64
+
65
+ # Ensure the audio is in the correct shape (mono)
66
+ if waveform.dim() == 2:
67
+ waveform = waveform.mean(dim=0)
68
+
69
+ print(f"Waveform shape after processing: {waveform.shape}")
70
+
71
+ if waveform.numel() == 0:
72
+ raise ValueError("Waveform is empty. Please check the input audio file.")
73
+
74
+ input_length = waveform.shape[0] # Since waveform is 1D, access the length with shape[0]
75
+ chunk_length = int(CHUNK_LENGTH_S * SAMPLING_RATE)
76
+
77
+ # Corrected slicing for 1D tensor
78
+ chunks = [waveform[i:i + chunk_length] for i in range(0, input_length, chunk_length)]
79
+
80
+ # Initialize the processor
81
+ processor = get_processor()
82
+ model = get_model()
83
+ device = get_device()
84
+
85
+ results = []
86
+ for chunk in chunks:
87
+ input_features = processor(chunk, sampling_rate=SAMPLING_RATE, return_tensors="pt").input_features.to(device)
88
+
89
+ with torch.no_grad():
90
+ if task == "translate":
91
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task="translate")
92
+ generated_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
93
+ else:
94
+ generated_ids = model.generate(input_features)
95
+
96
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
97
+ results.extend(transcription)
98
+
99
+ # Clear GPU cache
100
+ torch.cuda.empty_cache()
101
+
102
+ return " ".join(results)
103
+
104
 
105
  def process_audio(audio):
106
  if audio is None: