File size: 5,649 Bytes
8235b4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d7e2e4
8235b4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import argparse
import logging
import os
import sys

import librosa
import torch
import tqdm

from .data.data import EvalDataLoader, EvalDataset
from . import distrib
from .utils import remove_pad

from .utils import bold, deserialize_model, LogProgress
logger = logging.getLogger(__name__)

def load_model():
    global device
    global model
    global pkg
    print("Loading svoice model if available...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pkg = torch.load('checkpoint.th', map_location=device)
    if 'model' in pkg:
        model = pkg['model']
    else:
        model = pkg
    model = deserialize_model(model)
    logger.debug(model)
    model.eval()
    model.to(device)
    print("svoice model loaded.")
    print("Device: {}".format(device))

parser = argparse.ArgumentParser("Speech separation using MulCat blocks")
parser.add_argument("model_path", type=str, help="Model name")
parser.add_argument("out_dir", type=str, default="exp/result",
                    help="Directory putting enhanced wav files")
parser.add_argument("--mix_dir", type=str, default=None,
                    help="Directory including mix wav files")
parser.add_argument("--mix_json", type=str, default=None,
                    help="Json file including mix wav files")
parser.add_argument('--device', default="cuda")
parser.add_argument("--sample_rate", default=8000,
                    type=int, help="Sample rate")
parser.add_argument("--batch_size", default=1, type=int, help="Batch size")
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
                    default=logging.INFO, help="More loggging")

def save_wavs(estimate_source, mix_sig, lengths, filenames, out_dir, sr=16000):
    # Remove padding and flat
    flat_estimate = remove_pad(estimate_source, lengths)
    mix_sig = remove_pad(mix_sig, lengths)
    # Write result
    for i, filename in enumerate(filenames):
        filename = os.path.join(
            out_dir, os.path.basename(filename).strip(".wav"))
        write(mix_sig[i], filename + ".wav", sr=sr)
        C = flat_estimate[i].shape[0]
        # future support for wave playing
        for c in range(C):
            write(flat_estimate[i][c], filename + f"_s{c + 1}.wav", sr=sr)


def write(inputs, filename, sr=8000):
    librosa.output.write_wav(filename, inputs, sr, norm=True)

def separate_demo(mix_dir='mix/', batch_size=1, sample_rate=16000):
    mix_dir, mix_json = mix_dir, None
    out_dir = 'separated'
    # Load data
    eval_dataset = EvalDataset(
        mix_dir,
        mix_json,
        batch_size=batch_size,
        sample_rate=sample_rate,
    )
    eval_loader = distrib.loader(
        eval_dataset, batch_size=1, klass=EvalDataLoader)

    if distrib.rank == 0:
        os.makedirs(out_dir, exist_ok=True)
    distrib.barrier()

    with torch.no_grad():
        for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)):
            # Get batch data
            mixture, lengths, filenames = data
            mixture = mixture.to(device)
            lengths = lengths.to(device)
            # Forward
            estimate_sources = model(mixture)[-1]
            # save wav files
            save_wavs(estimate_sources, mixture, lengths,
                      filenames, out_dir, sr=sample_rate)

    separated_files = [os.path.join(out_dir, f) for f in os.listdir(out_dir)]
    separated_files = [os.path.abspath(f) for f in separated_files]
    separated_files = [f for f in separated_files if not f.endswith('original.wav')]
    return separated_files

def get_mix_paths(args):
    mix_dir = None
    mix_json = None
    # fix mix dir
    try:
        if args.dset.mix_dir:
           mix_dir = args.dset.mix_dir
    except:
        mix_dir = args.mix_dir

    # fix mix json
    try:
        if args.dset.mix_json:
            mix_json = args.dset.mix_json
    except:
        mix_json = args.mix_json
    return mix_dir, mix_json


def separate(args, model=None, local_out_dir=None):
    mix_dir, mix_json = get_mix_paths(args)
    if not mix_json and not mix_dir:
        logger.error("Must provide mix_dir or mix_json! "
                     "When providing mix_dir, mix_json is ignored.")
    # Load model
    if not model:
        # model
        pkg = torch.load(args.model_path)
        if 'model' in pkg:
            model = pkg['model']
        else:
            model = pkg
        model = deserialize_model(model)
        logger.debug(model)
    model.eval()
    model.to(args.device)
    if local_out_dir:
        out_dir = local_out_dir
    else:
        out_dir = args.out_dir

    # Load data
    eval_dataset = EvalDataset(
        mix_dir,
        mix_json,
        batch_size=args.batch_size,
        sample_rate=args.sample_rate,
    )
    eval_loader = distrib.loader(
        eval_dataset, batch_size=1, klass=EvalDataLoader)

    if distrib.rank == 0:
        os.makedirs(out_dir, exist_ok=True)
    distrib.barrier()

    with torch.no_grad():
        for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)):
            # Get batch data
            mixture, lengths, filenames = data
            mixture = mixture.to(args.device)
            lengths = lengths.to(args.device)
            # Forward
            estimate_sources = model(mixture)[-1]
            # save wav files
            save_wavs(estimate_sources, mixture, lengths,
                      filenames, out_dir, sr=args.sample_rate)


if __name__ == "__main__":
    args = parser.parse_args()
    logging.basicConfig(stream=sys.stderr, level=args.verbose)
    logger.debug(args)
    separate(args, local_out_dir=args.out_dir)