ChatTTS-Forge / modules /denoise.py
zhzluke96
update
da8d589
raw
history blame
1.33 kB
import os
from typing import Union
import torch
import torchaudio
from modules.Denoiser.AudioDenoiser import AudioDenoiser
from modules.utils.constants import MODELS_DIR
from modules.devices import devices
import soundfile as sf
ad: Union[AudioDenoiser, None] = None
class TTSAudioDenoiser:
def load_ad(self):
global ad
if ad is None:
ad = AudioDenoiser(
os.path.join(
MODELS_DIR,
"Denoise",
"audio-denoiser-512-32-v1",
),
device=devices.device,
)
ad.model.to(devices.device)
return ad
def denoise(self, audio_data, sample_rate, auto_scale=False):
ad = self.load_ad()
sr = ad.model_sample_rate
return sr, ad.process_waveform(audio_data, sample_rate, auto_scale)
if __name__ == "__main__":
tts_deno = TTSAudioDenoiser()
data, sr = sf.read("test.wav")
audio_tensor = torch.from_numpy(data).unsqueeze(0).float()
print(audio_tensor)
# data, sr = torchaudio.load("test.wav")
# print(data)
# data = data.to(devices.device)
sr, denoised = tts_deno.denoise(audio_data=audio_tensor, sample_rate=sr)
denoised = denoised.cpu()
torchaudio.save("denoised.wav", denoised, sample_rate=sr)