File size: 3,046 Bytes
ed57756
 
 
 
 
 
 
cc15592
ed57756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c342b39
 
 
ed57756
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import pathlib
import yaml
import torch
import torchaudio
import numpy as np
from lightning_module import SSLDualLightningModule
import gradio as gr
import subprocess

def normalize_waveform(wav, sr, db=-3):
    wav, _ = torchaudio.sox_effects.apply_effects_tensor(
        wav.unsqueeze(0),
        sr,
        [["norm", "{}".format(db)]],
    )
    return wav.squeeze(0)

def calc_spectrogram(wav, config):
    spec_module = torchaudio.transforms.MelSpectrogram(
            sample_rate=config["preprocess"]["sampling_rate"],
            n_fft=config["preprocess"]["fft_length"],
            win_length=config["preprocess"]["frame_length"],
            hop_length=config["preprocess"]["frame_shift"],
            f_min=config["preprocess"]["fmin"],
            f_max=config["preprocess"]["fmax"],
            n_mels=config["preprocess"]["n_mels"],
            power=1,
            center=True,
            norm="slaney",
            mel_scale="slaney",
    )
    specs = spec_module(wav)
    log_spec = torch.log(
        torch.clamp_min(specs, config["preprocess"]["min_magnitude"])
        * config["preprocess"]["comp_factor"]
    ).to(torch.float32)
    return log_spec

def transfer(audio):
    wp_src = pathlib.Path("aet_sample/src.wav")
    wav_src, sr = torchaudio.load(wp_src)
    sr_inp, wav_tar = audio
    wav_tar = wav_tar / (np.max(np.abs(wav_tar)) * 1.1)
    wav_tar = torch.from_numpy(wav_tar.astype(np.float32))
    resampler = torchaudio.transforms.Resample(
        orig_freq=sr_inp,
        new_freq=sr,
    )
    wav_tar = resampler(wav_tar)
    config_path = pathlib.Path("configs/test/melspec/ssl_tono.yaml")
    config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)

    melspec_src = calc_spectrogram(
        normalize_waveform(wav_src.squeeze(0), sr), config
    )
    wav_tar = normalize_waveform(wav_tar.squeeze(0), sr)
    ckpt_path = pathlib.Path("aet_sample/tono_melspec_aet.ckpt")
    src_model = SSLDualLightningModule(config).load_from_checkpoint(
            checkpoint_path=ckpt_path,
            config=config,
    )

    encoder_src = src_model.encoder
    channelfeats_src = src_model.channelfeats
    channel_src = src_model.channel

    _, enc_hidden_src = encoder_src(
        melspec_src.unsqueeze(0).unsqueeze(1).transpose(2, 3)
    )
    chfeats_src = channelfeats_src(enc_hidden_src)
    wav_transfer = channel_src(wav_tar.unsqueeze(0), chfeats_src)
    wav_transfer = wav_transfer.detach().numpy()[0, :]
    return sr, wav_transfer

if __name__ == "__main__":
    subprocess.run(["curl", "-OL", "https://sarulab.sakura.ne.jp/saeki/selfremaster/pretrained/tono_aet_melspec.ckpt"])
    subprocess.run(["mv", "tono_aet_melspec.ckpt", "aet_sample"])

    iface = gr.Interface(
        transfer,
        "audio",
        gr.outputs.Audio(type="numpy"),
        examples=[
            ["aet_sample/tar.wav"]
        ],
        title='Audio effect transfer demo',
        description='Add channel feature of Japanese old audio recording to any high-quality audio' 
    )

    iface.launch()