JustinLin610's picture
first commit
ee21b96
raw
history blame
3.41 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import gc
import os
import random
import shutil
import numpy as np
import torch
import tqdm
from examples.textless_nlp.gslm.speech2unit.pretrained.cpc_feature_reader import (
CpcFeatureReader,
)
from examples.textless_nlp.gslm.speech2unit.pretrained.hubert_feature_reader import (
HubertFeatureReader,
)
from examples.textless_nlp.gslm.speech2unit.pretrained.logmel_feature_reader import (
LogMelFeatureReader,
)
from examples.textless_nlp.gslm.speech2unit.pretrained.w2v2_feature_reader import (
Wav2VecFeatureReader,
)
def get_feature_reader(feature_type):
if feature_type == "logmel":
return LogMelFeatureReader
elif feature_type == "hubert":
return HubertFeatureReader
elif feature_type == "w2v2":
return Wav2VecFeatureReader
elif feature_type == "cpc":
return CpcFeatureReader
else:
raise NotImplementedError(f"{feature_type} is not supported.")
def get_feature_iterator(
feature_type, checkpoint_path, layer, manifest_path, sample_pct
):
feature_reader_cls = get_feature_reader(feature_type)
with open(manifest_path, "r") as fp:
lines = fp.read().split("\n")
root = lines.pop(0).strip()
file_path_list = [
os.path.join(root, line.split("\t")[0])
for line in lines
if len(line) > 0
]
if sample_pct < 1.0:
file_path_list = random.sample(
file_path_list, int(sample_pct * len(file_path_list))
)
num_files = len(file_path_list)
reader = feature_reader_cls(
checkpoint_path=checkpoint_path, layer=layer
)
def iterate():
for file_path in file_path_list:
feats = reader.get_feats(file_path)
yield feats.cpu().numpy()
return iterate, num_files
def get_features(
feature_type, checkpoint_path, layer, manifest_path, sample_pct, flatten
):
generator, num_files = get_feature_iterator(
feature_type=feature_type,
checkpoint_path=checkpoint_path,
layer=layer,
manifest_path=manifest_path,
sample_pct=sample_pct,
)
iterator = generator()
features_list = []
for features in tqdm.tqdm(iterator, total=num_files):
features_list.append(features)
# Explicit clean up
del iterator
del generator
gc.collect()
torch.cuda.empty_cache()
if flatten:
return np.concatenate(features_list)
return features_list
def get_and_dump_features(
feature_type,
checkpoint_path,
layer,
manifest_path,
sample_pct,
flatten,
out_features_path,
):
# Feature extraction
features_batch = get_features(
feature_type=feature_type,
checkpoint_path=checkpoint_path,
layer=layer,
manifest_path=manifest_path,
sample_pct=sample_pct,
flatten=flatten,
)
# Save features
out_dir_path = os.path.dirname(out_features_path)
os.makedirs(out_dir_path, exist_ok=True)
shutil.copyfile(
manifest_path,
os.path.join(out_dir_path, os.path.basename(manifest_path)),
)
np.save(out_features_path, features_batch)
return features_batch