Spaces:
Paused
Paused
# visualisation tools for mimic2 | |
import argparse | |
import csv | |
import os | |
import random | |
from statistics import StatisticsError, mean, median, mode, stdev | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from text.cmudict import CMUDict | |
def get_audio_seconds(frames): | |
return (frames * 12.5) / 1000 | |
def append_data_statistics(meta_data): | |
# get data statistics | |
for char_cnt in meta_data: | |
data = meta_data[char_cnt]["data"] | |
audio_len_list = [d["audio_len"] for d in data] | |
mean_audio_len = mean(audio_len_list) | |
try: | |
mode_audio_list = [round(d["audio_len"], 2) for d in data] | |
mode_audio_len = mode(mode_audio_list) | |
except StatisticsError: | |
mode_audio_len = audio_len_list[0] | |
median_audio_len = median(audio_len_list) | |
try: | |
std = stdev(d["audio_len"] for d in data) | |
except StatisticsError: | |
std = 0 | |
meta_data[char_cnt]["mean"] = mean_audio_len | |
meta_data[char_cnt]["median"] = median_audio_len | |
meta_data[char_cnt]["mode"] = mode_audio_len | |
meta_data[char_cnt]["std"] = std | |
return meta_data | |
def process_meta_data(path): | |
meta_data = {} | |
# load meta data | |
with open(path, "r", encoding="utf-8") as f: | |
data = csv.reader(f, delimiter="|") | |
for row in data: | |
frames = int(row[2]) | |
utt = row[3] | |
audio_len = get_audio_seconds(frames) | |
char_count = len(utt) | |
if not meta_data.get(char_count): | |
meta_data[char_count] = {"data": []} | |
meta_data[char_count]["data"].append( | |
{ | |
"utt": utt, | |
"frames": frames, | |
"audio_len": audio_len, | |
"row": "{}|{}|{}|{}".format(row[0], row[1], row[2], row[3]), | |
} | |
) | |
meta_data = append_data_statistics(meta_data) | |
return meta_data | |
def get_data_points(meta_data): | |
x = meta_data | |
y_avg = [meta_data[d]["mean"] for d in meta_data] | |
y_mode = [meta_data[d]["mode"] for d in meta_data] | |
y_median = [meta_data[d]["median"] for d in meta_data] | |
y_std = [meta_data[d]["std"] for d in meta_data] | |
y_num_samples = [len(meta_data[d]["data"]) for d in meta_data] | |
return { | |
"x": x, | |
"y_avg": y_avg, | |
"y_mode": y_mode, | |
"y_median": y_median, | |
"y_std": y_std, | |
"y_num_samples": y_num_samples, | |
} | |
def save_training(file_path, meta_data): | |
rows = [] | |
for char_cnt in meta_data: | |
data = meta_data[char_cnt]["data"] | |
for d in data: | |
rows.append(d["row"] + "\n") | |
random.shuffle(rows) | |
with open(file_path, "w+", encoding="utf-8") as f: | |
for row in rows: | |
f.write(row) | |
def plot(meta_data, save_path=None): | |
save = False | |
if save_path: | |
save = True | |
graph_data = get_data_points(meta_data) | |
x = graph_data["x"] | |
y_avg = graph_data["y_avg"] | |
y_std = graph_data["y_std"] | |
y_mode = graph_data["y_mode"] | |
y_median = graph_data["y_median"] | |
y_num_samples = graph_data["y_num_samples"] | |
plt.figure() | |
plt.plot(x, y_avg, "ro") | |
plt.xlabel("character lengths", fontsize=30) | |
plt.ylabel("avg seconds", fontsize=30) | |
if save: | |
name = "char_len_vs_avg_secs" | |
plt.savefig(os.path.join(save_path, name)) | |
plt.figure() | |
plt.plot(x, y_mode, "ro") | |
plt.xlabel("character lengths", fontsize=30) | |
plt.ylabel("mode seconds", fontsize=30) | |
if save: | |
name = "char_len_vs_mode_secs" | |
plt.savefig(os.path.join(save_path, name)) | |
plt.figure() | |
plt.plot(x, y_median, "ro") | |
plt.xlabel("character lengths", fontsize=30) | |
plt.ylabel("median seconds", fontsize=30) | |
if save: | |
name = "char_len_vs_med_secs" | |
plt.savefig(os.path.join(save_path, name)) | |
plt.figure() | |
plt.plot(x, y_std, "ro") | |
plt.xlabel("character lengths", fontsize=30) | |
plt.ylabel("standard deviation", fontsize=30) | |
if save: | |
name = "char_len_vs_std" | |
plt.savefig(os.path.join(save_path, name)) | |
plt.figure() | |
plt.plot(x, y_num_samples, "ro") | |
plt.xlabel("character lengths", fontsize=30) | |
plt.ylabel("number of samples", fontsize=30) | |
if save: | |
name = "char_len_vs_num_samples" | |
plt.savefig(os.path.join(save_path, name)) | |
def plot_phonemes(train_path, cmu_dict_path, save_path): | |
cmudict = CMUDict(cmu_dict_path) | |
phonemes = {} | |
with open(train_path, "r", encoding="utf-8") as f: | |
data = csv.reader(f, delimiter="|") | |
phonemes["None"] = 0 | |
for row in data: | |
words = row[3].split() | |
for word in words: | |
pho = cmudict.lookup(word) | |
if pho: | |
indie = pho[0].split() | |
for nemes in indie: | |
if phonemes.get(nemes): | |
phonemes[nemes] += 1 | |
else: | |
phonemes[nemes] = 1 | |
else: | |
phonemes["None"] += 1 | |
x, y = [], [] | |
for k, v in phonemes.items(): | |
x.append(k) | |
y.append(v) | |
plt.figure() | |
plt.rcParams["figure.figsize"] = (50, 20) | |
barplot = sns.barplot(x=x, y=y) | |
if save_path: | |
fig = barplot.get_figure() | |
fig.savefig(os.path.join(save_path, "phoneme_dist")) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--train_file_path", | |
required=True, | |
help="this is the path to the train.txt file that the preprocess.py script creates", | |
) | |
parser.add_argument("--save_to", help="path to save charts of data to") | |
parser.add_argument("--cmu_dict_path", help="give cmudict-0.7b to see phoneme distribution") | |
args = parser.parse_args() | |
meta_data = process_meta_data(args.train_file_path) | |
plt.rcParams["figure.figsize"] = (10, 5) | |
plot(meta_data, save_path=args.save_to) | |
if args.cmu_dict_path: | |
plt.rcParams["figure.figsize"] = (30, 10) | |
plot_phonemes(args.train_file_path, args.cmu_dict_path, args.save_to) | |
plt.show() | |
if __name__ == "__main__": | |
main() | |