InterpreTalk / backend /Client.py
benjolo's picture
Upload 43 files
1778490 verified
raw
history blame
2.85 kB
from typing import Tuple
import wave
import os
import torchaudio
from vad import EnergyVAD
TARGET_SAMPLING_RATE = 16000
def create_frames(data: bytes, frame_duration: int) -> Tuple[bytes]:
frame_size = int(TARGET_SAMPLING_RATE * (frame_duration / 1000))
return (data[i:i + frame_size] for i in range(0, len(data), frame_size)), frame_size
def detect_activity(energies: list):
if sum(energies) < len(energies) / 12:
return False
count = 0
for energy in energies:
if energy == 1:
count += 1
if count == 12:
return True
else:
count = 0
return False
class Client:
def __init__(self, sid, client_id, call_id=None, original_sr=None):
self.sid = sid
self.client_id = client_id
self.call_id = call_id
self.buffer = bytearray()
self.output_path = self.sid + "_output_audio.wav"
self.target_language = None
self.original_sr = original_sr
self.vad = EnergyVAD(
sample_rate=TARGET_SAMPLING_RATE,
frame_length=25,
frame_shift=20,
energy_threshold=0.05,
pre_emphasis=0.95,
) # PM - Default values given in the docs for this class
def add_bytes(self, new_bytes):
self.buffer += new_bytes
def resample_and_clear(self):
print(f"πŸ“₯ [ClientAudioBuffer] Writing {len(self.buffer)} bytes to {self.output_path}")
with wave.open(self.sid + "_OG.wav", "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(self.original_sr)
wf.setnframes(0)
wf.setcomptype("NONE", "not compressed")
wf.writeframes(self.buffer)
waveform, sample_rate = torchaudio.load(self.sid + "_OG.wav")
resampler = torchaudio.transforms.Resample(sample_rate, TARGET_SAMPLING_RATE, dtype=waveform.dtype)
resampled_waveform = resampler(waveform)
self.buffer = bytearray()
return resampled_waveform
def vad_analyse(self, resampled_waveform):
torchaudio.save(self.output_path, resampled_waveform, TARGET_SAMPLING_RATE)
vad_array = self.vad(resampled_waveform)
print(f"VAD OUTPUT: {vad_array}")
return detect_activity(vad_array)
def write_to_file(self, resampled_waveform):
torchaudio.save(self.output_path, resampled_waveform, TARGET_SAMPLING_RATE)
def get_length(self):
return len(self.buffer)
def __del__(self):
if len(self.buffer) > 0:
print(f"🚨 [ClientAudioBuffer] Buffer not empty for {self.sid} ({len(self.buffer)} bytes)!")
if os.path.exists(self.output_path):
os.remove(self.output_path)
if os.path.exists(self.sid + "_OG.wav"):
os.remove(self.sid + "_OG.wav")