ylacombe commited on
Commit
0271b59
1 Parent(s): a18836f

Update analyze.py

Browse files
Files changed (1) hide show
  1. analyze.py +96 -91
analyze.py CHANGED
@@ -2,6 +2,7 @@ from itertools import count, islice
2
  from typing import Any, Iterable, Literal, Optional, TypeVar, Union, overload, Dict, List, Tuple
3
  from collections import defaultdict
4
  import json
 
5
 
6
  import torch
7
 
@@ -36,113 +37,117 @@ def batched(
36
  while batch := list(islice(it, n)):
37
  yield (list(islice(indices, len(batch))), batch) if with_indices else batch
38
 
39
-
40
  def analyze(
41
  batch: List[Dict[str, Any]],
 
42
  cache: Optional[Dict[str, List[Any]]] = None,
43
  ) -> List[List[Any]]:
44
- cache = {} if cache is None else cache
45
- return batch
46
-
47
-
48
- def run_dataspeech(
49
- rows: Iterable[Row], audio_column_name: str, text_column_name: str
50
- ) -> Iterable[Any]:
51
- cache: Dict[str, List[Any]] = {}
52
-
53
  # TODO: add speaker and gender to app
54
  speaker_id_column_name = "speaker_id"
55
  gender_column_name = "gender"
56
 
57
- for batch in batched(rows, BATCH_SIZE):
58
- tmp_dict = defaultdict(list)
59
- for sample in batch:
60
- for key in sample:
61
- if key in [audio_column_name, text_column_name, speaker_id_column_name, gender_column_name]:
62
- tmp_dict[key].append(sample[key]) if key != audio_column_name else tmp_dict[key].append(sample[key][0]["src"])
63
-
64
- tmp_dataset = Dataset.from_dict(tmp_dict).cast_column(audio_column_name, Audio())
65
-
66
-
67
- ## 1. Extract continous tags
68
- pitch_dataset = tmp_dataset.map(
69
- pitch_apply,
70
- batched=True,
71
- batch_size=BATCH_SIZE,
72
- with_rank=True if torch.cuda.device_count()>0 else False,
73
- num_proc=torch.cuda.device_count(),
74
- remove_columns=[audio_column_name], # tricks to avoid rewritting audio
75
- fn_kwargs={"audio_column_name": audio_column_name, "penn_batch_size": 4096},
76
- )
77
 
78
- snr_dataset = tmp_dataset.map(
79
- snr_apply,
80
- batched=True,
81
- batch_size=BATCH_SIZE,
82
- with_rank=True if torch.cuda.device_count()>0 else False,
83
- num_proc=torch.cuda.device_count(),
84
- remove_columns=[audio_column_name], # tricks to avoid rewritting audio
85
- fn_kwargs={"audio_column_name": audio_column_name},
86
- )
87
-
88
- rate_dataset = tmp_dataset.map(
89
- rate_apply,
90
- with_rank=False,
91
- num_proc=1,
92
- remove_columns=[audio_column_name], # tricks to avoid rewritting audio
93
- fn_kwargs={"audio_column_name": audio_column_name, "text_column_name": text_column_name},
94
- )
95
 
96
- enriched_dataset = pitch_dataset.add_column("snr", snr_dataset["snr"]).add_column("c50", snr_dataset["c50"])
97
- enriched_dataset = enriched_dataset.add_column("speaking_rate", rate_dataset["speaking_rate"]).add_column("phonemes", rate_dataset["phonemes"])
98
-
99
-
100
- ## 2. Map continuous tags to text tags
101
-
102
- text_bins_dict = {}
103
- with open("./v01_text_bins.json") as json_file:
104
- text_bins_dict = json.load(json_file)
105
-
106
- bin_edges_dict = {}
107
- with open("./v01_bin_edges.json") as json_file:
108
- bin_edges_dict = json.load(json_file)
109
-
110
- speaker_level_pitch_bins = text_bins_dict.get("speaker_level_pitch_bins")
111
- speaker_rate_bins = text_bins_dict.get("speaker_rate_bins")
112
- snr_bins = text_bins_dict.get("snr_bins")
113
- reverberation_bins = text_bins_dict.get("reverberation_bins")
114
- utterance_level_std = text_bins_dict.get("utterance_level_std")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- enriched_dataset = [enriched_dataset]
 
 
 
 
 
 
 
 
 
 
 
 
117
  if "gender" in batch[0] and "speaker_id" in batch[0]:
118
- bin_edges = None
119
- if "pitch_bins_male" in bin_edges_dict and "pitch_bins_female" in bin_edges_dict:
120
- bin_edges = {"male": bin_edges_dict["pitch_bins_male"], "female": bin_edges_dict["pitch_bins_female"]}
 
 
 
 
121
 
122
- enriched_dataset, _ = speaker_level_relative_to_gender(enriched_dataset, speaker_level_pitch_bins, "speaker_id", "gender", "utterance_pitch_mean", "pitch", batch_size=20, num_workers=1, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges)
123
-
124
- enriched_dataset, _ = bins_to_text(enriched_dataset, speaker_rate_bins, "speaking_rate", "speaking_rate", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("speaking_rate",None))
125
- enriched_dataset, _ = bins_to_text(enriched_dataset, snr_bins, "snr", "noise", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("noise",None), lower_range=None)
126
- enriched_dataset, _ = bins_to_text(enriched_dataset, reverberation_bins, "c50", "reverberation", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("reverberation",None))
127
- enriched_dataset, _ = bins_to_text(enriched_dataset, utterance_level_std, "utterance_pitch_std", "speech_monotony", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("speech_monotony",None))
128
 
 
 
 
 
129
 
130
- enriched_dataset = enriched_dataset[0]
131
-
132
- for i,sample in enumerate(batch):
133
- new_sample = {}
134
- new_sample[audio_column_name] = f"<audio src='{sample[audio_column_name][0]['src']}' controls></audio>"
135
- for col in ["speaking_rate", "reverberation", "noise", "speech_monotony", "c50", "snr",]: # phonemes, speaking_rate, utterance_pitch_std, utterance_pitch_mean
136
- new_sample[col] = enriched_dataset[col][i]
137
- if "gender" in batch[0] and "speaker_id" in batch[0]:
138
- new_sample["pitch"] = enriched_dataset["pitch"][i]
139
- new_sample[gender_column_name] = sample[col]
140
- new_sample[speaker_id_column_name] = sample[col]
141
-
142
- new_sample[text_column_name] = sample[text_column_name]
143
- batch[i] = new_sample
144
-
145
  yield analyze(
146
  batch=batch,
 
 
147
  cache=cache,
148
  )
 
2
  from typing import Any, Iterable, Literal, Optional, TypeVar, Union, overload, Dict, List, Tuple
3
  from collections import defaultdict
4
  import json
5
+ import spaces
6
 
7
  import torch
8
 
 
37
  while batch := list(islice(it, n)):
38
  yield (list(islice(indices, len(batch))), batch) if with_indices else batch
39
 
40
+ @spaces.GPU(duration=60)
41
  def analyze(
42
  batch: List[Dict[str, Any]],
43
+ audio_column_name: str, text_column_name: str,
44
  cache: Optional[Dict[str, List[Any]]] = None,
45
  ) -> List[List[Any]]:
46
+ cache = {} if cache is None else cache
47
+
 
 
 
 
 
 
 
48
  # TODO: add speaker and gender to app
49
  speaker_id_column_name = "speaker_id"
50
  gender_column_name = "gender"
51
 
52
+ tmp_dict = defaultdict(list)
53
+ for sample in batch:
54
+ for key in sample:
55
+ if key in [audio_column_name, text_column_name, speaker_id_column_name, gender_column_name]:
56
+ tmp_dict[key].append(sample[key]) if key != audio_column_name else tmp_dict[key].append(sample[key][0]["src"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ tmp_dataset = Dataset.from_dict(tmp_dict).cast_column(audio_column_name, Audio())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+
61
+ ## 1. Extract continous tags
62
+ pitch_dataset = tmp_dataset.map(
63
+ pitch_apply,
64
+ batched=True,
65
+ batch_size=BATCH_SIZE,
66
+ with_rank=True if torch.cuda.device_count()>0 else False,
67
+ num_proc=torch.cuda.device_count(),
68
+ remove_columns=[audio_column_name], # tricks to avoid rewritting audio
69
+ fn_kwargs={"audio_column_name": audio_column_name, "penn_batch_size": 4096},
70
+ )
71
+
72
+ snr_dataset = tmp_dataset.map(
73
+ snr_apply,
74
+ batched=True,
75
+ batch_size=BATCH_SIZE,
76
+ with_rank=True if torch.cuda.device_count()>0 else False,
77
+ num_proc=torch.cuda.device_count(),
78
+ remove_columns=[audio_column_name], # tricks to avoid rewritting audio
79
+ fn_kwargs={"audio_column_name": audio_column_name},
80
+ )
81
+
82
+ rate_dataset = tmp_dataset.map(
83
+ rate_apply,
84
+ with_rank=False,
85
+ num_proc=1,
86
+ remove_columns=[audio_column_name], # tricks to avoid rewritting audio
87
+ fn_kwargs={"audio_column_name": audio_column_name, "text_column_name": text_column_name},
88
+ )
89
+
90
+ enriched_dataset = pitch_dataset.add_column("snr", snr_dataset["snr"]).add_column("c50", snr_dataset["c50"])
91
+ enriched_dataset = enriched_dataset.add_column("speaking_rate", rate_dataset["speaking_rate"]).add_column("phonemes", rate_dataset["phonemes"])
92
+
93
+
94
+ ## 2. Map continuous tags to text tags
95
+
96
+ text_bins_dict = {}
97
+ with open("./v01_text_bins.json") as json_file:
98
+ text_bins_dict = json.load(json_file)
99
+
100
+ bin_edges_dict = {}
101
+ with open("./v01_bin_edges.json") as json_file:
102
+ bin_edges_dict = json.load(json_file)
103
+
104
+ speaker_level_pitch_bins = text_bins_dict.get("speaker_level_pitch_bins")
105
+ speaker_rate_bins = text_bins_dict.get("speaker_rate_bins")
106
+ snr_bins = text_bins_dict.get("snr_bins")
107
+ reverberation_bins = text_bins_dict.get("reverberation_bins")
108
+ utterance_level_std = text_bins_dict.get("utterance_level_std")
109
+
110
+ enriched_dataset = [enriched_dataset]
111
+ if "gender" in batch[0] and "speaker_id" in batch[0]:
112
+ bin_edges = None
113
+ if "pitch_bins_male" in bin_edges_dict and "pitch_bins_female" in bin_edges_dict:
114
+ bin_edges = {"male": bin_edges_dict["pitch_bins_male"], "female": bin_edges_dict["pitch_bins_female"]}
115
+
116
+ enriched_dataset, _ = speaker_level_relative_to_gender(enriched_dataset, speaker_level_pitch_bins, "speaker_id", "gender", "utterance_pitch_mean", "pitch", batch_size=20, num_workers=1, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges)
117
 
118
+ enriched_dataset, _ = bins_to_text(enriched_dataset, speaker_rate_bins, "speaking_rate", "speaking_rate", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("speaking_rate",None))
119
+ enriched_dataset, _ = bins_to_text(enriched_dataset, snr_bins, "snr", "noise", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("noise",None), lower_range=None)
120
+ enriched_dataset, _ = bins_to_text(enriched_dataset, reverberation_bins, "c50", "reverberation", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("reverberation",None))
121
+ enriched_dataset, _ = bins_to_text(enriched_dataset, utterance_level_std, "utterance_pitch_std", "speech_monotony", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("speech_monotony",None))
122
+
123
+
124
+ enriched_dataset = enriched_dataset[0]
125
+
126
+ for i,sample in enumerate(batch):
127
+ new_sample = {}
128
+ new_sample[audio_column_name] = f"<audio src='{sample[audio_column_name][0]['src']}' controls></audio>"
129
+ for col in ["speaking_rate", "reverberation", "noise", "speech_monotony", "c50", "snr",]: # phonemes, speaking_rate, utterance_pitch_std, utterance_pitch_mean
130
+ new_sample[col] = enriched_dataset[col][i]
131
  if "gender" in batch[0] and "speaker_id" in batch[0]:
132
+ new_sample["pitch"] = enriched_dataset["pitch"][i]
133
+ new_sample[gender_column_name] = sample[col]
134
+ new_sample[speaker_id_column_name] = sample[col]
135
+
136
+ new_sample[text_column_name] = sample[text_column_name]
137
+ batch[i] = new_sample
138
+ return batch
139
 
 
 
 
 
 
 
140
 
141
+ def run_dataspeech(
142
+ rows: Iterable[Row], audio_column_name: str, text_column_name: str
143
+ ) -> Iterable[Any]:
144
+ cache: Dict[str, List[Any]] = {}
145
 
146
+
147
+ for batch in batched(rows, BATCH_SIZE):
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  yield analyze(
149
  batch=batch,
150
+ audio_column_name=audio_column_name,
151
+ text_column_name=text_column_name,
152
  cache=cache,
153
  )