Spaces:
Build error
Build error
Update analyze.py
Browse files- analyze.py +96 -91
analyze.py
CHANGED
@@ -2,6 +2,7 @@ from itertools import count, islice
|
|
2 |
from typing import Any, Iterable, Literal, Optional, TypeVar, Union, overload, Dict, List, Tuple
|
3 |
from collections import defaultdict
|
4 |
import json
|
|
|
5 |
|
6 |
import torch
|
7 |
|
@@ -36,113 +37,117 @@ def batched(
|
|
36 |
while batch := list(islice(it, n)):
|
37 |
yield (list(islice(indices, len(batch))), batch) if with_indices else batch
|
38 |
|
39 |
-
|
40 |
def analyze(
|
41 |
batch: List[Dict[str, Any]],
|
|
|
42 |
cache: Optional[Dict[str, List[Any]]] = None,
|
43 |
) -> List[List[Any]]:
|
44 |
-
cache = {} if cache is None else cache
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
def run_dataspeech(
|
49 |
-
rows: Iterable[Row], audio_column_name: str, text_column_name: str
|
50 |
-
) -> Iterable[Any]:
|
51 |
-
cache: Dict[str, List[Any]] = {}
|
52 |
-
|
53 |
# TODO: add speaker and gender to app
|
54 |
speaker_id_column_name = "speaker_id"
|
55 |
gender_column_name = "gender"
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
for
|
60 |
-
|
61 |
-
if key
|
62 |
-
tmp_dict[key].append(sample[key]) if key != audio_column_name else tmp_dict[key].append(sample[key][0]["src"])
|
63 |
-
|
64 |
-
tmp_dataset = Dataset.from_dict(tmp_dict).cast_column(audio_column_name, Audio())
|
65 |
-
|
66 |
-
|
67 |
-
## 1. Extract continous tags
|
68 |
-
pitch_dataset = tmp_dataset.map(
|
69 |
-
pitch_apply,
|
70 |
-
batched=True,
|
71 |
-
batch_size=BATCH_SIZE,
|
72 |
-
with_rank=True if torch.cuda.device_count()>0 else False,
|
73 |
-
num_proc=torch.cuda.device_count(),
|
74 |
-
remove_columns=[audio_column_name], # tricks to avoid rewritting audio
|
75 |
-
fn_kwargs={"audio_column_name": audio_column_name, "penn_batch_size": 4096},
|
76 |
-
)
|
77 |
|
78 |
-
|
79 |
-
snr_apply,
|
80 |
-
batched=True,
|
81 |
-
batch_size=BATCH_SIZE,
|
82 |
-
with_rank=True if torch.cuda.device_count()>0 else False,
|
83 |
-
num_proc=torch.cuda.device_count(),
|
84 |
-
remove_columns=[audio_column_name], # tricks to avoid rewritting audio
|
85 |
-
fn_kwargs={"audio_column_name": audio_column_name},
|
86 |
-
)
|
87 |
-
|
88 |
-
rate_dataset = tmp_dataset.map(
|
89 |
-
rate_apply,
|
90 |
-
with_rank=False,
|
91 |
-
num_proc=1,
|
92 |
-
remove_columns=[audio_column_name], # tricks to avoid rewritting audio
|
93 |
-
fn_kwargs={"audio_column_name": audio_column_name, "text_column_name": text_column_name},
|
94 |
-
)
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
if "gender" in batch[0] and "speaker_id" in batch[0]:
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
121 |
|
122 |
-
enriched_dataset, _ = speaker_level_relative_to_gender(enriched_dataset, speaker_level_pitch_bins, "speaker_id", "gender", "utterance_pitch_mean", "pitch", batch_size=20, num_workers=1, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges)
|
123 |
-
|
124 |
-
enriched_dataset, _ = bins_to_text(enriched_dataset, speaker_rate_bins, "speaking_rate", "speaking_rate", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("speaking_rate",None))
|
125 |
-
enriched_dataset, _ = bins_to_text(enriched_dataset, snr_bins, "snr", "noise", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("noise",None), lower_range=None)
|
126 |
-
enriched_dataset, _ = bins_to_text(enriched_dataset, reverberation_bins, "c50", "reverberation", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("reverberation",None))
|
127 |
-
enriched_dataset, _ = bins_to_text(enriched_dataset, utterance_level_std, "utterance_pitch_std", "speech_monotony", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("speech_monotony",None))
|
128 |
|
|
|
|
|
|
|
|
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
for i,sample in enumerate(batch):
|
133 |
-
new_sample = {}
|
134 |
-
new_sample[audio_column_name] = f"<audio src='{sample[audio_column_name][0]['src']}' controls></audio>"
|
135 |
-
for col in ["speaking_rate", "reverberation", "noise", "speech_monotony", "c50", "snr",]: # phonemes, speaking_rate, utterance_pitch_std, utterance_pitch_mean
|
136 |
-
new_sample[col] = enriched_dataset[col][i]
|
137 |
-
if "gender" in batch[0] and "speaker_id" in batch[0]:
|
138 |
-
new_sample["pitch"] = enriched_dataset["pitch"][i]
|
139 |
-
new_sample[gender_column_name] = sample[col]
|
140 |
-
new_sample[speaker_id_column_name] = sample[col]
|
141 |
-
|
142 |
-
new_sample[text_column_name] = sample[text_column_name]
|
143 |
-
batch[i] = new_sample
|
144 |
-
|
145 |
yield analyze(
|
146 |
batch=batch,
|
|
|
|
|
147 |
cache=cache,
|
148 |
)
|
|
|
2 |
from typing import Any, Iterable, Literal, Optional, TypeVar, Union, overload, Dict, List, Tuple
|
3 |
from collections import defaultdict
|
4 |
import json
|
5 |
+
import spaces
|
6 |
|
7 |
import torch
|
8 |
|
|
|
37 |
while batch := list(islice(it, n)):
|
38 |
yield (list(islice(indices, len(batch))), batch) if with_indices else batch
|
39 |
|
40 |
+
@spaces.GPU(duration=60)
|
41 |
def analyze(
|
42 |
batch: List[Dict[str, Any]],
|
43 |
+
audio_column_name: str, text_column_name: str,
|
44 |
cache: Optional[Dict[str, List[Any]]] = None,
|
45 |
) -> List[List[Any]]:
|
46 |
+
cache = {} if cache is None else cache
|
47 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
# TODO: add speaker and gender to app
|
49 |
speaker_id_column_name = "speaker_id"
|
50 |
gender_column_name = "gender"
|
51 |
|
52 |
+
tmp_dict = defaultdict(list)
|
53 |
+
for sample in batch:
|
54 |
+
for key in sample:
|
55 |
+
if key in [audio_column_name, text_column_name, speaker_id_column_name, gender_column_name]:
|
56 |
+
tmp_dict[key].append(sample[key]) if key != audio_column_name else tmp_dict[key].append(sample[key][0]["src"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
tmp_dataset = Dataset.from_dict(tmp_dict).cast_column(audio_column_name, Audio())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
|
61 |
+
## 1. Extract continous tags
|
62 |
+
pitch_dataset = tmp_dataset.map(
|
63 |
+
pitch_apply,
|
64 |
+
batched=True,
|
65 |
+
batch_size=BATCH_SIZE,
|
66 |
+
with_rank=True if torch.cuda.device_count()>0 else False,
|
67 |
+
num_proc=torch.cuda.device_count(),
|
68 |
+
remove_columns=[audio_column_name], # tricks to avoid rewritting audio
|
69 |
+
fn_kwargs={"audio_column_name": audio_column_name, "penn_batch_size": 4096},
|
70 |
+
)
|
71 |
+
|
72 |
+
snr_dataset = tmp_dataset.map(
|
73 |
+
snr_apply,
|
74 |
+
batched=True,
|
75 |
+
batch_size=BATCH_SIZE,
|
76 |
+
with_rank=True if torch.cuda.device_count()>0 else False,
|
77 |
+
num_proc=torch.cuda.device_count(),
|
78 |
+
remove_columns=[audio_column_name], # tricks to avoid rewritting audio
|
79 |
+
fn_kwargs={"audio_column_name": audio_column_name},
|
80 |
+
)
|
81 |
+
|
82 |
+
rate_dataset = tmp_dataset.map(
|
83 |
+
rate_apply,
|
84 |
+
with_rank=False,
|
85 |
+
num_proc=1,
|
86 |
+
remove_columns=[audio_column_name], # tricks to avoid rewritting audio
|
87 |
+
fn_kwargs={"audio_column_name": audio_column_name, "text_column_name": text_column_name},
|
88 |
+
)
|
89 |
+
|
90 |
+
enriched_dataset = pitch_dataset.add_column("snr", snr_dataset["snr"]).add_column("c50", snr_dataset["c50"])
|
91 |
+
enriched_dataset = enriched_dataset.add_column("speaking_rate", rate_dataset["speaking_rate"]).add_column("phonemes", rate_dataset["phonemes"])
|
92 |
+
|
93 |
+
|
94 |
+
## 2. Map continuous tags to text tags
|
95 |
+
|
96 |
+
text_bins_dict = {}
|
97 |
+
with open("./v01_text_bins.json") as json_file:
|
98 |
+
text_bins_dict = json.load(json_file)
|
99 |
+
|
100 |
+
bin_edges_dict = {}
|
101 |
+
with open("./v01_bin_edges.json") as json_file:
|
102 |
+
bin_edges_dict = json.load(json_file)
|
103 |
+
|
104 |
+
speaker_level_pitch_bins = text_bins_dict.get("speaker_level_pitch_bins")
|
105 |
+
speaker_rate_bins = text_bins_dict.get("speaker_rate_bins")
|
106 |
+
snr_bins = text_bins_dict.get("snr_bins")
|
107 |
+
reverberation_bins = text_bins_dict.get("reverberation_bins")
|
108 |
+
utterance_level_std = text_bins_dict.get("utterance_level_std")
|
109 |
+
|
110 |
+
enriched_dataset = [enriched_dataset]
|
111 |
+
if "gender" in batch[0] and "speaker_id" in batch[0]:
|
112 |
+
bin_edges = None
|
113 |
+
if "pitch_bins_male" in bin_edges_dict and "pitch_bins_female" in bin_edges_dict:
|
114 |
+
bin_edges = {"male": bin_edges_dict["pitch_bins_male"], "female": bin_edges_dict["pitch_bins_female"]}
|
115 |
+
|
116 |
+
enriched_dataset, _ = speaker_level_relative_to_gender(enriched_dataset, speaker_level_pitch_bins, "speaker_id", "gender", "utterance_pitch_mean", "pitch", batch_size=20, num_workers=1, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges)
|
117 |
|
118 |
+
enriched_dataset, _ = bins_to_text(enriched_dataset, speaker_rate_bins, "speaking_rate", "speaking_rate", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("speaking_rate",None))
|
119 |
+
enriched_dataset, _ = bins_to_text(enriched_dataset, snr_bins, "snr", "noise", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("noise",None), lower_range=None)
|
120 |
+
enriched_dataset, _ = bins_to_text(enriched_dataset, reverberation_bins, "c50", "reverberation", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("reverberation",None))
|
121 |
+
enriched_dataset, _ = bins_to_text(enriched_dataset, utterance_level_std, "utterance_pitch_std", "speech_monotony", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("speech_monotony",None))
|
122 |
+
|
123 |
+
|
124 |
+
enriched_dataset = enriched_dataset[0]
|
125 |
+
|
126 |
+
for i,sample in enumerate(batch):
|
127 |
+
new_sample = {}
|
128 |
+
new_sample[audio_column_name] = f"<audio src='{sample[audio_column_name][0]['src']}' controls></audio>"
|
129 |
+
for col in ["speaking_rate", "reverberation", "noise", "speech_monotony", "c50", "snr",]: # phonemes, speaking_rate, utterance_pitch_std, utterance_pitch_mean
|
130 |
+
new_sample[col] = enriched_dataset[col][i]
|
131 |
if "gender" in batch[0] and "speaker_id" in batch[0]:
|
132 |
+
new_sample["pitch"] = enriched_dataset["pitch"][i]
|
133 |
+
new_sample[gender_column_name] = sample[col]
|
134 |
+
new_sample[speaker_id_column_name] = sample[col]
|
135 |
+
|
136 |
+
new_sample[text_column_name] = sample[text_column_name]
|
137 |
+
batch[i] = new_sample
|
138 |
+
return batch
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
+
def run_dataspeech(
|
142 |
+
rows: Iterable[Row], audio_column_name: str, text_column_name: str
|
143 |
+
) -> Iterable[Any]:
|
144 |
+
cache: Dict[str, List[Any]] = {}
|
145 |
|
146 |
+
|
147 |
+
for batch in batched(rows, BATCH_SIZE):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
yield analyze(
|
149 |
batch=batch,
|
150 |
+
audio_column_name=audio_column_name,
|
151 |
+
text_column_name=text_column_name,
|
152 |
cache=cache,
|
153 |
)
|