|
import torch |
|
import numpy as np |
|
import joblib |
|
from s3prl.nn import S3PRLUpstream |
|
import soundfile as sf |
|
from argparse import ArgumentParser |
|
import sentencepiece as spm |
|
import os |
|
from tqdm import tqdm |
|
|
|
|
|
class ApplyKmeans(object): |
|
def __init__(self, km_path, use_gpu): |
|
self.km_model = joblib.load(km_path) |
|
self.C_np = self.km_model.cluster_centers_.transpose() |
|
self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True) |
|
|
|
self.C = torch.from_numpy(self.C_np) |
|
self.Cnorm = torch.from_numpy(self.Cnorm_np) |
|
if use_gpu and torch.cuda.is_available(): |
|
self.C = self.C.cuda() |
|
self.Cnorm = self.Cnorm.cuda() |
|
|
|
def __call__(self, x): |
|
if isinstance(x, torch.Tensor): |
|
x = x.to(self.C.device) |
|
dist = ( |
|
x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm |
|
) |
|
return dist.argmin(dim=1).cpu().numpy() |
|
else: |
|
dist = ( |
|
(x**2).sum(1, keepdims=True) |
|
- 2 * np.matmul(x, self.C_np) |
|
+ self.Cnorm_np |
|
) |
|
return np.argmin(dist, axis=1) |
|
|
|
|
|
def streaming_extract(wav, window_size=60): |
|
chunk_audios = [] |
|
for i in tqdm(range(0, wav.shape[0], window_size * 16)): |
|
batched_audio = (torch.tensor(wav[i : i+window_size*16]).unsqueeze(0), torch.tensor([window_size*16])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser() |
|
parser.add_argument("--km_model", default='./km_2000.mdl', help="Path to the kmeans model") |
|
parser.add_argument("--bpe_model", default='./bpe.model', help="Path to the bpe model") |
|
parser.add_argument("--audio", required=True, help="Path to the audio file") |
|
parser.add_argument("-s", action='store_true', help="Streaming mode") |
|
|
|
args = parser.parse_args() |
|
kmeans_path = args.km_model |
|
bpe_path = args.bpe_model |
|
audio_file = args.audio |
|
streaming = args.s |
|
|
|
apply_kmeans = ApplyKmeans(kmeans_path, use_gpu=True) |
|
ssl_model = S3PRLUpstream("hf_hubert_custom", path_or_url='TencentGameMate/chinese-hubert-large') |
|
ssl_model.eval() |
|
sp = spm.SentencePieceProcessor(model_file=bpe_path) |
|
|
|
unit_to_char = {} |
|
for l in open('distinct_cjk_token_lists').readlines(): |
|
l = l.split() |
|
unit_to_char[int(l[0])] = l[1] |
|
|
|
wav, sr = sf.read(audio_file) |
|
with torch.no_grad(): |
|
if streaming: |
|
assert False, "streaming mode is still developing" |
|
|
|
else: |
|
all_hs, all_hs_len = ssl_model(torch.tensor([wav]), torch.tensor([wav.shape[0]])) |
|
|
|
ssl_units = apply_kmeans(all_hs[20][0, :, :].numpy()) |
|
print(ssl_units) |
|
ssl_char = "".join([unit_to_char[c] for c in ssl_units]) |
|
ssl_char_bpe = sp.encode(ssl_char, out_type=str) |
|
print(ssl_char) |
|
print(ssl_char_bpe) |
|
|
|
|