File size: 3,139 Bytes
9a26655 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import torch
import numpy as np
import joblib
from s3prl.nn import S3PRLUpstream
import soundfile as sf
from argparse import ArgumentParser
import sentencepiece as spm
import os
from tqdm import tqdm
class ApplyKmeans(object):
def __init__(self, km_path, use_gpu):
self.km_model = joblib.load(km_path)
self.C_np = self.km_model.cluster_centers_.transpose()
self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True)
self.C = torch.from_numpy(self.C_np)
self.Cnorm = torch.from_numpy(self.Cnorm_np)
if use_gpu and torch.cuda.is_available():
self.C = self.C.cuda()
self.Cnorm = self.Cnorm.cuda()
def __call__(self, x):
if isinstance(x, torch.Tensor):
x = x.to(self.C.device)
dist = (
x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
)
return dist.argmin(dim=1).cpu().numpy()
else:
dist = (
(x**2).sum(1, keepdims=True)
- 2 * np.matmul(x, self.C_np)
+ self.Cnorm_np
)
return np.argmin(dist, axis=1)
def streaming_extract(wav, window_size=60):
chunk_audios = []
for i in tqdm(range(0, wav.shape[0], window_size * 16)):
batched_audio = (torch.tensor(wav[i : i+window_size*16]).unsqueeze(0), torch.tensor([window_size*16]))
# all_hs = []
# for i in tqdm(range(0, wav.shape[0], window_size * 16)):
# hs, _ = model(torch.tensor(wav[i : i+window_size*16]).unsqueeze(0), torch.tensor([window_size*16]))
# all_hs.append(hs[20])
# return torch.concat(all_hs, dim=1)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--km_model", default='./km_2000.mdl', help="Path to the kmeans model")
parser.add_argument("--bpe_model", default='./bpe.model', help="Path to the bpe model")
parser.add_argument("--audio", required=True, help="Path to the audio file")
parser.add_argument("-s", action='store_true', help="Streaming mode")
args = parser.parse_args()
kmeans_path = args.km_model
bpe_path = args.bpe_model
audio_file = args.audio
streaming = args.s
apply_kmeans = ApplyKmeans(kmeans_path, use_gpu=True)
ssl_model = S3PRLUpstream("hf_hubert_custom", path_or_url='TencentGameMate/chinese-hubert-large')
ssl_model.eval()
sp = spm.SentencePieceProcessor(model_file=bpe_path)
unit_to_char = {}
for l in open('distinct_cjk_token_lists').readlines():
l = l.split()
unit_to_char[int(l[0])] = l[1]
wav, sr = sf.read(audio_file)
with torch.no_grad():
if streaming:
assert False, "streaming mode is still developing"
# all_hs = streaming_extract(wav)
else:
all_hs, all_hs_len = ssl_model(torch.tensor([wav]), torch.tensor([wav.shape[0]]))
ssl_units = apply_kmeans(all_hs[20][0, :, :].numpy())
print(ssl_units)
ssl_char = "".join([unit_to_char[c] for c in ssl_units])
ssl_char_bpe = sp.encode(ssl_char, out_type=str)
print(ssl_char)
print(ssl_char_bpe)
|