tonic commited on
Commit
33d9042
ยท
1 Parent(s): 89d01e6

Laion WhisperSpeech Demo

Browse files
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: Laion Whisper
3
- emoji: ๐Ÿ†
4
  colorFrom: pink
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.15.0
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
  ---
12
 
 
1
  ---
2
+ title: WhisperSpeech
3
+ emoji: ๐ŸŒฌ๏ธ๐Ÿ’ฌ๐Ÿ“
4
  colorFrom: pink
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.15.0
8
  app_file: app.py
9
+ pinned: True
10
  license: mit
11
  ---
12
 
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import os
4
+ from whisperspeech.pipeline import Pipeline
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from whisperspeech.languages import LANGUAGES
8
+ from whisperspeech.pipeline import Pipeline
9
+ import tempfil
10
+
11
+ title = """#๐Ÿ™‹๐Ÿปโ€โ™‚๏ธ Welcome to๐ŸŒŸTonic's๐ŸŒฌ๏ธ๐Ÿ’ฌ๐Ÿ“WhisperSpeech
12
+ You can use this ZeroGPU Space to test out the current model [๐ŸŒฌ๏ธ๐Ÿ’ฌ๐Ÿ“collabora/whisperspeech](https://huggingface.co/collabora/whisperspeech). ๐ŸŒฌ๏ธ๐Ÿ’ฌ๐Ÿ“collabora/whisperspeech is An Open Source text-to-speech system built by inverting Whisper. Previously known as spear-tts-pytorch. It's like Stable Diffusion but for speech โ€“ both powerful and easily customizable.
13
+ You can also use ๐ŸŒฌ๏ธ๐Ÿ’ฌ๐Ÿ“WhisperSpeech by cloning this space. ๐Ÿงฌ๐Ÿ”ฌ๐Ÿ” Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/laion-whisper?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3>
14
+ Join us : ๐ŸŒŸTeamTonic๐ŸŒŸ is always making cool demos! Join our active builder's๐Ÿ› ๏ธcommunity ๐Ÿ‘ป [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On ๐Ÿค—Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On ๐ŸŒGithub: [Polytonic](https://github.com/tonic-ai) & contribute to ๐ŸŒŸ [Poly](https://github.com/tonic-ai/poly) ๐Ÿค—Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant ๐Ÿค—
15
+ """
16
+
17
+ @spaces.GPU
18
+
19
+ def whisper_speech_demo(text, lang, speaker_audio=None, mix_lang=None, mix_text=None):
20
+ pipe = Pipeline(s2a_ref='collabora/whisperspeech:s2a-q4-tiny-en+pl.model')
21
+
22
+ # Use uploaded speaker audio if provided
23
+ speaker_url = None
24
+ if speaker_audio is not None:
25
+ speaker_url = speaker_audio.name
26
+
27
+ if mix_lang and mix_text:
28
+ mixed_langs = lang.split(',') + mix_lang.split(',')
29
+ mixed_texts = [text] + mix_text.split(',')
30
+ stoks = pipe.t2s.generate(mixed_texts, lang=mixed_langs)
31
+ audio_data = pipe.generate(stoks, speaker_url, lang=mixed_langs[0])
32
+ else:
33
+ audio_data = pipe.generate(text, speaker_url, lang)
34
+
35
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
36
+ tmp_file_name = tmp_file.name
37
+ with open(tmp_file_name, 'wb') as file:
38
+ file.write(audio_data)
39
+
40
+ return tmp_file_name
41
+
42
+ with gr.Blocks() as demo:
43
+ gr.Markdown(title)
44
+ with gr.Row():
45
+ text_input = gr.Textbox(label="Enter text")
46
+ lang_input = gr.Dropdown(choices=list(LANGUAGES.keys()), label="Language")
47
+ speaker_input = gr.File(label="Upload Speaker Audio (optional)", type="file", accepts=["audio/*"])
48
+ with gr.Row():
49
+ mix_lang_input = gr.Textbox(label="Mixed Languages (optional, comma-separated)", placeholder="e.g., en,pl")
50
+ mix_text_input = gr.Textbox(label="Mixed Texts (optional, for mixed languages)", placeholder="e.g., Hello, Czeล›ฤ‡")
51
+ with gr.Row():
52
+ submit_button = gr.Button("Generate Speech")
53
+ output_audio = gr.Audio(label="Generated Speech")
54
+
55
+ submit_button.click(
56
+ whisper_speech_demo,
57
+ inputs=[text_input, lang_input, speaker_input, mix_lang_input, mix_text_input],
58
+ outputs=output_audio
59
+ )
60
+
61
+ demo.launch()
requirements.txt CHANGED
@@ -1 +1,3 @@
1
- whisperspeech
 
 
 
1
+ torch
2
+ transformers
3
+ accelerate
whisperspeech/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.5.6"
whisperspeech/_modidx.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Autogenerated by nbdev
2
+
3
+ d = { 'settings': { 'branch': 'master',
4
+ 'doc_baseurl': '/WhisperSpeech',
5
+ 'doc_host': 'https://collabora.github.io',
6
+ 'git_url': 'https://github.com/collabora/WhisperSpeech',
7
+ 'lib_path': 'whisperspeech'},
8
+ 'syms': { 'whisperspeech.a2wav': { 'whisperspeech.a2wav.Vocoder': ('6. quality-boosting vocoder.html#vocoder', 'whisperspeech/a2wav.py'),
9
+ 'whisperspeech.a2wav.Vocoder.__init__': ( '6. quality-boosting vocoder.html#vocoder.__init__',
10
+ 'whisperspeech/a2wav.py'),
11
+ 'whisperspeech.a2wav.Vocoder.decode': ( '6. quality-boosting vocoder.html#vocoder.decode',
12
+ 'whisperspeech/a2wav.py'),
13
+ 'whisperspeech.a2wav.Vocoder.decode_to_file': ( '6. quality-boosting '
14
+ 'vocoder.html#vocoder.decode_to_file',
15
+ 'whisperspeech/a2wav.py'),
16
+ 'whisperspeech.a2wav.Vocoder.decode_to_notebook': ( '6. quality-boosting '
17
+ 'vocoder.html#vocoder.decode_to_notebook',
18
+ 'whisperspeech/a2wav.py')},
19
+ 'whisperspeech.extract_acoustic': { 'whisperspeech.extract_acoustic.extract_Atoks': ( '1. acoustic token '
20
+ 'extraction.html#extract_atoks',
21
+ 'whisperspeech/extract_acoustic.py'),
22
+ 'whisperspeech.extract_acoustic.extract_acoustic': ( '1. acoustic token '
23
+ 'extraction.html#extract_acoustic',
24
+ 'whisperspeech/extract_acoustic.py'),
25
+ 'whisperspeech.extract_acoustic.load': ( '1. acoustic token extraction.html#load',
26
+ 'whisperspeech/extract_acoustic.py'),
27
+ 'whisperspeech.extract_acoustic.load_model': ( '1. acoustic token '
28
+ 'extraction.html#load_model',
29
+ 'whisperspeech/extract_acoustic.py')},
30
+ 'whisperspeech.extract_semb': { 'whisperspeech.extract_semb.encode_semantic': ( '2c. whisper semantic embedding '
31
+ 'extraction.html#encode_semantic',
32
+ 'whisperspeech/extract_semb.py'),
33
+ 'whisperspeech.extract_semb.extract_semantic': ( '2c. whisper semantic embedding '
34
+ 'extraction.html#extract_semantic',
35
+ 'whisperspeech/extract_semb.py'),
36
+ 'whisperspeech.extract_semb.load_model': ( '2c. whisper semantic embedding '
37
+ 'extraction.html#load_model',
38
+ 'whisperspeech/extract_semb.py')},
39
+ 'whisperspeech.fetch_models': { 'whisperspeech.fetch_models.main': ( '0. download models.html#main',
40
+ 'whisperspeech/fetch_models.py')},
41
+ 'whisperspeech.modules': { 'whisperspeech.modules.Decoder': ('a. neural modules.html#decoder', 'whisperspeech/modules.py'),
42
+ 'whisperspeech.modules.Decoder.__init__': ( 'a. neural modules.html#decoder.__init__',
43
+ 'whisperspeech/modules.py'),
44
+ 'whisperspeech.modules.Decoder.forward': ( 'a. neural modules.html#decoder.forward',
45
+ 'whisperspeech/modules.py'),
46
+ 'whisperspeech.modules.Encoder': ('a. neural modules.html#encoder', 'whisperspeech/modules.py'),
47
+ 'whisperspeech.modules.Encoder.__init__': ( 'a. neural modules.html#encoder.__init__',
48
+ 'whisperspeech/modules.py'),
49
+ 'whisperspeech.modules.Encoder.forward': ( 'a. neural modules.html#encoder.forward',
50
+ 'whisperspeech/modules.py'),
51
+ 'whisperspeech.modules.LayerNorm': ('a. neural modules.html#layernorm', 'whisperspeech/modules.py'),
52
+ 'whisperspeech.modules.LayerNorm.forward': ( 'a. neural modules.html#layernorm.forward',
53
+ 'whisperspeech/modules.py'),
54
+ 'whisperspeech.modules.LinearHead': ( 'a. neural modules.html#linearhead',
55
+ 'whisperspeech/modules.py'),
56
+ 'whisperspeech.modules.MultiHeadAttention': ( 'a. neural modules.html#multiheadattention',
57
+ 'whisperspeech/modules.py'),
58
+ 'whisperspeech.modules.MultiHeadAttention.__init__': ( 'a. neural '
59
+ 'modules.html#multiheadattention.__init__',
60
+ 'whisperspeech/modules.py'),
61
+ 'whisperspeech.modules.MultiHeadAttention.forward': ( 'a. neural '
62
+ 'modules.html#multiheadattention.forward',
63
+ 'whisperspeech/modules.py'),
64
+ 'whisperspeech.modules.MultiHeadAttention.qkv_attention_pth20': ( 'a. neural '
65
+ 'modules.html#multiheadattention.qkv_attention_pth20',
66
+ 'whisperspeech/modules.py'),
67
+ 'whisperspeech.modules.MultiHeadAttention.qkv_attention_vanilla': ( 'a. neural '
68
+ 'modules.html#multiheadattention.qkv_attention_vanilla',
69
+ 'whisperspeech/modules.py'),
70
+ 'whisperspeech.modules.MultiHeadAttention.qkv_attention_xformers': ( 'a. neural '
71
+ 'modules.html#multiheadattention.qkv_attention_xformers',
72
+ 'whisperspeech/modules.py'),
73
+ 'whisperspeech.modules.QueryHead': ('a. neural modules.html#queryhead', 'whisperspeech/modules.py'),
74
+ 'whisperspeech.modules.ResidualAttentionBlock': ( 'a. neural modules.html#residualattentionblock',
75
+ 'whisperspeech/modules.py'),
76
+ 'whisperspeech.modules.ResidualAttentionBlock.__init__': ( 'a. neural '
77
+ 'modules.html#residualattentionblock.__init__',
78
+ 'whisperspeech/modules.py'),
79
+ 'whisperspeech.modules.ResidualAttentionBlock.forward': ( 'a. neural '
80
+ 'modules.html#residualattentionblock.forward',
81
+ 'whisperspeech/modules.py'),
82
+ 'whisperspeech.modules.Rotary': ('a. neural modules.html#rotary', 'whisperspeech/modules.py'),
83
+ 'whisperspeech.modules.Rotary.__init__': ( 'a. neural modules.html#rotary.__init__',
84
+ 'whisperspeech/modules.py'),
85
+ 'whisperspeech.modules.Rotary.forward': ( 'a. neural modules.html#rotary.forward',
86
+ 'whisperspeech/modules.py'),
87
+ 'whisperspeech.modules.SumDecoder': ( 'a. neural modules.html#sumdecoder',
88
+ 'whisperspeech/modules.py'),
89
+ 'whisperspeech.modules.SumDecoder.__init__': ( 'a. neural modules.html#sumdecoder.__init__',
90
+ 'whisperspeech/modules.py'),
91
+ 'whisperspeech.modules.SumDecoder.forward': ( 'a. neural modules.html#sumdecoder.forward',
92
+ 'whisperspeech/modules.py'),
93
+ 'whisperspeech.modules.apply_rotary_pos_emb': ( 'a. neural modules.html#apply_rotary_pos_emb',
94
+ 'whisperspeech/modules.py'),
95
+ 'whisperspeech.modules.init_transformer': ( 'a. neural modules.html#init_transformer',
96
+ 'whisperspeech/modules.py'),
97
+ 'whisperspeech.modules.rotate_half': ( 'a. neural modules.html#rotate_half',
98
+ 'whisperspeech/modules.py'),
99
+ 'whisperspeech.modules.sinusoids': ('a. neural modules.html#sinusoids', 'whisperspeech/modules.py')},
100
+ 'whisperspeech.pipeline': { 'whisperspeech.pipeline.Pipeline': ('7. pipeline.html#pipeline', 'whisperspeech/pipeline.py'),
101
+ 'whisperspeech.pipeline.Pipeline.__init__': ( '7. pipeline.html#pipeline.__init__',
102
+ 'whisperspeech/pipeline.py'),
103
+ 'whisperspeech.pipeline.Pipeline.generate': ( '7. pipeline.html#pipeline.generate',
104
+ 'whisperspeech/pipeline.py'),
105
+ 'whisperspeech.pipeline.Pipeline.generate_atoks': ( '7. pipeline.html#pipeline.generate_atoks',
106
+ 'whisperspeech/pipeline.py'),
107
+ 'whisperspeech.pipeline.Pipeline.generate_to_file': ( '7. pipeline.html#pipeline.generate_to_file',
108
+ 'whisperspeech/pipeline.py'),
109
+ 'whisperspeech.pipeline.Pipeline.generate_to_notebook': ( '7. '
110
+ 'pipeline.html#pipeline.generate_to_notebook',
111
+ 'whisperspeech/pipeline.py')},
112
+ 'whisperspeech.prepare_s2a_dataset': { 'whisperspeech.prepare_s2a_dataset.flac_to_s2a_name': ( '4a. s2a dataset '
113
+ 'preparation.html#flac_to_s2a_name',
114
+ 'whisperspeech/prepare_s2a_dataset.py'),
115
+ 'whisperspeech.prepare_s2a_dataset.prepare_s2a': ( '4a. s2a dataset '
116
+ 'preparation.html#prepare_s2a',
117
+ 'whisperspeech/prepare_s2a_dataset.py'),
118
+ 'whisperspeech.prepare_s2a_dataset.resampler': ( '4a. s2a dataset '
119
+ 'preparation.html#resampler',
120
+ 'whisperspeech/prepare_s2a_dataset.py')},
121
+ 'whisperspeech.prepare_t2s_dataset': { 'whisperspeech.prepare_t2s_dataset.Transcriber': ( '5a. t2s dataset '
122
+ 'preparation.html#transcriber',
123
+ 'whisperspeech/prepare_t2s_dataset.py'),
124
+ 'whisperspeech.prepare_t2s_dataset.Transcriber.__init__': ( '5a. t2s dataset '
125
+ 'preparation.html#transcriber.__init__',
126
+ 'whisperspeech/prepare_t2s_dataset.py'),
127
+ 'whisperspeech.prepare_t2s_dataset.Transcriber.transcribe': ( '5a. t2s dataset '
128
+ 'preparation.html#transcriber.transcribe',
129
+ 'whisperspeech/prepare_t2s_dataset.py'),
130
+ 'whisperspeech.prepare_t2s_dataset.flac_to_t2s_name': ( '5a. t2s dataset '
131
+ 'preparation.html#flac_to_t2s_name',
132
+ 'whisperspeech/prepare_t2s_dataset.py'),
133
+ 'whisperspeech.prepare_t2s_dataset.prepare_t2s': ( '5a. t2s dataset '
134
+ 'preparation.html#prepare_t2s',
135
+ 'whisperspeech/prepare_t2s_dataset.py')},
136
+ 'whisperspeech.s2a_delar_mup_wds': { 'whisperspeech.s2a_delar_mup_wds.CMLMVisual': ( '4b. semantic to acoustic token '
137
+ 'modeling.html#cmlmvisual',
138
+ 'whisperspeech/s2a_delar_mup_wds.py'),
139
+ 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.__init__': ( '4b. semantic to acoustic token '
140
+ 'modeling.html#cmlmvisual.__init__',
141
+ 'whisperspeech/s2a_delar_mup_wds.py'),
142
+ 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.add_data': ( '4b. semantic to acoustic token '
143
+ 'modeling.html#cmlmvisual.add_data',
144
+ 'whisperspeech/s2a_delar_mup_wds.py'),
145
+ 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.add_table_row': ( '4b. semantic to acoustic '
146
+ 'token '
147
+ 'modeling.html#cmlmvisual.add_table_row',
148
+ 'whisperspeech/s2a_delar_mup_wds.py'),
149
+ 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.hide': ( '4b. semantic to acoustic token '
150
+ 'modeling.html#cmlmvisual.hide',
151
+ 'whisperspeech/s2a_delar_mup_wds.py'),
152
+ 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.on_iter': ( '4b. semantic to acoustic token '
153
+ 'modeling.html#cmlmvisual.on_iter',
154
+ 'whisperspeech/s2a_delar_mup_wds.py'),
155
+ 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.plot': ( '4b. semantic to acoustic token '
156
+ 'modeling.html#cmlmvisual.plot',
157
+ 'whisperspeech/s2a_delar_mup_wds.py'),
158
+ 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.show': ( '4b. semantic to acoustic token '
159
+ 'modeling.html#cmlmvisual.show',
160
+ 'whisperspeech/s2a_delar_mup_wds.py'),
161
+ 'whisperspeech.s2a_delar_mup_wds.DelSumDecoder': ( '4b. semantic to acoustic token '
162
+ 'modeling.html#delsumdecoder',
163
+ 'whisperspeech/s2a_delar_mup_wds.py'),
164
+ 'whisperspeech.s2a_delar_mup_wds.DelSumDecoder.__init__': ( '4b. semantic to acoustic '
165
+ 'token '
166
+ 'modeling.html#delsumdecoder.__init__',
167
+ 'whisperspeech/s2a_delar_mup_wds.py'),
168
+ 'whisperspeech.s2a_delar_mup_wds.DelSumDecoder.forward': ( '4b. semantic to acoustic '
169
+ 'token '
170
+ 'modeling.html#delsumdecoder.forward',
171
+ 'whisperspeech/s2a_delar_mup_wds.py'),
172
+ 'whisperspeech.s2a_delar_mup_wds.EmbeddingProjector': ( '4b. semantic to acoustic token '
173
+ 'modeling.html#embeddingprojector',
174
+ 'whisperspeech/s2a_delar_mup_wds.py'),
175
+ 'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention': ( '4b. semantic to acoustic token '
176
+ 'modeling.html#multiheadattention',
177
+ 'whisperspeech/s2a_delar_mup_wds.py'),
178
+ 'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.__init__': ( '4b. semantic to '
179
+ 'acoustic token '
180
+ 'modeling.html#multiheadattention.__init__',
181
+ 'whisperspeech/s2a_delar_mup_wds.py'),
182
+ 'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.forward': ( '4b. semantic to acoustic '
183
+ 'token '
184
+ 'modeling.html#multiheadattention.forward',
185
+ 'whisperspeech/s2a_delar_mup_wds.py'),
186
+ 'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.qkv_attention_pth20': ( '4b. semantic '
187
+ 'to acoustic '
188
+ 'token '
189
+ 'modeling.html#multiheadattention.qkv_attention_pth20',
190
+ 'whisperspeech/s2a_delar_mup_wds.py'),
191
+ 'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.qkv_attention_xformers': ( '4b. '
192
+ 'semantic '
193
+ 'to '
194
+ 'acoustic '
195
+ 'token '
196
+ 'modeling.html#multiheadattention.qkv_attention_xformers',
197
+ 'whisperspeech/s2a_delar_mup_wds.py'),
198
+ 'whisperspeech.s2a_delar_mup_wds.ResidualAttentionBlock': ( '4b. semantic to acoustic '
199
+ 'token '
200
+ 'modeling.html#residualattentionblock',
201
+ 'whisperspeech/s2a_delar_mup_wds.py'),
202
+ 'whisperspeech.s2a_delar_mup_wds.ResidualAttentionBlock.__init__': ( '4b. semantic to '
203
+ 'acoustic token '
204
+ 'modeling.html#residualattentionblock.__init__',
205
+ 'whisperspeech/s2a_delar_mup_wds.py'),
206
+ 'whisperspeech.s2a_delar_mup_wds.ResidualAttentionBlock.forward': ( '4b. semantic to '
207
+ 'acoustic token '
208
+ 'modeling.html#residualattentionblock.forward',
209
+ 'whisperspeech/s2a_delar_mup_wds.py'),
210
+ 'whisperspeech.s2a_delar_mup_wds.Rotary': ( '4b. semantic to acoustic token '
211
+ 'modeling.html#rotary',
212
+ 'whisperspeech/s2a_delar_mup_wds.py'),
213
+ 'whisperspeech.s2a_delar_mup_wds.Rotary.__init__': ( '4b. semantic to acoustic token '
214
+ 'modeling.html#rotary.__init__',
215
+ 'whisperspeech/s2a_delar_mup_wds.py'),
216
+ 'whisperspeech.s2a_delar_mup_wds.Rotary.forward': ( '4b. semantic to acoustic token '
217
+ 'modeling.html#rotary.forward',
218
+ 'whisperspeech/s2a_delar_mup_wds.py'),
219
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer': ( '4b. semantic to acoustic token '
220
+ 'modeling.html#sadelartransformer',
221
+ 'whisperspeech/s2a_delar_mup_wds.py'),
222
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.__init__': ( '4b. semantic to '
223
+ 'acoustic token '
224
+ 'modeling.html#sadelartransformer.__init__',
225
+ 'whisperspeech/s2a_delar_mup_wds.py'),
226
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.device': ( '4b. semantic to acoustic '
227
+ 'token '
228
+ 'modeling.html#sadelartransformer.device',
229
+ 'whisperspeech/s2a_delar_mup_wds.py'),
230
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.embed_stoks': ( '4b. semantic to '
231
+ 'acoustic token '
232
+ 'modeling.html#sadelartransformer.embed_stoks',
233
+ 'whisperspeech/s2a_delar_mup_wds.py'),
234
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.forward': ( '4b. semantic to acoustic '
235
+ 'token '
236
+ 'modeling.html#sadelartransformer.forward',
237
+ 'whisperspeech/s2a_delar_mup_wds.py'),
238
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.generate': ( '4b. semantic to '
239
+ 'acoustic token '
240
+ 'modeling.html#sadelartransformer.generate',
241
+ 'whisperspeech/s2a_delar_mup_wds.py'),
242
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.get_extra_state': ( '4b. semantic to '
243
+ 'acoustic token '
244
+ 'modeling.html#sadelartransformer.get_extra_state',
245
+ 'whisperspeech/s2a_delar_mup_wds.py'),
246
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.get_metrics': ( '4b. semantic to '
247
+ 'acoustic token '
248
+ 'modeling.html#sadelartransformer.get_metrics',
249
+ 'whisperspeech/s2a_delar_mup_wds.py'),
250
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.init_transformer': ( '4b. semantic to '
251
+ 'acoustic token '
252
+ 'modeling.html#sadelartransformer.init_transformer',
253
+ 'whisperspeech/s2a_delar_mup_wds.py'),
254
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.load_checkpoint': ( '4b. semantic to '
255
+ 'acoustic token '
256
+ 'modeling.html#sadelartransformer.load_checkpoint',
257
+ 'whisperspeech/s2a_delar_mup_wds.py'),
258
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.load_frozen_semantic_embeddings': ( '4b. '
259
+ 'semantic '
260
+ 'to '
261
+ 'acoustic '
262
+ 'token '
263
+ 'modeling.html#sadelartransformer.load_frozen_semantic_embeddings',
264
+ 'whisperspeech/s2a_delar_mup_wds.py'),
265
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.load_model': ( '4b. semantic to '
266
+ 'acoustic token '
267
+ 'modeling.html#sadelartransformer.load_model',
268
+ 'whisperspeech/s2a_delar_mup_wds.py'),
269
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.save_model': ( '4b. semantic to '
270
+ 'acoustic token '
271
+ 'modeling.html#sadelartransformer.save_model',
272
+ 'whisperspeech/s2a_delar_mup_wds.py'),
273
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.set_extra_state': ( '4b. semantic to '
274
+ 'acoustic token '
275
+ 'modeling.html#sadelartransformer.set_extra_state',
276
+ 'whisperspeech/s2a_delar_mup_wds.py'),
277
+ 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.setup': ( '4b. semantic to acoustic '
278
+ 'token '
279
+ 'modeling.html#sadelartransformer.setup',
280
+ 'whisperspeech/s2a_delar_mup_wds.py'),
281
+ 'whisperspeech.s2a_delar_mup_wds.Tunables': ( '4b. semantic to acoustic token '
282
+ 'modeling.html#tunables',
283
+ 'whisperspeech/s2a_delar_mup_wds.py'),
284
+ 'whisperspeech.s2a_delar_mup_wds.Tunables.__post_init__': ( '4b. semantic to acoustic '
285
+ 'token '
286
+ 'modeling.html#tunables.__post_init__',
287
+ 'whisperspeech/s2a_delar_mup_wds.py'),
288
+ 'whisperspeech.s2a_delar_mup_wds.Tunables.upgrade': ( '4b. semantic to acoustic token '
289
+ 'modeling.html#tunables.upgrade',
290
+ 'whisperspeech/s2a_delar_mup_wds.py'),
291
+ 'whisperspeech.s2a_delar_mup_wds._make_model': ( '4b. semantic to acoustic token '
292
+ 'modeling.html#_make_model',
293
+ 'whisperspeech/s2a_delar_mup_wds.py'),
294
+ 'whisperspeech.s2a_delar_mup_wds.apply_rotary_pos_emb': ( '4b. semantic to acoustic token '
295
+ 'modeling.html#apply_rotary_pos_emb',
296
+ 'whisperspeech/s2a_delar_mup_wds.py'),
297
+ 'whisperspeech.s2a_delar_mup_wds.load_datasets': ( '4b. semantic to acoustic token '
298
+ 'modeling.html#load_datasets',
299
+ 'whisperspeech/s2a_delar_mup_wds.py'),
300
+ 'whisperspeech.s2a_delar_mup_wds.make_model': ( '4b. semantic to acoustic token '
301
+ 'modeling.html#make_model',
302
+ 'whisperspeech/s2a_delar_mup_wds.py'),
303
+ 'whisperspeech.s2a_delar_mup_wds.pad_samples': ( '4b. semantic to acoustic token '
304
+ 'modeling.html#pad_samples',
305
+ 'whisperspeech/s2a_delar_mup_wds.py'),
306
+ 'whisperspeech.s2a_delar_mup_wds.rand': ( '4b. semantic to acoustic token '
307
+ 'modeling.html#rand',
308
+ 'whisperspeech/s2a_delar_mup_wds.py'),
309
+ 'whisperspeech.s2a_delar_mup_wds.random_trunc': ( '4b. semantic to acoustic token '
310
+ 'modeling.html#random_trunc',
311
+ 'whisperspeech/s2a_delar_mup_wds.py'),
312
+ 'whisperspeech.s2a_delar_mup_wds.rotate_half': ( '4b. semantic to acoustic token '
313
+ 'modeling.html#rotate_half',
314
+ 'whisperspeech/s2a_delar_mup_wds.py'),
315
+ 'whisperspeech.s2a_delar_mup_wds.speaker_id_extractor': ( '4b. semantic to acoustic token '
316
+ 'modeling.html#speaker_id_extractor',
317
+ 'whisperspeech/s2a_delar_mup_wds.py')},
318
+ 'whisperspeech.t2s_up_wds': { 'whisperspeech.t2s_up_wds.CharTokenizer': ( '5b. text to semantic token '
319
+ 'modeling.html#chartokenizer',
320
+ 'whisperspeech/t2s_up_wds.py'),
321
+ 'whisperspeech.t2s_up_wds.CharTokenizer.decode': ( '5b. text to semantic token '
322
+ 'modeling.html#chartokenizer.decode',
323
+ 'whisperspeech/t2s_up_wds.py'),
324
+ 'whisperspeech.t2s_up_wds.CharTokenizer.encode': ( '5b. text to semantic token '
325
+ 'modeling.html#chartokenizer.encode',
326
+ 'whisperspeech/t2s_up_wds.py'),
327
+ 'whisperspeech.t2s_up_wds.Decoder': ( '5b. text to semantic token modeling.html#decoder',
328
+ 'whisperspeech/t2s_up_wds.py'),
329
+ 'whisperspeech.t2s_up_wds.Decoder.__init__': ( '5b. text to semantic token '
330
+ 'modeling.html#decoder.__init__',
331
+ 'whisperspeech/t2s_up_wds.py'),
332
+ 'whisperspeech.t2s_up_wds.Decoder.forward': ( '5b. text to semantic token '
333
+ 'modeling.html#decoder.forward',
334
+ 'whisperspeech/t2s_up_wds.py'),
335
+ 'whisperspeech.t2s_up_wds.EmbeddingProjector': ( '5b. text to semantic token '
336
+ 'modeling.html#embeddingprojector',
337
+ 'whisperspeech/t2s_up_wds.py'),
338
+ 'whisperspeech.t2s_up_wds.Encoder': ( '5b. text to semantic token modeling.html#encoder',
339
+ 'whisperspeech/t2s_up_wds.py'),
340
+ 'whisperspeech.t2s_up_wds.Encoder.__init__': ( '5b. text to semantic token '
341
+ 'modeling.html#encoder.__init__',
342
+ 'whisperspeech/t2s_up_wds.py'),
343
+ 'whisperspeech.t2s_up_wds.Encoder.forward': ( '5b. text to semantic token '
344
+ 'modeling.html#encoder.forward',
345
+ 'whisperspeech/t2s_up_wds.py'),
346
+ 'whisperspeech.t2s_up_wds.TSARTransformer': ( '5b. text to semantic token '
347
+ 'modeling.html#tsartransformer',
348
+ 'whisperspeech/t2s_up_wds.py'),
349
+ 'whisperspeech.t2s_up_wds.TSARTransformer.__init__': ( '5b. text to semantic token '
350
+ 'modeling.html#tsartransformer.__init__',
351
+ 'whisperspeech/t2s_up_wds.py'),
352
+ 'whisperspeech.t2s_up_wds.TSARTransformer.device': ( '5b. text to semantic token '
353
+ 'modeling.html#tsartransformer.device',
354
+ 'whisperspeech/t2s_up_wds.py'),
355
+ 'whisperspeech.t2s_up_wds.TSARTransformer.ensure_tokenizer': ( '5b. text to semantic token '
356
+ 'modeling.html#tsartransformer.ensure_tokenizer',
357
+ 'whisperspeech/t2s_up_wds.py'),
358
+ 'whisperspeech.t2s_up_wds.TSARTransformer.forward': ( '5b. text to semantic token '
359
+ 'modeling.html#tsartransformer.forward',
360
+ 'whisperspeech/t2s_up_wds.py'),
361
+ 'whisperspeech.t2s_up_wds.TSARTransformer.generate': ( '5b. text to semantic token '
362
+ 'modeling.html#tsartransformer.generate',
363
+ 'whisperspeech/t2s_up_wds.py'),
364
+ 'whisperspeech.t2s_up_wds.TSARTransformer.generate_batch': ( '5b. text to semantic token '
365
+ 'modeling.html#tsartransformer.generate_batch',
366
+ 'whisperspeech/t2s_up_wds.py'),
367
+ 'whisperspeech.t2s_up_wds.TSARTransformer.init_transformer': ( '5b. text to semantic token '
368
+ 'modeling.html#tsartransformer.init_transformer',
369
+ 'whisperspeech/t2s_up_wds.py'),
370
+ 'whisperspeech.t2s_up_wds.TSARTransformer.load_checkpoint': ( '5b. text to semantic token '
371
+ 'modeling.html#tsartransformer.load_checkpoint',
372
+ 'whisperspeech/t2s_up_wds.py'),
373
+ 'whisperspeech.t2s_up_wds.TSARTransformer.load_frozen_semantic_embeddings': ( '5b. text to '
374
+ 'semantic token '
375
+ 'modeling.html#tsartransformer.load_frozen_semantic_embeddings',
376
+ 'whisperspeech/t2s_up_wds.py'),
377
+ 'whisperspeech.t2s_up_wds.TSARTransformer.load_model': ( '5b. text to semantic token '
378
+ 'modeling.html#tsartransformer.load_model',
379
+ 'whisperspeech/t2s_up_wds.py'),
380
+ 'whisperspeech.t2s_up_wds.TSARTransformer.save_model': ( '5b. text to semantic token '
381
+ 'modeling.html#tsartransformer.save_model',
382
+ 'whisperspeech/t2s_up_wds.py'),
383
+ 'whisperspeech.t2s_up_wds.TSARTransformer.setup': ( '5b. text to semantic token '
384
+ 'modeling.html#tsartransformer.setup',
385
+ 'whisperspeech/t2s_up_wds.py'),
386
+ 'whisperspeech.t2s_up_wds.Tunables': ( '5b. text to semantic token modeling.html#tunables',
387
+ 'whisperspeech/t2s_up_wds.py'),
388
+ 'whisperspeech.t2s_up_wds.Tunables.__post_init__': ( '5b. text to semantic token '
389
+ 'modeling.html#tunables.__post_init__',
390
+ 'whisperspeech/t2s_up_wds.py'),
391
+ 'whisperspeech.t2s_up_wds._make_model': ( '5b. text to semantic token modeling.html#_make_model',
392
+ 'whisperspeech/t2s_up_wds.py'),
393
+ 'whisperspeech.t2s_up_wds.ar_padder': ( '5b. text to semantic token modeling.html#ar_padder',
394
+ 'whisperspeech/t2s_up_wds.py'),
395
+ 'whisperspeech.t2s_up_wds.build_speaker_map': ( '5b. text to semantic token '
396
+ 'modeling.html#build_speaker_map',
397
+ 'whisperspeech/t2s_up_wds.py'),
398
+ 'whisperspeech.t2s_up_wds.char_per_seconder': ( '5b. text to semantic token '
399
+ 'modeling.html#char_per_seconder',
400
+ 'whisperspeech/t2s_up_wds.py'),
401
+ 'whisperspeech.t2s_up_wds.load_datasets': ( '5b. text to semantic token '
402
+ 'modeling.html#load_datasets',
403
+ 'whisperspeech/t2s_up_wds.py'),
404
+ 'whisperspeech.t2s_up_wds.make_model': ( '5b. text to semantic token modeling.html#make_model',
405
+ 'whisperspeech/t2s_up_wds.py'),
406
+ 'whisperspeech.t2s_up_wds.rand': ( '5b. text to semantic token modeling.html#rand',
407
+ 'whisperspeech/t2s_up_wds.py'),
408
+ 'whisperspeech.t2s_up_wds.speaker_id_extractor': ( '5b. text to semantic token '
409
+ 'modeling.html#speaker_id_extractor',
410
+ 'whisperspeech/t2s_up_wds.py'),
411
+ 'whisperspeech.t2s_up_wds.tokenizer': ( '5b. text to semantic token modeling.html#tokenizer',
412
+ 'whisperspeech/t2s_up_wds.py')},
413
+ 'whisperspeech.train': { 'whisperspeech.train.SimpleVisual': ('b1. training.html#simplevisual', 'whisperspeech/train.py'),
414
+ 'whisperspeech.train.SimpleVisual.__init__': ( 'b1. training.html#simplevisual.__init__',
415
+ 'whisperspeech/train.py'),
416
+ 'whisperspeech.train.SimpleVisual.add_data': ( 'b1. training.html#simplevisual.add_data',
417
+ 'whisperspeech/train.py'),
418
+ 'whisperspeech.train.SimpleVisual.add_table_row': ( 'b1. training.html#simplevisual.add_table_row',
419
+ 'whisperspeech/train.py'),
420
+ 'whisperspeech.train.SimpleVisual.hide': ( 'b1. training.html#simplevisual.hide',
421
+ 'whisperspeech/train.py'),
422
+ 'whisperspeech.train.SimpleVisual.on_iter': ( 'b1. training.html#simplevisual.on_iter',
423
+ 'whisperspeech/train.py'),
424
+ 'whisperspeech.train.SimpleVisual.plot': ( 'b1. training.html#simplevisual.plot',
425
+ 'whisperspeech/train.py'),
426
+ 'whisperspeech.train.SimpleVisual.show': ( 'b1. training.html#simplevisual.show',
427
+ 'whisperspeech/train.py'),
428
+ 'whisperspeech.train.train': ('b1. training.html#train', 'whisperspeech/train.py'),
429
+ 'whisperspeech.train.validate': ('b1. training.html#validate', 'whisperspeech/train.py')},
430
+ 'whisperspeech.train_multi': { 'whisperspeech.train_multi.TrainingTask': ( 'b2. training (lightning).html#trainingtask',
431
+ 'whisperspeech/train_multi.py'),
432
+ 'whisperspeech.train_multi.TrainingTask.__init__': ( 'b2. training '
433
+ '(lightning).html#trainingtask.__init__',
434
+ 'whisperspeech/train_multi.py'),
435
+ 'whisperspeech.train_multi.TrainingTask.configure_optimizers': ( 'b2. training '
436
+ '(lightning).html#trainingtask.configure_optimizers',
437
+ 'whisperspeech/train_multi.py'),
438
+ 'whisperspeech.train_multi.TrainingTask.on_fit_start': ( 'b2. training '
439
+ '(lightning).html#trainingtask.on_fit_start',
440
+ 'whisperspeech/train_multi.py'),
441
+ 'whisperspeech.train_multi.TrainingTask.on_validation_epoch_end': ( 'b2. training '
442
+ '(lightning).html#trainingtask.on_validation_epoch_end',
443
+ 'whisperspeech/train_multi.py'),
444
+ 'whisperspeech.train_multi.TrainingTask.test_step': ( 'b2. training '
445
+ '(lightning).html#trainingtask.test_step',
446
+ 'whisperspeech/train_multi.py'),
447
+ 'whisperspeech.train_multi.TrainingTask.training_step': ( 'b2. training '
448
+ '(lightning).html#trainingtask.training_step',
449
+ 'whisperspeech/train_multi.py'),
450
+ 'whisperspeech.train_multi.TrainingTask.validation_step': ( 'b2. training '
451
+ '(lightning).html#trainingtask.validation_step',
452
+ 'whisperspeech/train_multi.py'),
453
+ 'whisperspeech.train_multi.parse_and_call': ( 'b2. training (lightning).html#parse_and_call',
454
+ 'whisperspeech/train_multi.py')},
455
+ 'whisperspeech.vad': { 'whisperspeech.vad.extract_segments': ( '1b. voice activity detection.html#extract_segments',
456
+ 'whisperspeech/vad.py'),
457
+ 'whisperspeech.vad.fix_dots_in_names': ( '1b. voice activity detection.html#fix_dots_in_names',
458
+ 'whisperspeech/vad.py'),
459
+ 'whisperspeech.vad.flac_to_vad_name': ( '1b. voice activity detection.html#flac_to_vad_name',
460
+ 'whisperspeech/vad.py'),
461
+ 'whisperspeech.vad.load_dataset': ( '1b. voice activity detection.html#load_dataset',
462
+ 'whisperspeech/vad.py'),
463
+ 'whisperspeech.vad.process_shard': ( '1b. voice activity detection.html#process_shard',
464
+ 'whisperspeech/vad.py'),
465
+ 'whisperspeech.vad.segment_audio': ( '1b. voice activity detection.html#segment_audio',
466
+ 'whisperspeech/vad.py')},
467
+ 'whisperspeech.verify_wds': { 'whisperspeech.verify_wds.process_shard': ( '0. verify webdataset archives.html#process_shard',
468
+ 'whisperspeech/verify_wds.py')},
469
+ 'whisperspeech.vq_stoks': { 'whisperspeech.vq_stoks.RQBottleneckTransformer': ( '2b. whisper quantization (semantic token) '
470
+ 'model.html#rqbottlenecktransformer',
471
+ 'whisperspeech/vq_stoks.py'),
472
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.__init__': ( '2b. whisper quantization (semantic '
473
+ 'token) '
474
+ 'model.html#rqbottlenecktransformer.__init__',
475
+ 'whisperspeech/vq_stoks.py'),
476
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.decode_text': ( '2b. whisper quantization '
477
+ '(semantic token) '
478
+ 'model.html#rqbottlenecktransformer.decode_text',
479
+ 'whisperspeech/vq_stoks.py'),
480
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.dequantize': ( '2b. whisper quantization (semantic '
481
+ 'token) '
482
+ 'model.html#rqbottlenecktransformer.dequantize',
483
+ 'whisperspeech/vq_stoks.py'),
484
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.device': ( '2b. whisper quantization (semantic '
485
+ 'token) '
486
+ 'model.html#rqbottlenecktransformer.device',
487
+ 'whisperspeech/vq_stoks.py'),
488
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.downsample_embeddings': ( '2b. whisper '
489
+ 'quantization (semantic '
490
+ 'token) '
491
+ 'model.html#rqbottlenecktransformer.downsample_embeddings',
492
+ 'whisperspeech/vq_stoks.py'),
493
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.encode_audio': ( '2b. whisper quantization '
494
+ '(semantic token) '
495
+ 'model.html#rqbottlenecktransformer.encode_audio',
496
+ 'whisperspeech/vq_stoks.py'),
497
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.encode_mel': ( '2b. whisper quantization (semantic '
498
+ 'token) '
499
+ 'model.html#rqbottlenecktransformer.encode_mel',
500
+ 'whisperspeech/vq_stoks.py'),
501
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.ensure_whisper': ( '2b. whisper quantization '
502
+ '(semantic token) '
503
+ 'model.html#rqbottlenecktransformer.ensure_whisper',
504
+ 'whisperspeech/vq_stoks.py'),
505
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.extract_teacher': ( '2b. whisper quantization '
506
+ '(semantic token) '
507
+ 'model.html#rqbottlenecktransformer.extract_teacher',
508
+ 'whisperspeech/vq_stoks.py'),
509
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.forward': ( '2b. whisper quantization (semantic '
510
+ 'token) '
511
+ 'model.html#rqbottlenecktransformer.forward',
512
+ 'whisperspeech/vq_stoks.py'),
513
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.get_metrics': ( '2b. whisper quantization '
514
+ '(semantic token) '
515
+ 'model.html#rqbottlenecktransformer.get_metrics',
516
+ 'whisperspeech/vq_stoks.py'),
517
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.init_transformer': ( '2b. whisper quantization '
518
+ '(semantic token) '
519
+ 'model.html#rqbottlenecktransformer.init_transformer',
520
+ 'whisperspeech/vq_stoks.py'),
521
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.load_checkpoint': ( '2b. whisper quantization '
522
+ '(semantic token) '
523
+ 'model.html#rqbottlenecktransformer.load_checkpoint',
524
+ 'whisperspeech/vq_stoks.py'),
525
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.load_model': ( '2b. whisper quantization (semantic '
526
+ 'token) '
527
+ 'model.html#rqbottlenecktransformer.load_model',
528
+ 'whisperspeech/vq_stoks.py'),
529
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.quantize': ( '2b. whisper quantization (semantic '
530
+ 'token) '
531
+ 'model.html#rqbottlenecktransformer.quantize',
532
+ 'whisperspeech/vq_stoks.py'),
533
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.save_model': ( '2b. whisper quantization (semantic '
534
+ 'token) '
535
+ 'model.html#rqbottlenecktransformer.save_model',
536
+ 'whisperspeech/vq_stoks.py'),
537
+ 'whisperspeech.vq_stoks.RQBottleneckTransformer.setup': ( '2b. whisper quantization (semantic '
538
+ 'token) '
539
+ 'model.html#rqbottlenecktransformer.setup',
540
+ 'whisperspeech/vq_stoks.py'),
541
+ 'whisperspeech.vq_stoks.Tunables': ( '2b. whisper quantization (semantic token) '
542
+ 'model.html#tunables',
543
+ 'whisperspeech/vq_stoks.py'),
544
+ 'whisperspeech.vq_stoks.Tunables.__post_init__': ( '2b. whisper quantization (semantic token) '
545
+ 'model.html#tunables.__post_init__',
546
+ 'whisperspeech/vq_stoks.py'),
547
+ 'whisperspeech.vq_stoks.Tunables.upgrade': ( '2b. whisper quantization (semantic token) '
548
+ 'model.html#tunables.upgrade',
549
+ 'whisperspeech/vq_stoks.py'),
550
+ 'whisperspeech.vq_stoks.add_masks': ( '2b. whisper quantization (semantic token) '
551
+ 'model.html#add_masks',
552
+ 'whisperspeech/vq_stoks.py'),
553
+ 'whisperspeech.vq_stoks.derived_dataset': ( '2b. whisper quantization (semantic token) '
554
+ 'model.html#derived_dataset',
555
+ 'whisperspeech/vq_stoks.py'),
556
+ 'whisperspeech.vq_stoks.load_datasets': ( '2b. whisper quantization (semantic token) '
557
+ 'model.html#load_datasets',
558
+ 'whisperspeech/vq_stoks.py'),
559
+ 'whisperspeech.vq_stoks.logrand': ( '2b. whisper quantization (semantic token) model.html#logrand',
560
+ 'whisperspeech/vq_stoks.py'),
561
+ 'whisperspeech.vq_stoks.make_model': ( '2b. whisper quantization (semantic token) '
562
+ 'model.html#make_model',
563
+ 'whisperspeech/vq_stoks.py'),
564
+ 'whisperspeech.vq_stoks.merge_in': ( '2b. whisper quantization (semantic token) '
565
+ 'model.html#merge_in',
566
+ 'whisperspeech/vq_stoks.py'),
567
+ 'whisperspeech.vq_stoks.rand': ( '2b. whisper quantization (semantic token) model.html#rand',
568
+ 'whisperspeech/vq_stoks.py'),
569
+ 'whisperspeech.vq_stoks.tokenize_text': ( '2b. whisper quantization (semantic token) '
570
+ 'model.html#tokenize_text',
571
+ 'whisperspeech/vq_stoks.py')},
572
+ 'whisperspeech.wer_metrics': { 'whisperspeech.wer_metrics.DfBuilder': ( 'c. word error rate metrics.html#dfbuilder',
573
+ 'whisperspeech/wer_metrics.py'),
574
+ 'whisperspeech.wer_metrics.DfBuilder.__init__': ( 'c. word error rate '
575
+ 'metrics.html#dfbuilder.__init__',
576
+ 'whisperspeech/wer_metrics.py'),
577
+ 'whisperspeech.wer_metrics.DfBuilder.df': ( 'c. word error rate metrics.html#dfbuilder.df',
578
+ 'whisperspeech/wer_metrics.py'),
579
+ 'whisperspeech.wer_metrics.DfBuilder.push': ( 'c. word error rate metrics.html#dfbuilder.push',
580
+ 'whisperspeech/wer_metrics.py'),
581
+ 'whisperspeech.wer_metrics.WERStats': ( 'c. word error rate metrics.html#werstats',
582
+ 'whisperspeech/wer_metrics.py'),
583
+ 'whisperspeech.wer_metrics.WERStats.__init__': ( 'c. word error rate '
584
+ 'metrics.html#werstats.__init__',
585
+ 'whisperspeech/wer_metrics.py'),
586
+ 'whisperspeech.wer_metrics.WERStats.push_sample': ( 'c. word error rate '
587
+ 'metrics.html#werstats.push_sample',
588
+ 'whisperspeech/wer_metrics.py'),
589
+ 'whisperspeech.wer_metrics.librispeech_data': ( 'c. word error rate '
590
+ 'metrics.html#librispeech_data',
591
+ 'whisperspeech/wer_metrics.py'),
592
+ 'whisperspeech.wer_metrics.whisper_normalize': ( 'c. word error rate '
593
+ 'metrics.html#whisper_normalize',
594
+ 'whisperspeech/wer_metrics.py')},
595
+ 'whisperspeech.wh_transcribe': { 'whisperspeech.wh_transcribe.chunk_merger': ( '2a. whisper quantization dataset '
596
+ 'preparation.html#chunk_merger',
597
+ 'whisperspeech/wh_transcribe.py'),
598
+ 'whisperspeech.wh_transcribe.flac_to_txt_name': ( '2a. whisper quantization dataset '
599
+ 'preparation.html#flac_to_txt_name',
600
+ 'whisperspeech/wh_transcribe.py'),
601
+ 'whisperspeech.wh_transcribe.merge_in': ( '2a. whisper quantization dataset '
602
+ 'preparation.html#merge_in',
603
+ 'whisperspeech/wh_transcribe.py'),
604
+ 'whisperspeech.wh_transcribe.process_shard': ( '2a. whisper quantization dataset '
605
+ 'preparation.html#process_shard',
606
+ 'whisperspeech/wh_transcribe.py'),
607
+ 'whisperspeech.wh_transcribe.random_cutter': ( '2a. whisper quantization dataset '
608
+ 'preparation.html#random_cutter',
609
+ 'whisperspeech/wh_transcribe.py'),
610
+ 'whisperspeech.wh_transcribe.split_to_chunks': ( '2a. whisper quantization dataset '
611
+ 'preparation.html#split_to_chunks',
612
+ 'whisperspeech/wh_transcribe.py'),
613
+ 'whisperspeech.wh_transcribe.wds_compose': ( '2a. whisper quantization dataset '
614
+ 'preparation.html#wds_compose',
615
+ 'whisperspeech/wh_transcribe.py')}}}
whisperspeech/a2wav.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/6. Quality-boosting vocoder.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['Vocoder']
5
+
6
+ # %% ../nbs/6. Quality-boosting vocoder.ipynb 1
7
+ from vocos import Vocos
8
+ import torch
9
+ import torchaudio
10
+
11
+ # %% ../nbs/6. Quality-boosting vocoder.ipynb 2
12
+ class Vocoder:
13
+ def __init__(self, repo_id="charactr/vocos-encodec-24khz"):
14
+ self.vocos = Vocos.from_pretrained(repo_id).cuda()
15
+
16
+ def is_notebook(self):
17
+ try:
18
+ return get_ipython().__class__.__name__ == "ZMQInteractiveShell"
19
+ except:
20
+ return False
21
+
22
+ @torch.no_grad()
23
+ def decode(self, atoks):
24
+ if len(atoks.shape) == 3:
25
+ b,q,t = atoks.shape
26
+ atoks = atoks.permute(1,0,2)
27
+ else:
28
+ q,t = atoks.shape
29
+
30
+ features = self.vocos.codes_to_features(atoks)
31
+ bandwidth_id = torch.tensor({2:0,4:1,8:2}[q]).cuda()
32
+ return self.vocos.decode(features, bandwidth_id=bandwidth_id)
33
+
34
+ def decode_to_file(self, fname, atoks):
35
+ audio = self.decode(atoks)
36
+ torchaudio.save(fname, audio.cpu(), 24000)
37
+ if self.is_notebook():
38
+ from IPython.display import display, HTML, Audio
39
+ display(HTML(f'<a href="{fname}" target="_blank">Listen to {fname}</a>'))
40
+
41
+ def decode_to_notebook(self, atoks):
42
+ from IPython.display import display, HTML, Audio
43
+
44
+ audio = self.decode(atoks)
45
+ display(Audio(audio.cpu().numpy(), rate=24000))
whisperspeech/extract_acoustic.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/1. Acoustic token extraction.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['load', 'load_model', 'extract_Atoks', 'extract_acoustic']
5
+
6
+ # %% ../nbs/1. Acoustic token extraction.ipynb 2
7
+ import torch
8
+ import torchaudio
9
+ import gc
10
+
11
+ from pathlib import Path
12
+ from fastcore.script import *
13
+ from fastprogress import progress_bar, master_bar
14
+
15
+ # %% ../nbs/1. Acoustic token extraction.ipynb 5
16
+ def load(fname, newsr=24000):
17
+ """Load an audio file to the GPU and resample to `newsr`."""
18
+ x, sr = torchaudio.load(fname)
19
+ _tform = torchaudio.transforms.Resample(sr, newsr)
20
+ return _tform(x).cuda().unsqueeze(0)
21
+
22
+ # %% ../nbs/1. Acoustic token extraction.ipynb 6
23
+ def load_model():
24
+ "Load the pretrained EnCodec model"
25
+ from encodec.model import EncodecModel
26
+ model = EncodecModel.encodec_model_24khz()
27
+ model.set_target_bandwidth(1.5)
28
+ model.cuda().eval();
29
+ return model
30
+
31
+ # %% ../nbs/1. Acoustic token extraction.ipynb 7
32
+ def extract_Atoks(model, audio):
33
+ """Extract EnCodec tokens for the given `audio` tensor (or file path)
34
+ using the given `model` (see `load_model`)."""
35
+ if isinstance(audio, (Path, str)):
36
+ audio = load(audio)
37
+ with torch.no_grad():
38
+ frames = torch.cat([model.encode(segment)[0][0]
39
+ for segment in torch.split(audio, 320*20000, dim=-1)], dim=-1)
40
+ return frames
41
+
42
+ # %% ../nbs/1. Acoustic token extraction.ipynb 8
43
+ @call_parse
44
+ def extract_acoustic(
45
+ srcdir:Path, # source dir, should contain *.flac files
46
+ outdir:Path, # output dir, will get the *.encodec files
47
+ ):
48
+ "Convert audio files to .encodec files with tensors of tokens"
49
+ model = load_model()
50
+ outdir.mkdir(exist_ok=True, parents=True)
51
+ for name in progress_bar(list(srcdir.rglob('*.flac'))):
52
+ outname = outdir/name.with_suffix('.encodec').name
53
+ tokens = extract_Atoks(model, name)
54
+ torch.save(tokens, outname)
55
+ del tokens
56
+ gc.collect()
whisperspeech/fetch_models.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/0. Download models.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = []
5
+
6
+ # %% ../nbs/0. Download models.ipynb 1
7
+ from fastcore.script import call_parse
8
+ import whisperx
9
+ import whisper
10
+
11
+ # %% ../nbs/0. Download models.ipynb 3
12
+ @call_parse
13
+ def main():
14
+ whisper.load_model('base.en')
15
+ whisper.load_model('small.en')
16
+ whisperx.vad.load_vad_model('cpu')
17
+ whisperx.asr.load_model('medium.en', "cpu", compute_type="float16", language='en')
whisperspeech/languages.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B. Languages.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['to_id']
5
+
6
+ # %% ../nbs/B. Languages.ipynb 3
7
+ LANGUAGES = {
8
+ "en": "english",
9
+ "zh": "chinese",
10
+ "de": "german",
11
+ "es": "spanish",
12
+ "ru": "russian",
13
+ "ko": "korean",
14
+ "fr": "french",
15
+ "ja": "japanese",
16
+ "pt": "portuguese",
17
+ "tr": "turkish",
18
+ "pl": "polish",
19
+ "ca": "catalan",
20
+ "nl": "dutch",
21
+ "ar": "arabic",
22
+ "sv": "swedish",
23
+ "it": "italian",
24
+ "id": "indonesian",
25
+ "hi": "hindi",
26
+ "fi": "finnish",
27
+ "vi": "vietnamese",
28
+ "he": "hebrew",
29
+ "uk": "ukrainian",
30
+ "el": "greek",
31
+ "ms": "malay",
32
+ "cs": "czech",
33
+ "ro": "romanian",
34
+ "da": "danish",
35
+ "hu": "hungarian",
36
+ "ta": "tamil",
37
+ "no": "norwegian",
38
+ "th": "thai",
39
+ "ur": "urdu",
40
+ "hr": "croatian",
41
+ "bg": "bulgarian",
42
+ "lt": "lithuanian",
43
+ "la": "latin",
44
+ "mi": "maori",
45
+ "ml": "malayalam",
46
+ "cy": "welsh",
47
+ "sk": "slovak",
48
+ "te": "telugu",
49
+ "fa": "persian",
50
+ "lv": "latvian",
51
+ "bn": "bengali",
52
+ "sr": "serbian",
53
+ "az": "azerbaijani",
54
+ "sl": "slovenian",
55
+ "kn": "kannada",
56
+ "et": "estonian",
57
+ "mk": "macedonian",
58
+ "br": "breton",
59
+ "eu": "basque",
60
+ "is": "icelandic",
61
+ "hy": "armenian",
62
+ "ne": "nepali",
63
+ "mn": "mongolian",
64
+ "bs": "bosnian",
65
+ "kk": "kazakh",
66
+ "sq": "albanian",
67
+ "sw": "swahili",
68
+ "gl": "galician",
69
+ "mr": "marathi",
70
+ "pa": "punjabi",
71
+ "si": "sinhala",
72
+ "km": "khmer",
73
+ "sn": "shona",
74
+ "yo": "yoruba",
75
+ "so": "somali",
76
+ "af": "afrikaans",
77
+ "oc": "occitan",
78
+ "ka": "georgian",
79
+ "be": "belarusian",
80
+ "tg": "tajik",
81
+ "sd": "sindhi",
82
+ "gu": "gujarati",
83
+ "am": "amharic",
84
+ "yi": "yiddish",
85
+ "lo": "lao",
86
+ "uz": "uzbek",
87
+ "fo": "faroese",
88
+ "ht": "haitian creole",
89
+ "ps": "pashto",
90
+ "tk": "turkmen",
91
+ "nn": "nynorsk",
92
+ "mt": "maltese",
93
+ "sa": "sanskrit",
94
+ "lb": "luxembourgish",
95
+ "my": "myanmar",
96
+ "bo": "tibetan",
97
+ "tl": "tagalog",
98
+ "mg": "malagasy",
99
+ "as": "assamese",
100
+ "tt": "tatar",
101
+ "haw": "hawaiian",
102
+ "ln": "lingala",
103
+ "ha": "hausa",
104
+ "ba": "bashkir",
105
+ "jw": "javanese",
106
+ "su": "sundanese",
107
+ }
108
+
109
+ # %% ../nbs/B. Languages.ipynb 4
110
+ # language code lookup by name, with a few language aliases
111
+ TO_LANGUAGE_CODE = {
112
+ **{language: code for code, language in LANGUAGES.items()},
113
+ "burmese": "my",
114
+ "valencian": "ca",
115
+ "flemish": "nl",
116
+ "haitian": "ht",
117
+ "letzeburgesch": "lb",
118
+ "pushto": "ps",
119
+ "panjabi": "pa",
120
+ "moldavian": "ro",
121
+ "moldovan": "ro",
122
+ "sinhalese": "si",
123
+ "castilian": "es",
124
+ }
125
+
126
+ # %% ../nbs/B. Languages.ipynb 5
127
+ languages = tuple(LANGUAGES.keys())
128
+
129
+ # %% ../nbs/B. Languages.ipynb 6
130
+ def to_id(lang):
131
+ return languages.index(TO_LANGUAGE_CODE.get(lang, lang))
whisperspeech/modules.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/A. Neural modules.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['LayerNorm', 'LinearHead', 'QueryHead', 'init_transformer', 'sinusoids', 'MultiHeadAttention',
5
+ 'ResidualAttentionBlock', 'BaseDecoder', 'EmbeddingProjector', 'FlexEmbeddings']
6
+
7
+ # %% ../nbs/A. Neural modules.ipynb 2
8
+ import torch
9
+ import numpy as np
10
+ import math
11
+
12
+ from torch import Tensor, nn
13
+ import torch.nn.functional as F
14
+ from typing import Dict, Iterable, Optional
15
+
16
+ # import xformers.ops as xops
17
+
18
+ # %% ../nbs/A. Neural modules.ipynb 3
19
+ # Code in this file is mostly borrowed from
20
+ # https://github.com/openai/whisper/blob/main/whisper/model.py
21
+ # and is under the MIT License
22
+
23
+ class LayerNorm(nn.LayerNorm):
24
+ def forward(self, x):
25
+ return super().forward(x.float()).type(x.dtype)
26
+
27
+ # Used in ฮผP to initialize the weights and configure the optimizer
28
+ # These two layers map the transformer width into a fixed dimension
29
+ class LinearHead(nn.Linear):
30
+ pass
31
+
32
+ class QueryHead(nn.Linear):
33
+ pass
34
+
35
+ # based on https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L163
36
+ def init_transformer(m):
37
+ if isinstance(m, (nn.Linear, nn.Embedding)):
38
+ torch.nn.init.trunc_normal_(m.weight, std=.02)
39
+ if isinstance(m, nn.Linear) and m.bias is not None:
40
+ torch.nn.init.constant_(m.bias, 0)
41
+ elif isinstance(m, nn.LayerNorm):
42
+ torch.nn.init.constant_(m.bias, 0)
43
+ torch.nn.init.constant_(m.weight, 1.0)
44
+
45
+ # %% ../nbs/A. Neural modules.ipynb 4
46
+ def sinusoids(length, channels, max_timescale=10000):
47
+ """Returns sinusoids for positional embedding"""
48
+ assert channels % 2 == 0
49
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
50
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
51
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
52
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
53
+
54
+ # %% ../nbs/A. Neural modules.ipynb 5
55
+ class MultiHeadAttention(nn.Module):
56
+ def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False, cross=False):
57
+ super().__init__()
58
+ self.n_state = n_state
59
+ self.n_head = n_head
60
+ self.sqrt_qk_scale = math.sqrt(qk_scale)
61
+ self.query = QueryHead(n_state, n_state)
62
+ self.key = nn.Linear(n_state, n_state, bias=False)
63
+ self.value = nn.Linear(n_state, n_state)
64
+ self.out = nn.Linear(n_state, n_state)
65
+ self.cross = cross
66
+ self.query_subsampling = 1
67
+ self.key_subsampling = 1
68
+
69
+ self.cached_kvx = None
70
+ self.register_buffer('k_cache', None)
71
+ self.register_buffer('v_cache', None)
72
+
73
+ self.rotary = None
74
+ if rope:
75
+ self.rotary = Rotary(n_state // n_head)
76
+ self.qkv = None
77
+ self.kv = None
78
+
79
+ def setup_kv_cache(self, max_batch_size, max_seq_len, dtype=torch.float32):
80
+ cache_shape = (max_batch_size, self.n_head, max_seq_len, self.n_state//self.n_head)
81
+ self.k_cache = torch.zeros(cache_shape, dtype=dtype, device=self.key.weight.device)
82
+ self.v_cache = torch.zeros(cache_shape, dtype=dtype, device=self.value.weight.device)
83
+
84
+ def merge_linears(self, layers, mults):
85
+ bias = [x.bias for x in layers if x.bias is not None][0]
86
+ din, dout = layers[0].weight.shape
87
+ new = nn.Linear(din, len(layers) * dout).to(layers[0].weight.device)
88
+ with torch.no_grad():
89
+ new.weight[:] = torch.cat([x.weight * m for x,m in zip(layers, mults)])
90
+ new.bias[:] = torch.cat([torch.zeros_like(bias) if x.bias is None else x.bias * m for x, m in zip(layers, mults)])
91
+ return new
92
+
93
+ def convert_for_eval(self):
94
+ if self.qkv or self.kv: raise AttributeError("already converted")
95
+
96
+ self.odim = self.key.weight.shape[1]
97
+ if self.cross:
98
+ self.q = self.merge_linears([self.query], [self.sqrt_qk_scale])
99
+ self.kv = self.merge_linears([self.key, self.value],
100
+ [self.sqrt_qk_scale, 1])
101
+ else:
102
+ self.qkv = self.merge_linears([self.query, self.key, self.value],
103
+ [self.sqrt_qk_scale, self.sqrt_qk_scale, 1])
104
+
105
+ def split_heads(self, x, x_positions, rope=False, subsampling=1):
106
+ x = x.view(*x.shape[:2], self.n_head, -1)
107
+ if rope:
108
+ x = rope_rotate(x, x_positions * subsampling, *self.rotary(x))
109
+ return x.permute(0, 2, 1, 3)
110
+
111
+ def forward(
112
+ self,
113
+ qx,
114
+ q_positions,
115
+ kvx,
116
+ kv_positions,
117
+ causal = False,
118
+ mask=None,
119
+ ):
120
+ if self.qkv:
121
+ q,k,v = self.qkv(qx).split(self.odim, dim=-1)
122
+ elif self.kv:
123
+ q = self.q(qx)
124
+ k,v = self.kv(kvx).split(self.odim, dim=-1)
125
+ else:
126
+ q,k,v = None,None,None
127
+
128
+ if q is None: q = self.query(qx) * self.sqrt_qk_scale
129
+ q = self.split_heads(q, q_positions, rope = self.rotary, subsampling = self.query_subsampling)
130
+
131
+ if kvx is not self.cached_kvx:
132
+ if k is None: k = self.key(kvx) * self.sqrt_qk_scale
133
+ k = self.split_heads(k, kv_positions, rope = self.rotary, subsampling = self.key_subsampling)
134
+ if v is None: v = self.value(kvx)
135
+ v = self.split_heads(v, kv_positions)
136
+ if self.k_cache is not None:
137
+ self.k_cache[:,:,kv_positions] = k
138
+ self.v_cache[:,:,kv_positions] = v
139
+
140
+ if self.k_cache is not None:
141
+ k, v = self.k_cache, self.v_cache
142
+
143
+ if mask is not None:
144
+ mask = mask[q_positions]
145
+
146
+ wv = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0, is_causal=causal)
147
+
148
+ return self.out(wv.permute(0, 2, 1, 3).flatten(start_dim=2))
149
+
150
+ # %% ../nbs/A. Neural modules.ipynb 6
151
+ # modified from https://blog.eleuther.ai/rotary-embeddings/
152
+
153
+ import torch
154
+
155
+ class Rotary(torch.nn.Module):
156
+ def __init__(self, dim, base=10000):
157
+ super().__init__()
158
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
159
+ self.register_buffer("inv_freq", inv_freq)
160
+ self.seq_len_cached = None
161
+ self.cos_cached = None
162
+ self.sin_cached = None
163
+
164
+ def forward(self, x, seq_dim=1):
165
+ seq_len = x.shape[seq_dim]
166
+ if not self.seq_len_cached or seq_len > self.seq_len_cached:
167
+ self.seq_len_cached = 2500
168
+ # self.seq_len_cached = seq_len
169
+
170
+ t = torch.arange(self.seq_len_cached, device=x.device).type_as(self.inv_freq)
171
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
172
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
173
+ self.cos_cached = emb.cos()[None, :, None, :]
174
+ self.sin_cached = emb.sin()[None, :, None, :]
175
+ return self.cos_cached, self.sin_cached
176
+
177
+
178
+ # rotary pos emb helpers:
179
+ def rotate_half(x):
180
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
181
+ return torch.cat(
182
+ (-x2, x1), dim=len(x.shape)-1
183
+ )
184
+
185
+ def rope_rotate(x, positions, cos, sin):
186
+ return x * cos[:,positions] + rotate_half(x) * sin[:,positions]
187
+
188
+ # %% ../nbs/A. Neural modules.ipynb 7
189
+ class ResidualAttentionBlock(nn.Module):
190
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False,
191
+ qk_scale: float = 1, ffn_mult: int = 4):
192
+ super().__init__()
193
+ self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope)
194
+ self.attn_ln = LayerNorm(n_state)
195
+
196
+ self.cross_attn = (
197
+ MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope, cross=True) if cross_attention else None
198
+ )
199
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
200
+
201
+ n_mlp = n_state * ffn_mult
202
+ self.mlp = nn.Sequential(
203
+ nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
204
+ )
205
+ self.mlp_ln = LayerNorm(n_state)
206
+
207
+ def setup_kv_cache(self, max_batch_size, max_seq_len, max_cross_seq_len=None):
208
+ self.attn.setup_kv_cache(max_batch_size, max_seq_len)
209
+ if self.cross_attn:
210
+ self.cross_attn.setup_kv_cache(max_batch_size, max_cross_seq_len)
211
+
212
+ def forward(
213
+ self,
214
+ x: Tensor,
215
+ x_positions: Tensor = None,
216
+ xa: Optional[Tensor] = None,
217
+ xa_positions: Optional[Tensor] = None,
218
+ causal = False,
219
+ mask=None,
220
+ ):
221
+ lnx = self.attn_ln(x)
222
+ x = x + self.attn(lnx, x_positions, lnx, x_positions, causal=causal, mask=mask)
223
+ if self.cross_attn:
224
+ lnx = self.cross_attn_ln(x)
225
+ x = x + self.cross_attn(lnx, x_positions, xa, xa_positions)
226
+ x = x + self.mlp(self.mlp_ln(x))
227
+ return x
228
+
229
+ # %% ../nbs/A. Neural modules.ipynb 8
230
+ class BaseDecoder(nn.Module):
231
+ def __init__(self, depth=6, n_head=6, width=384, qk_scale=1, ffn_mult=4, length=2250, rope=False):
232
+ super().__init__()
233
+ self.length = length
234
+ self.width = width
235
+ self.layers = nn.ModuleList([
236
+ ResidualAttentionBlock(
237
+ self.width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope
238
+ ) for _ in range(math.floor(depth))
239
+ ])
240
+
241
+ self.ln_post = LayerNorm(width)
242
+
243
+ mask = torch.empty(length, length).fill_(-torch.inf).triu_(1)
244
+ self.register_buffer("mask", mask, persistent=False)
245
+
246
+ def forward(self, x, x_positions, xenc, xenc_positions):
247
+ for i,l in enumerate(self.layers):
248
+ x = l(x, x_positions, xenc, xenc_positions, causal=False, mask=self.mask)
249
+
250
+ x = self.ln_post(x)
251
+
252
+ return x
253
+
254
+ # %% ../nbs/A. Neural modules.ipynb 9
255
+ class EmbeddingProjector(nn.Linear):
256
+ pass
257
+
258
+ class FlexEmbeddings(nn.Module):
259
+ def __init__(self, codes, width, special_codes=None, frozen_width=None, special_embedding=None, unembed=True):
260
+ super().__init__()
261
+ self.codes = codes
262
+ self.special_codes = special_codes
263
+ if frozen_width is None: frozen_width = width
264
+
265
+ self.main = nn.Embedding(codes, frozen_width or width)
266
+ self.emb_to_hidden = EmbeddingProjector(frozen_width, width) if frozen_width != width else None
267
+ self.hidden_to_emb = EmbeddingProjector(width, frozen_width) if unembed and frozen_width != width else None
268
+ if special_codes:
269
+ self.special = special_embedding or nn.Embedding(special_codes, width)
270
+
271
+ self.register_buffer('merged_in', None)
272
+ self.register_buffer('merged_out', None)
273
+ self.register_buffer('bias_out', None)
274
+
275
+ def set_frozen_embeddings(self, values):
276
+ with torch.no_grad():
277
+ self.main.weight[:] = values
278
+ self.main.lr_scale = 0
279
+
280
+ @torch.no_grad()
281
+ def convert_for_eval(self):
282
+ if not self.special_codes: return
283
+ # in
284
+ main_w = self.main.weight
285
+ if self.emb_to_hidden is not None: main_w = self.emb_to_hidden(main_w)
286
+ weight = torch.cat([main_w, self.special.weight], dim=0)
287
+ self.merged_in = nn.Embedding(*weight.shape, _weight=weight)
288
+
289
+ # out
290
+ weight = self.main.weight
291
+ if self.hidden_to_emb: weight = weight @ self.hidden_to_emb.weight
292
+ self.merged_out = torch.cat([weight.T, self.special.weight.T], dim=1).T.contiguous() # T is for F.linear
293
+ if self.hidden_to_emb:
294
+ self.bias_out = torch.cat([
295
+ self.hidden_to_emb.bias @ self.main.weight.T,
296
+ torch.zeros(self.special.weight.shape[0], device=weight.device, dtype=weight.dtype)
297
+ ], dim=0)
298
+ else:
299
+ self.bias_out = None
300
+
301
+ def forward(self, toks):
302
+ if not self.training and self.merged_in is not None:
303
+ return self.merged_in(toks)
304
+
305
+ if self.special_codes:
306
+ special_mask = toks >= self.codes
307
+ embs = self.main(torch.where(special_mask, 0, toks))
308
+ else:
309
+ embs = self.main(toks)
310
+
311
+ if self.emb_to_hidden: embs = self.emb_to_hidden(embs)
312
+
313
+ if self.special_codes:
314
+ embs[special_mask] = self.special(toks[special_mask] - self.codes).to(embs.dtype)
315
+
316
+ return embs
317
+
318
+ def unembed(self, embs):
319
+ if not self.training and self.merged_out is not None:
320
+ return F.linear(embs, self.merged_out, self.bias_out) # embs @ self.merged_out + self.bias_out
321
+
322
+ orig_embs = embs
323
+ if self.hidden_to_emb: embs = self.hidden_to_emb(embs)
324
+
325
+ main_logits = (embs @ self.main.weight.to(embs.dtype).T).float()
326
+
327
+ if not self.special_codes:
328
+ return main_logits
329
+
330
+ special_logits = (orig_embs @ self.special.weight.to(orig_embs.dtype).T).float()
331
+ return torch.cat([main_logits, special_logits], dim=-1)
whisperspeech/pipeline.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/7. Pipeline.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['Pipeline']
5
+
6
+ # %% ../nbs/7. Pipeline.ipynb 1
7
+ import torch
8
+ from whisperspeech.t2s_up_wds_mlang_enclm import TSARTransformer
9
+ from whisperspeech.s2a_delar_mup_wds_mlang import SADelARTransformer
10
+ from whisperspeech.a2wav import Vocoder
11
+ import traceback
12
+ from pathlib import Path
13
+
14
+ # %% ../nbs/7. Pipeline.ipynb 2
15
+ class Pipeline:
16
+ default_speaker = torch.tensor(
17
+ [-0.2929, -0.4503, 0.4155, -0.1417, 0.0473, -0.1624, -0.2322, 0.7071,
18
+ 0.4800, 0.5496, 0.0410, 0.6236, 0.4729, 0.0587, 0.2194, -0.0466,
19
+ -0.3036, 0.0497, 0.5028, -0.1703, 0.5039, -0.6464, 0.3857, -0.7350,
20
+ -0.1605, 0.4808, 0.5397, -0.4851, 0.1774, -0.8712, 0.5789, 0.1785,
21
+ -0.1417, 0.3039, 0.4232, -0.0186, 0.2685, 0.6153, -0.3103, -0.5706,
22
+ -0.4494, 0.3394, -0.6184, -0.3617, 1.1041, -0.1178, -0.1885, 0.1997,
23
+ 0.5571, -0.2906, -0.0477, -0.4048, -0.1062, 1.4779, 0.1639, -0.3712,
24
+ -0.1776, -0.0568, -0.6162, 0.0110, -0.0207, -0.1319, -0.3854, 0.7248,
25
+ 0.0343, 0.5724, 0.0670, 0.0486, -0.3813, 0.1738, 0.3017, 1.0502,
26
+ 0.1550, 0.5708, 0.0366, 0.5093, 0.0294, -0.7091, -0.8220, -0.1583,
27
+ -0.2343, 0.1366, 0.7372, -0.0631, 0.1505, 0.4600, -0.1252, -0.5245,
28
+ 0.7523, -0.0386, -0.2587, 1.0066, -0.2037, 0.1617, -0.3800, 0.2790,
29
+ 0.0184, -0.5111, -0.7291, 0.1627, 0.2367, -0.0192, 0.4822, -0.4458,
30
+ 0.1457, -0.5884, 0.1909, 0.2563, -0.2035, -0.0377, 0.7771, 0.2139,
31
+ 0.3801, 0.6047, -0.6043, -0.2563, -0.0726, 0.3856, 0.3217, 0.0823,
32
+ -0.1302, 0.3287, 0.5693, 0.2453, 0.8231, 0.0072, 1.0327, 0.6065,
33
+ -0.0620, -0.5572, 0.5220, 0.2485, 0.1520, 0.0222, -0.2179, -0.7392,
34
+ -0.3855, 0.1822, 0.1042, 0.7133, 0.3583, 0.0606, -0.0424, -0.9189,
35
+ -0.4882, -0.5480, -0.5719, -0.1660, -0.3439, -0.5814, -0.2542, 0.0197,
36
+ 0.4942, 0.0915, -0.0420, -0.0035, 0.5578, 0.1051, -0.0891, 0.2348,
37
+ 0.6876, -0.6685, 0.8215, -0.3692, -0.3150, -0.0462, -0.6806, -0.2661,
38
+ -0.0308, -0.0050, 0.6756, -0.1647, 1.0734, 0.0049, 0.4969, 0.0259,
39
+ -0.8949, 0.0731, 0.0886, 0.3442, -0.1433, -0.6804, 0.2204, 0.1859,
40
+ 0.2702, 0.1699, -0.1443, -0.9614, 0.3261, 0.1718, 0.3545, -0.0686]
41
+ )
42
+
43
+ def __init__(self, t2s_ref=None, s2a_ref=None, optimize=True, torch_compile=False):
44
+ args = dict()
45
+ try:
46
+ if t2s_ref:
47
+ args["ref"] = t2s_ref
48
+ self.t2s = TSARTransformer.load_model(**args).cuda()
49
+ if optimize: self.t2s.optimize(torch_compile=torch_compile)
50
+ except:
51
+ print("Failed to load the T2S model:")
52
+ print(traceback.format_exc())
53
+ try:
54
+ if s2a_ref:
55
+ args["ref"] = s2a_ref
56
+ self.s2a = SADelARTransformer.load_model(**args).cuda()
57
+ if optimize: self.s2a.optimize(torch_compile=torch_compile)
58
+ except:
59
+ print("Failed to load the S2A model:")
60
+ print(traceback.format_exc())
61
+ self.vocoder = Vocoder()
62
+ self.encoder = None
63
+
64
+ def extract_spk_emb(self, fname):
65
+ """Extracts a speaker embedding from the first 30 seconds of the give audio file.
66
+ """
67
+ import torchaudio
68
+ if self.encoder is None:
69
+ from speechbrain.pretrained import EncoderClassifier
70
+ self.encoder = EncoderClassifier.from_hparams("speechbrain/spkrec-ecapa-voxceleb",
71
+ savedir="~/.cache/speechbrain/",
72
+ run_opts={"device": "cuda"})
73
+ samples, sr = torchaudio.load(fname)
74
+ samples = self.encoder.audio_normalizer(samples[0,:30*sr], sr)
75
+ spk_emb = self.encoder.encode_batch(samples)
76
+ return spk_emb[0,0]
77
+
78
+ def generate_atoks(self, text, speaker=None, lang='en', cps=15, step_callback=None):
79
+ if speaker is None: speaker = self.default_speaker
80
+ elif isinstance(speaker, (str, Path)): speaker = self.extract_spk_emb(speaker)
81
+ text = text.replace("\n", " ")
82
+ stoks = self.t2s.generate(text, cps=cps, lang=lang, step=step_callback)
83
+ atoks = self.s2a.generate(stoks, speaker.unsqueeze(0), step=step_callback)
84
+ return atoks
85
+
86
+ def generate(self, text, speaker=None, lang='en', cps=15, step_callback=None):
87
+ return self.vocoder.decode(self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=step_callback))
88
+
89
+ def generate_to_file(self, fname, text, speaker=None, lang='en', cps=15, step_callback=None):
90
+ self.vocoder.decode_to_file(fname, self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=None))
91
+
92
+ def generate_to_notebook(self, text, speaker=None, lang='en', cps=15, step_callback=None):
93
+ self.vocoder.decode_to_notebook(self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=None))
whisperspeech/prepare_s2a_dataset.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4A. S2A dataset preparation.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['flac_to_s2a_name']
5
+
6
+ # %% ../nbs/4A. S2A dataset preparation.ipynb 2
7
+ import sys
8
+ import os
9
+ import itertools
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torchaudio
15
+ import torch.nn.functional as F
16
+ from torch.profiler import profile, record_function, ProfilerActivity
17
+
18
+ from fastprogress import progress_bar
19
+ from fastcore.script import *
20
+
21
+ import whisper
22
+ from . import vad, wh_transcribe, vq_stoks, extract_acoustic
23
+ import webdataset as wds
24
+
25
+ # %% ../nbs/4A. S2A dataset preparation.ipynb 4
26
+ def flac_to_s2a_name(input):
27
+ if '-flac-' in input:
28
+ return input.rsplit("/", 1)[1].replace('flac', 's2a') + ".gz"
29
+ else:
30
+ return input.rsplit("/", 1)[1].replace('raw', 's2a') + ".gz"
31
+
32
+ # %% ../nbs/4A. S2A dataset preparation.ipynb 6
33
+ def resampler(newsr = 24000, key = 'samples_24k'):
34
+ _last_sr = None
35
+ tform = None
36
+
37
+ def _resample(samples):
38
+ for s in samples:
39
+ sr = s['sample_rate']
40
+ if sr != newsr:
41
+ if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr)
42
+ s[key] = tform(s['samples'])
43
+ else:
44
+ s[key] = s['samples']
45
+ yield s
46
+
47
+ return _resample
48
+
49
+ # %% ../nbs/4A. S2A dataset preparation.ipynb 9
50
+ @call_parse
51
+ def prepare_s2a(
52
+ input:str, # FLAC webdataset file path (or - to read the names from stdin)
53
+ proc_dataset_path:Path, # processed VAD files path
54
+ output:str=None, # output file name
55
+ vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", # the model path (use repo_id:filename to download it from hugginface)
56
+ n_samples:int=None, # process a limited amount of samples
57
+ batch_size:int=1, # process several segments at once
58
+ fix_dots:bool=False, # fix dots in file names
59
+ ):
60
+ if ":" in vq_model:
61
+ repo, fname = vq_model.split(":", 1)
62
+ vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda()
63
+ else:
64
+ vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda()
65
+ amodel = extract_acoustic.load_model()
66
+ amodel.set_target_bandwidth(3)
67
+
68
+ if input == "-":
69
+ input = [f.strip() for f in sys.stdin.readlines()]
70
+ assert output, "please provide the output shard name"
71
+ else:
72
+ if output is None: output = flac_to_s2a_name(input)
73
+ input = [input]
74
+
75
+ total = n_samples//batch_size if n_samples else 'noinfer'
76
+
77
+ ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names if fix_dots else None).compose(
78
+ wds.decode(wds.torch_audio),
79
+ wds.select(lambda x: 'wav' in x or 'flac' in x),
80
+ vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')),
81
+ wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}),
82
+ lambda x: wh_transcribe.split_to_chunks(x),
83
+ resampler(),
84
+ resampler(16000, 'samples_16k'),
85
+ wds.to_tuple('__key__', 'rpad_s', 'samples_16k', 'samples_24k'),
86
+ wds.batched(64),
87
+ )
88
+
89
+ dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size)
90
+
91
+ speakers = set()
92
+ tmp = output+".tmp"
93
+ with wds.TarWriter(tmp) as sink:
94
+ for keys, rpad_ss, samples, samples24k in progress_bar(dl, total=total):
95
+ with record_function('to_cuda'):
96
+ samples, samples24k = samples.cuda(), samples24k.unsqueeze(1).cuda()
97
+ with record_function('encodec'):
98
+ atoks = amodel.encode(samples24k)[0][0]
99
+ with record_function('vq_stoks'):
100
+ stoks = vq_model.encode_audio(samples)
101
+ with record_function('from_cuda'):
102
+ atoks, stoks = atoks.cpu().numpy().astype(np.int16), stoks.cpu().numpy().astype(np.int16)
103
+ for key, rpad_s, _atoks, _stoks in zip(keys, rpad_ss, atoks, stoks):
104
+ speakers.add(key.split('/')[1])
105
+ sink.write({
106
+ "__key__": key,
107
+ "atoks.npy": _atoks[:,:int(-rpad_s * 75)],
108
+ "stoks.npy": _stoks[:int(-rpad_s * 25)],
109
+ })
110
+ with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers))
111
+ if not n_samples:
112
+ os.rename(tmp, output)
whisperspeech/prepare_t2s_dataset.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/5A. T2S dataset preparation.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = []
5
+
6
+ # %% ../nbs/5A. T2S dataset preparation.ipynb 2
7
+ import sys
8
+ import os
9
+ import itertools
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torchaudio
15
+ import torch.nn.functional as F
16
+ from torch.profiler import profile, record_function, ProfilerActivity
17
+
18
+ from fastprogress import progress_bar
19
+ from fastcore.script import *
20
+
21
+ import whisper, whisperx
22
+ from . import vad, wh_transcribe, vq_stoks, extract_acoustic
23
+ import webdataset as wds
24
+
25
+ # %% ../nbs/5A. T2S dataset preparation.ipynb 4
26
+ def flac_to_t2s_name(input):
27
+ return input.rsplit("/", 1)[1].replace('flac', 't2s') + ".gz"
28
+
29
+ # %% ../nbs/5A. T2S dataset preparation.ipynb 6
30
+ class Transcriber:
31
+ """
32
+ A helper class to transcribe a batch of 30 second audio chunks.
33
+ """
34
+ def __init__(self, model_size, lang=False):
35
+ self.model = whisperx.asr.load_model(model_size, "cuda", compute_type="float16", language=lang)
36
+ # without calling vad_model at least once the rest segfaults for some reason...
37
+ self.model.vad_model({"waveform": torch.zeros(1, 16000), "sample_rate": 16000})
38
+
39
+ def transcribe(self, batch):
40
+ batch = whisper.log_mel_spectrogram(batch)
41
+ embs = self.model.model.encode(batch.cpu().numpy())
42
+ return self.model.tokenizer.tokenizer.decode_batch([x.sequences_ids[0] for x in
43
+ self.model.model.model.generate(
44
+ embs,
45
+ [self.model.model.get_prompt(self.model.tokenizer, [], without_timestamps=True)]*len(batch),
46
+ )])
47
+
48
+ # %% ../nbs/5A. T2S dataset preparation.ipynb 7
49
+ @call_parse
50
+ def prepare_t2s(
51
+ input:str, # FLAC webdataset file path (or - to read the names from stdin)
52
+ proc_dataset_path:Path, # processed VAD files path
53
+ output:str=None, # output file name
54
+ vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", # the model path (use repo_id:filename to download it from hugginface)
55
+ n_samples:int=None, # process a limited amount of samples
56
+ batch_size:int=1, # process several segments at once
57
+ transcription_model:str="small.en",
58
+ ):
59
+ if ":" in vq_model:
60
+ repo, fname = vq_model.split(":", 1)
61
+ vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda()
62
+ else:
63
+ vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda()
64
+ transcriber = Transcriber(transcription_model)
65
+
66
+ if input == "-":
67
+ input = [f.strip() for f in sys.stdin.readlines()]
68
+ assert output, "please provide the output shard name"
69
+ else:
70
+ if output is None: output = flac_to_t2s_name(input)
71
+ input = [input]
72
+
73
+ total = n_samples//batch_size if n_samples else 'noinfer'
74
+ if n_samples: print(f"Benchmarking run of {n_samples} samples ({total} batches)")
75
+
76
+ ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names).compose(
77
+ wds.decode(wds.torch_audio),
78
+ vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')),
79
+ wds.map_dict(**{"vad.npy": lambda s: wh_transcribe.chunk_merger(s, wh_transcribe.random_cutter)}),
80
+ lambda x: wh_transcribe.split_to_chunks(x),
81
+ # drop the first and last segment because they tend to be inaccurate
82
+ # (the transcriptions don't have the "LibriVox" header and "end of chapter" suffix)
83
+ wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']),
84
+ wds.to_tuple('__key__', 'rpad', 'samples'),
85
+ wds.batched(64),
86
+ )
87
+
88
+ dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size)
89
+
90
+ speakers = set()
91
+ tmp = output+".tmp"
92
+ with wds.TarWriter(tmp) as sink:
93
+ for keys, rpads, samples in progress_bar(dl, total=total):
94
+ with record_function('to_cuda'):
95
+ csamples = samples.cuda()
96
+ with record_function('transcribe'):
97
+ txts = transcriber.transcribe(csamples)
98
+ with record_function('vq_stoks'):
99
+ stoks = vq_model.encode_audio(csamples)
100
+ with record_function('from_cuda'):
101
+ stoks = stoks.cpu().numpy().astype(np.int16)
102
+ for key, rpad, txt, _stoks in zip(keys, rpads, txts, stoks):
103
+ speakers.add(key.split('/')[1])
104
+ sink.write({
105
+ "__key__": key,
106
+ "txt": txt,
107
+ "stoks.npy": _stoks[:int(-rpad/16000 * 25)],
108
+ })
109
+ with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers))
110
+ if not n_samples:
111
+ os.rename(tmp, output)
whisperspeech/s2a_delar_mup_wds.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4B. Semantic to acoustic token modeling.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['load_datasets', 'CMLMVisual', 'Rotary', 'rotate_half', 'apply_rotary_pos_emb', 'ResidualAttentionBlock',
5
+ 'MultiHeadAttention', 'DelSumDecoder', 'EmbeddingProjector', 'rand', 'Tunables', 'SADelARTransformer']
6
+
7
+ # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 1
8
+ import io
9
+ import time
10
+ import math
11
+ import random
12
+ import dataclasses
13
+
14
+ # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 2
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.profiler import profile, record_function, ProfilerActivity, schedule
19
+ from fastcore.basics import store_attr
20
+ from huggingface_hub import hf_hub_download
21
+
22
+ # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 3
23
+ from pathlib import Path
24
+ import json
25
+ from fastprogress import progress_bar, master_bar
26
+ import webdataset as wds
27
+
28
+ # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 4
29
+ from .train import *
30
+ from .modules import *
31
+ from . import vq_stoks
32
+
33
+ # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 8
34
+ def rand(start, end):
35
+ return random.random() * (end - start) + start
36
+
37
+ # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 9
38
+ def random_trunc(random_trunc_p, atoks_len = 2250, stoks_len = 750):
39
+ atoks_per_second = atoks_len / 30
40
+ def _trunc(samples):
41
+ for s in samples:
42
+ if random.random() < random_trunc_p:
43
+ seconds = rand(0.3, 30)
44
+ s['atoks.npy'] = s['atoks.npy'][:,:math.ceil(seconds * atoks_per_second)]
45
+ s['stoks.npy'] = s['stoks.npy'][:math.ceil(s['atoks.npy'].shape[-1]/atoks_len*stoks_len)]
46
+ yield s
47
+ return _trunc
48
+
49
+ def pad_samples(atoks_len = 2250, stoks_len = 750, stoks_pad_token = 4096):
50
+ def _pad(samples):
51
+ for s in samples:
52
+ s['stoks.npy'] = F.pad(torch.tensor(s['stoks.npy']), (0, stoks_len - s['stoks.npy'].shape[-1]), value=stoks_pad_token)
53
+ s['atoks.npy'] = F.pad(torch.tensor(s['atoks.npy']), (0, atoks_len - s['atoks.npy'].shape[-1]), value=-100)
54
+ yield s
55
+ return _pad
56
+
57
+ # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 10
58
+ def speaker_id_extractor(speaker_map):
59
+ def _extractor(samples):
60
+ for s in samples:
61
+ s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]])
62
+ yield s
63
+ return _extractor
64
+
65
+ # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 14
66
+ def load_datasets(
67
+ input:str, # webdataset folder
68
+ samples:int, # samples per epoch
69
+ subsample:float=1, # use a fraction of the files
70
+ val_samples:int=512,
71
+ random_trunc_p:float=0,# probability of truncating the input to less than 30 seconds
72
+ stoks_pad_token=4096,
73
+ ):
74
+
75
+ if isinstance(input, (Path, str)):
76
+ path = Path(input)
77
+ if path.is_dir():
78
+ glob = '*-s2a-*.tar.gz'
79
+ else:
80
+ glob = path.name
81
+ path = path.parent
82
+ input = Path(path).glob(glob)
83
+ elif isinstance(input, list):
84
+ pass
85
+ else:
86
+ raise ArgumentError("input should be either a list or a path with an optional glob specifier")
87
+ shards = [str(x) for x in input]
88
+
89
+ speakers = set()
90
+ for shard in shards:
91
+ with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines()))
92
+ speakers = {id:i for i,id in enumerate(sorted(speakers))}
93
+
94
+ def ds(shards, length):
95
+ ds = wds.WebDataset(wds.ResampledShards(shards)).compose(
96
+ wds.decode(),
97
+ speaker_id_extractor(speakers),
98
+ random_trunc(random_trunc_p) if random_trunc_p > 0 else lambda x: x,
99
+ pad_samples(stoks_pad_token=stoks_pad_token),
100
+ wds.to_tuple('stoks.npy', 'atoks.npy', 'speaker'),
101
+ wds.batched(64),
102
+ )
103
+ ds.speakers = speakers
104
+ ds.total_samples = length
105
+ return ds.compose(wds.slice(length // 64)).with_epoch(length // 64).with_length(length // 64)
106
+
107
+ return (
108
+ ds(shards[1:], samples),
109
+ ds(shards[:1], val_samples),
110
+ )
111
+
112
+ # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 33
113
+ import pylab as plt
114
+ import fastprogress
115
+ import IPython
116
+ import numpy as np
117
+
118
+ class CMLMVisual:
119
+ """Visualize training progress"""
120
+ def __init__ (self, model, masterbar, total_steps):
121
+ self.model = model
122
+ self.masterbar = masterbar
123
+ self.total_steps = total_steps
124
+ self.epochs = total_steps // masterbar.main_bar.total
125
+
126
+ gs = plt.GridSpec(3, 1, height_ratios=[2,2,1])
127
+ graph_fig = plt.figure(figsize=(10,6))
128
+ self.graph_fig = graph_fig
129
+ self.loss_p = graph_fig.add_subplot(gs[0])
130
+ self.acc_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p)
131
+ self.acc_p.tick_params('x', labelbottom=False)
132
+ self.lr_p = graph_fig.add_subplot(gs[2], sharex=self.loss_p)
133
+ self.lr_p.tick_params('x', labelbottom=False)
134
+ self.graph_out = None
135
+
136
+ self.its = []
137
+ self.train_losses = []
138
+ self.val_losses = []
139
+ self.lr_history = []
140
+ self.acc = np.nan
141
+ self.acc_history = []
142
+ self.pacc_history = []
143
+
144
+ def show(self):
145
+ self.start_t = time.time()
146
+ self.masterbar.write(["samples", "train", "val", "time"], table=True)
147
+ self.graph_out = display(self.graph_fig, display_id=True)
148
+ self.acc_out = display(IPython.display.HTML(''), display_id=True)
149
+
150
+ def hide(self):
151
+ if self.graph_out is not None:
152
+ self.graph_out.update(IPython.display.HTML(''))
153
+
154
+ def plot(self):
155
+ loss_p, acc_p, lr_p = self.loss_p, self.acc_p, self.lr_p
156
+ loss_p.clear()
157
+ loss_p.plot(self.its, self.train_losses)
158
+ loss_p.plot(self.its, self.val_losses)
159
+ loss_p.set_xlim(0, self.total_steps)
160
+ loss_p.set_yscale('log')
161
+ acc_p.clear()
162
+ for k in self.acc_history[-1].keys():
163
+ acc_p.plot(self.its, [x[k] for x in self.acc_history], ':')
164
+ # acc_p.plot(self.its, np.stack(self.pacc_history), label=range(len(self.pacc_history[0])))
165
+ lr_p.clear()
166
+ lrs = np.array(self.lr_history)
167
+ lr_p.plot(self.its, lrs)
168
+ self.graph_out.update(self.graph_fig)
169
+
170
+ def add_data(self, it, lr, train_loss, val_los):
171
+ self.its.append(it)
172
+ self.train_losses.append(train_loss)
173
+ self.val_losses.append(val_los)
174
+ self.lr_history.append(lr)
175
+ metrics = self.model.get_metrics()
176
+ self.acc_history.append(metrics)
177
+ # self.acc_out.update(f"Accuracy: {self.entropy_history[-1]:.2f}")
178
+ # self.pacc_history.append((self.model.pval_true / self.model.pval_total).cpu().numpy())
179
+ # if self.acc_history:
180
+ html = "<h5>Accuracies:</h5><table>"
181
+ html += "<thead>"+(''.join([f"<td>{k}<td>" for k,x in metrics.items()]))+"</thead>"
182
+ html += "<tr>"+(''.join([f"<td>{x*100:.1f}%<td>" for k,x in metrics.items()]))+"</tr>"
183
+ html += "</table>"
184
+ self.acc_out.update(IPython.display.HTML(html))
185
+ self.plot()
186
+
187
+ def add_table_row(self, it, avg_train_loss, val_loss):
188
+ elapsed_t = time.time() - self.start_t
189
+ self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True)
190
+
191
+ def on_iter(self, bar, it, avg_train_loss, val_loss):
192
+ epoch = math.ceil(it / self.total_steps * self.epochs)
193
+ bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}"
194
+
195
+ # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 34
196
+ # modified from https://blog.eleuther.ai/rotary-embeddings/
197
+ import torch
198
+
199
+ class Rotary(torch.nn.Module):
200
+ def __init__(self, dim, base=10000):
201
+ super().__init__()
202
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
203
+ self.register_buffer("inv_freq", inv_freq)
204
+ self.seq_len_cached = None
205
+ self.cos_cached = None
206
+ self.sin_cached = None
207
+
208
+ def forward(self, x, seq_dim=1):
209
+ seq_len = x.shape[seq_dim]
210
+ if seq_len != self.seq_len_cached:
211
+ self.seq_len_cached = seq_len
212
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
213
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
214
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
215
+ self.cos_cached = emb.cos()[None, :, None, :]
216
+ self.sin_cached = emb.sin()[None, :, None, :]
217
+ return self.cos_cached, self.sin_cached
218
+
219
+
220
+ # rotary pos emb helpers:
221
+ def rotate_half(x):
222
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
223
+ return torch.cat(
224
+ (-x2, x1), dim=-1
225
+ )
226
+
227
+ #@torch.jit.script
228
+ def apply_rotary_pos_emb(q, k, cos, sin):
229
+ return (q * cos[:,:q.shape[1]]) + (rotate_half(q) * sin[:,:q.shape[1]]), (k * cos) + (rotate_half(k) * sin)
230
+
231
+ # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 35
232
+ from torch import Tensor, nn
233
+ import torch.nn.functional as F
234
+ from typing import Dict, Iterable, Optional
235
+
236
+ class ResidualAttentionBlock(nn.Module):
237
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False,
238
+ qk_scale: float = 1, ffn_mult: int = 4):
239
+ super().__init__()
240
+
241
+ self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope)
242
+ self.attn_ln = LayerNorm(n_state)
243
+
244
+ self.cross_attn = (
245
+ MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope) if cross_attention else None
246
+ )
247
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
248
+
249
+ n_mlp = n_state * ffn_mult
250
+ self.mlp = nn.Sequential(
251
+ nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
252
+ )
253
+ self.mlp_ln = LayerNorm(n_state)
254
+
255
+ def forward(
256
+ self,
257
+ x: Tensor,
258
+ xa: Optional[Tensor] = None,
259
+ causal = False,
260
+ kv_cache: Optional[dict] = None,
261
+ ):
262
+ x = x + self.attn(self.attn_ln(x), causal=causal, kv_cache=kv_cache)[0]
263
+ if self.cross_attn:
264
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
265
+ x = x + self.mlp(self.mlp_ln(x))
266
+ return x
267
+
268
+ class MultiHeadAttention(nn.Module):
269
+ def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False):
270
+ super().__init__()
271
+ self.n_head = n_head
272
+ self.sqrt_qk_scale = math.sqrt(qk_scale)
273
+ self.query = QueryHead(n_state, n_state)
274
+ self.key = nn.Linear(n_state, n_state, bias=False)
275
+ self.value = nn.Linear(n_state, n_state)
276
+ self.out = nn.Linear(n_state, n_state)
277
+
278
+ self.rotary = None
279
+ if rope:
280
+ self.rotary = Rotary(n_state // n_head)
281
+
282
+ def forward(
283
+ self,
284
+ x: Tensor,
285
+ xa: Optional[Tensor] = None,
286
+ causal = False,
287
+ kv_cache: Optional[dict] = None,
288
+ ):
289
+ q = self.query(x)
290
+
291
+ if kv_cache is None or xa is None or self.key not in kv_cache:
292
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
293
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
294
+ k = self.key(x if xa is None else xa)
295
+ v = self.value(x if xa is None else xa)
296
+ else:
297
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
298
+ k = kv_cache[self.key]
299
+ v = kv_cache[self.value]
300
+
301
+ if self.sqrt_qk_scale != 1:
302
+ q *= self.sqrt_qk_scale
303
+ k *= self.sqrt_qk_scale
304
+
305
+ wv, qk = self.qkv_attention_pth20(q, k, v, causal)
306
+ # wv, qk = self.qkv_attention_xformers(q, k, v, causal)
307
+
308
+ return self.out(wv), qk
309
+
310
+ def qkv_attention_pth20(
311
+ self, q: Tensor, k: Tensor, v: Tensor, causal = False
312
+ ):
313
+ n_batch, n_ctx, n_state = q.shape
314
+ q = q.view(*q.shape[:2], self.n_head, -1)
315
+ k = k.view(*k.shape[:2], self.n_head, -1)
316
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
317
+
318
+ #print('before rot:', q.shape, k.shape)
319
+ if self.rotary:
320
+ q, k = apply_rotary_pos_emb(q, k, *self.rotary(k))
321
+ #print(' after rot:', q.shape, k.shape)
322
+
323
+ k = k.permute(0, 2, 1, 3)
324
+ q = q.permute(0, 2, 1, 3)
325
+ # modified for better performance under PyTorch 2.0
326
+ wv = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=causal)
327
+
328
+ # previously we've returned q@k which we don't have now
329
+ # since it's not actually used anywhere else, let's just keep two return values for compatibility
330
+ return wv.permute(0, 2, 1, 3).flatten(start_dim=2), None
331
+
332
+ def qkv_attention_xformers(
333
+ self, q: Tensor, k: Tensor, v: Tensor, causal = False
334
+ ):
335
+ n_batch, n_ctx, n_state = q.shape
336
+ q = q.view(*q.shape[:2], self.n_head, -1)
337
+ k = k.view(*k.shape[:2], self.n_head, -1)
338
+ v = v.view(*v.shape[:2], self.n_head, -1)
339
+
340
+ if self.rotary:
341
+ q, k = apply_rotary_pos_emb(q, k, *self.rotary(k))
342
+
343
+ bias = xops.LowerTriangularMask() if causal else None
344
+ wv = xops.memory_efficient_attention(q,k,v, attn_bias=bias)
345
+
346
+ # previously we've returned q@k which we don't have now
347
+ # since it's not actually used anywhere else, let's just keep two return values for compatibility
348
+ return wv.flatten(start_dim=2), None
349
+
350
+ # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 36
351
+ class DelSumDecoder(nn.Module):
352
+ def __init__(self, depth=6, n_head=6, head_width=64, qk_scale=1, ffn_mult=4, length=2250, codes=1024, quantizers=8, linear_heads=True, rope=False, pos_embs=None):
353
+ super().__init__()
354
+ self.length = length
355
+ width = n_head * head_width
356
+ self.width = width
357
+ self.codes = codes
358
+ self.quantizers = quantizers
359
+ self.linear_heads = linear_heads
360
+
361
+ self.embeddings = nn.ModuleList([nn.Embedding(codes+1, width) for _ in range(quantizers)])
362
+ if pos_embs is not None:
363
+ self.register_buffer("positional_embedding", pos_embs)
364
+
365
+ self.layers = nn.ModuleList([
366
+ ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope) for _ in range(math.floor(depth))
367
+ ])
368
+
369
+ self.ln_post = LayerNorm(width)
370
+
371
+ if self.linear_heads:
372
+ self.heads = LinearHead(width, (codes+1) * quantizers, bias=False)
373
+ else:
374
+ self.splitter = nn.Sequential(
375
+ nn.Linear(width, width * quantizers),
376
+ nn.GELU(),
377
+ )
378
+ self.heads = nn.ModuleList([
379
+ LinearHead(width, codes+1, bias=True) for _ in range(quantizers)
380
+ ])
381
+
382
+ def forward(self, toks, xenc):
383
+ b,_,n = toks.shape
384
+ newn = min(n+1, self.length)
385
+ embs = torch.zeros((b,newn,self.width), dtype=xenc.dtype, device=xenc.device)
386
+ for i in range(self.quantizers):
387
+ embs[:,:i+1] += self.embeddings[i](torch.tensor([self.codes], device=xenc.device))
388
+ if i < n:
389
+ embs[:,i+1:] += self.embeddings[i](toks[:,i,:newn-i-1])
390
+
391
+ x = embs.to(xenc.dtype)
392
+
393
+ for l in self.layers:
394
+ x = l(x, xenc, causal=True)
395
+ x = self.ln_post(x)
396
+
397
+ if self.linear_heads:
398
+ logits = self.heads(x).view(b,newn,self.quantizers,self.codes+1).permute(0,2,1,3)
399
+ else:
400
+ split = self.splitter(x).view(b,newn,self.quantizers,self.width)
401
+ logits = torch.stack([self.heads[q](split[:,:,q]) for q in range(self.quantizers)], dim=1)
402
+
403
+ return logits
404
+
405
+ class EmbeddingProjector(nn.Linear):
406
+ pass
407
+
408
+ def rand(start, end):
409
+ return random.random() * (end - start) + start
410
+
411
+ @dataclasses.dataclass
412
+ class Tunables:
413
+ init_std :float = 9
414
+ embeddings_std :float = 0.2
415
+ embeddings_lr_scale: float = 10
416
+ output_mult :float = 5.6
417
+ # FIXME: try separate mults for self and cross attention
418
+ query_mult :float = .3
419
+ encoder_depth_ratio :float = 0.25
420
+ linear_heads :bool = False
421
+ rope :bool = True
422
+
423
+ lr0 :float = 3e-3
424
+ clip_gradient_norm :float = 2
425
+ weight_decay :float = 1e-3
426
+ warmup_steps :float = 2000
427
+
428
+ random :bool = False
429
+
430
+ def __post_init__(self):
431
+ # randomize the hyperparams if requested
432
+ if self.random:
433
+ self.init_std = 2*10**rand(0,1)
434
+ self.embeddings_std = 10**rand(-1.7,-0.22)
435
+ self.embeddings_lr_scale = 2**rand(2,4)
436
+ self.output_mult = 2**rand(1.5,3)
437
+ self.query_mult = 2**rand(-3,-1.3)
438
+ self.encoder_depth_ratio = random.choice([0.25,0.5])
439
+ self.linear_heads = False
440
+ self.rope = True
441
+
442
+ self.lr0 = 3e-3
443
+ self.clip_gradient_norm = 10**rand(-1,1)
444
+ self.warmup_steps = 100*(10**rand(1.18,1.3))
445
+
446
+ @staticmethod
447
+ def upgrade(args):
448
+ args = {k:v for k,v in args.items()}
449
+ def old_default(name, value):
450
+ if name not in args: args[name] = value
451
+ old_default('rope', False)
452
+ old_default('linear_heads', True)
453
+ return args
454
+
455
+ class SADelARTransformer(nn.Module):
456
+ def __init__(self, depth=3, ctx_n=2250, stoks_len=750, stoks_codes=4097, stoks_width=None, spk_width=None, n_head=3, head_width=64, ffn_mult=4,
457
+ quantizers=8, speaker_map={"1":0}, tunables=Tunables()):
458
+ super().__init__()
459
+ self.quantizers = quantizers
460
+ width = n_head * head_width
461
+ store_attr("depth,ctx_n,stoks_len,stoks_codes,stoks_width,spk_width,n_head,head_width,ffn_mult,quantizers,speaker_map")
462
+ self.width = width
463
+ self.base_width = 3 * head_width
464
+ self.tunables = tunables
465
+
466
+ if stoks_width is None: stoks_width = width
467
+ if spk_width is None: spk_width = width
468
+ self.emb_factor = width != stoks_width
469
+ self.spk_factor = width != spk_width
470
+
471
+ if tunables.rope:
472
+ self.positional_embeddings = None
473
+ else:
474
+ self.register_buffer('positional_embeddings', sinusoids(ctx_n, width))
475
+
476
+ self.speaker_embedding = nn.Embedding(len(speaker_map), width)
477
+ self.semantic_embedding = nn.Embedding(stoks_codes, stoks_width)
478
+ if self.emb_factor:
479
+ self.emb_to_hidden = nn.Linear(stoks_width, width)
480
+
481
+ if self.spk_factor:
482
+ self.spk_to_hidden = EmbeddingProjector(spk_width, width)
483
+
484
+ qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)
485
+
486
+ encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
487
+ decoder_depth = depth * 2 - encoder_depth
488
+ self.encoder = nn.Sequential(*[
489
+ ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(encoder_depth)
490
+ ])
491
+ self.ln_post = LayerNorm(width)
492
+
493
+ self.decoder = DelSumDecoder(pos_embs=self.positional_embeddings, qk_scale=qk_scale,
494
+ length=ctx_n, n_head=n_head, head_width=head_width, ffn_mult=ffn_mult,
495
+ depth=decoder_depth, quantizers=quantizers,
496
+ linear_heads=tunables.linear_heads, rope=tunables.rope)
497
+
498
+ self.register_buffer('val_true', torch.zeros(self.quantizers).cuda())
499
+ self.register_buffer('val_total', torch.zeros(self.quantizers).cuda())
500
+ self.apply(self.init_transformer)
501
+
502
+ def setup(self, device):
503
+ pass
504
+
505
+ def load_frozen_semantic_embeddings(self, vqmodel):
506
+ with torch.no_grad():
507
+ self.semantic_embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0]
508
+ self.semantic_embedding.lr_scale = 0
509
+
510
+ def init_transformer(self, m):
511
+ if isinstance(m, LinearHead):
512
+ m.no_weight_decay = True
513
+ torch.nn.init.constant_(m.weight, 0)
514
+ elif isinstance(m, QueryHead):
515
+ m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
516
+ torch.nn.init.constant_(m.weight, 0)
517
+ elif isinstance(m, nn.Embedding):
518
+ m.no_weight_decay = True
519
+ m.lr_scale = self.tunables.embeddings_lr_scale
520
+ std = self.tunables.embeddings_std
521
+ torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
522
+ elif isinstance(m, EmbeddingProjector):
523
+ m.lr_scale = self.tunables.embeddings_lr_scale/2
524
+ std = self.tunables.init_std
525
+ torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
526
+ elif isinstance(m, nn.Linear):
527
+ m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
528
+ std = self.tunables.init_std / m.weight.shape[1]
529
+ torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
530
+ if m.bias is not None:
531
+ torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
532
+ elif isinstance(m, nn.LayerNorm):
533
+ m.no_weight_decay = True
534
+ torch.nn.init.constant_(m.bias, 0)
535
+ torch.nn.init.constant_(m.weight, 1)
536
+
537
+ def embed_stoks(self, Stoks):
538
+ b,n = Stoks.shape
539
+ if self.stoks_len == 1500:
540
+ # converts 50 toks/s to 75 toks/s by adding padding between every two tokens
541
+ x = Stoks.reshape(b,n//2,2)
542
+ x = x.repeat_interleave(2, -1)[:,:,:3]
543
+ x[:,:,1] = 1024
544
+ x = x.reshape(b,n//2*3)
545
+ else:
546
+ # it's a lot easier with 25 toks/s
547
+ x = Stoks.repeat_interleave(3, -1)
548
+ # embed semantic tokens
549
+ Sembs = self.semantic_embedding(x.to(torch.long))
550
+ if self.emb_factor:
551
+ Sembs = self.emb_to_hidden(Sembs)
552
+ return Sembs
553
+
554
+ def forward(self, Stoks, Atoks, speakers, noloss=False):
555
+ Atoks = Atoks.to(torch.long)
556
+ semb = self.embed_stoks(Stoks)
557
+ with record_function("encoder"):
558
+ if self.positional_embeddings is not None: semb = semb + self.positional_embeddings
559
+ xenc = self.ln_post(self.encoder(semb))
560
+ # xenc = torch.zeros_like(xenc)
561
+ with record_function("decoder"):
562
+ Atoks_gt = Atoks.clone()
563
+ Atoks_gt[Atoks == -100] = 1024
564
+ # we can randomize speaker ids during validation to measure
565
+ # the importance of the speaker embedding vs. just the acoustic prompt/prefix
566
+ # if not self.training: speakers = speakers[torch.randperm(speakers.nelement())]
567
+ spk_embs = self.speaker_embedding(speakers)
568
+ if self.spk_factor: spk_embs = self.spk_to_hidden(spk_embs)
569
+ logits = self.decoder(Atoks_gt, xenc + spk_embs.unsqueeze(1))
570
+ logits *= self.tunables.output_mult / (self.width / self.base_width)
571
+
572
+ if noloss:
573
+ return logits
574
+
575
+ with record_function("loss"):
576
+ N = Atoks.shape[-1]
577
+ loss = 0
578
+ for i in range(self.quantizers):
579
+ loss += F.cross_entropy(logits[:,i,i:].reshape(-1,logits.shape[-1]), Atoks[:,i,:N-i].reshape(-1))
580
+ loss /= self.quantizers
581
+
582
+ if not self.training:
583
+ for i in range(self.quantizers):
584
+ Atoks_i = Atoks[:,i,:N-i]
585
+ valid_Atoks = Atoks_i != -100
586
+ self.val_true[i] += (logits[:,i,i:].argmax(-1)[valid_Atoks] == Atoks_i[valid_Atoks]).float().sum()
587
+ self.val_total[i] += valid_Atoks.float().sum()
588
+
589
+ return logits, loss
590
+
591
+ def get_metrics(self):
592
+ metrics = {
593
+ f'acc_{i}':x.item() for i,x in enumerate(self.val_true / self.val_total)
594
+ }
595
+ self.val_true[:] = 0
596
+ self.val_total[:] = 0
597
+ return metrics
598
+
599
+ #
600
+ # inference
601
+ #
602
+ @classmethod
603
+ def load_model(cls, repo_id="collabora/whisperspeech", filename="s2a_up_wds.model", local_filename=None):
604
+ if not local_filename:
605
+ local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
606
+ spec = torch.load(local_filename)
607
+ if '_extra_state' not in spec['state_dict']: spec['state_dict']['_extra_state'] = { 'speaker_map': spec['config']['speaker_map'] }
608
+ model = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec['tunables'])))
609
+ model.load_state_dict(spec['state_dict'])
610
+ model.eval()
611
+ return model
612
+
613
+ def get_extra_state(self):
614
+ return { 'speaker_map': self.speaker_map }
615
+
616
+ def set_extra_state(self, st):
617
+ self.speaker_map = st['speaker_map']
618
+
619
+ def load_checkpoint(self, local_filename):
620
+ spec = torch.load(local_filename, map_location='cpu')
621
+ assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
622
+ state_dict = {k.replace('model.', ''):v
623
+ for k,v in spec['state_dict'].items()}
624
+ self.load_state_dict(state_dict)
625
+ return self
626
+
627
+ def save_model(self, fname):
628
+ torch.save(dict(config = self.__stored_args__,
629
+ tunables = dataclasses.asdict(self.tunables),
630
+ state_dict = self.state_dict()), fname)
631
+
632
+ @property
633
+ def device(self):
634
+ return next(self.parameters()).device
635
+
636
+ @torch.no_grad()
637
+ def generate(self, stoks, speakers, N=None, T=0.7, top_k=None, show_progress_bar=True):
638
+ dev = self.device
639
+ if self.stoks_len == 1500:
640
+ N = N or len(stoks) * 3 // 2
641
+ else:
642
+ N = N or len(stoks) * 3
643
+ stoks = F.pad(stoks.to(dev), (0, self.stoks_len - len(stoks)), value=self.stoks_codes-1).unsqueeze(0)
644
+ speakers = torch.tensor([self.speaker_map[spk] for spk in speakers], device=dev)
645
+ toks = torch.zeros((1,self.quantizers,N), dtype=torch.long, device=dev)
646
+ it = range(0,N)
647
+ if show_progress_bar: it = progress_bar(it)
648
+ for i in it:
649
+ p = self(stoks, toks[:,:,:i], speakers, noloss=True)
650
+ last_p = p[0,:,-1]
651
+ if top_k:
652
+ last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf
653
+ for j,tok in enumerate(torch.multinomial((last_p / float(T)).softmax(-1), 1)):
654
+ toks[0,j,max(0,i-j)] = tok
655
+ if toks[0,0,i] == 1024: return toks[0,:,:i]
656
+ return toks[0]
657
+
658
+ # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 37
659
+ def _make_model(size:str, quantizers:int=4, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None, **kwargs):
660
+ assert(dataset is not None)
661
+ kwargs = dict(speaker_map=dataset.speakers, quantizers=quantizers, tunables=tunables, **kwargs)
662
+ if size == 'micro':
663
+ return SADelARTransformer(depth=4, n_head=3, ffn_mult=2, **kwargs)
664
+ if size == 'tiny-narrow':
665
+ return SADelARTransformer(depth=4, n_head=6, ffn_mult=1, **kwargs)
666
+ if size == 'tiny':
667
+ return SADelARTransformer(depth=4, n_head=6, **kwargs)
668
+ if size == 'base':
669
+ return SADelARTransformer(depth=6, n_head=8, **kwargs)
670
+ if size == 'base-deep':
671
+ return SADelARTransformer(depth=9, n_head=8, **kwargs)
672
+ if size == 'base-wide':
673
+ return SADelARTransformer(depth=6, n_head=12, **kwargs)
674
+ if size == 'small/2':
675
+ return SADelARTransformer(depth=9, n_head=12, **kwargs)
676
+ if size == 'small':
677
+ return SADelARTransformer(depth=12, n_head=12, **kwargs)
678
+ if size == 'medium':
679
+ return SADelARTransformer(depth=24, n_head=16, **kwargs)
680
+
681
+ def make_model(size:str, quantizers:int=4, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
682
+ if frozen_embeddings_model:
683
+ vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model)
684
+ model = _make_model(size, quantizers, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
685
+ model.load_frozen_semantic_embeddings(vqmodel)
686
+ else:
687
+ model = _make_model(size, quantizers, tunables, dataset)
688
+ return model
whisperspeech/s2a_delar_mup_wds_mlang.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['load_dataset', 'DelSumEmbedding', 'DelSumHead', 'rand', 'Tunables', 'SADelARTransformer']
5
+
6
+ # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 1
7
+ import io
8
+ import time
9
+ import math
10
+ import random
11
+ import dataclasses
12
+
13
+ # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 2
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import numpy as np
18
+ from torch.profiler import profile, record_function, ProfilerActivity, schedule
19
+ from fastcore.basics import store_attr
20
+ from huggingface_hub import hf_hub_download
21
+
22
+ # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 3
23
+ from pathlib import Path
24
+ import json
25
+ from fastprogress import progress_bar, master_bar
26
+
27
+ # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 4
28
+ from .modules import *
29
+
30
+ # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 8
31
+ def rand(start, end):
32
+ return random.random() * (end - start) + start
33
+
34
+ # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 9
35
+ def random_trunc(random_trunc_p, atoks_len = 2250, stoks_len = 750):
36
+ atoks_per_second = atoks_len / 30
37
+ def _trunc(samples):
38
+ for s in samples:
39
+ if random.random() < random_trunc_p:
40
+ seconds = rand(0.3, 30)
41
+ s['atoks.npy'] = s['atoks.npy'][:,:math.ceil(seconds * atoks_per_second)]
42
+ s['stoks.npy'] = s['stoks.npy'][:math.ceil(s['atoks.npy'].shape[-1]/atoks_len*stoks_len)]
43
+ yield s
44
+ return _trunc
45
+
46
+ def pad_samples(atoks_len = 2250, stoks_len = 750, stoks_pad_token = 4096):
47
+ def _pad(samples):
48
+ for s in samples:
49
+ s['stoks.npy'] = F.pad(torch.tensor(s['stoks.npy']), (1, stoks_len - s['stoks.npy'].shape[-1]-1), value=stoks_pad_token)
50
+ s['out_stoks'] = F.pad(torch.tensor(s['stoks.npy']), (0, stoks_len - s['stoks.npy'].shape[-1]), value=stoks_pad_token)
51
+ s['atoks.npy'] = F.pad(torch.tensor(s['atoks.npy']), (0, atoks_len - s['atoks.npy'].shape[-1]), value=-100)
52
+ yield s
53
+ return _pad
54
+
55
+ # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 10
56
+ def make_speaker_map(shards):
57
+ speakers = set()
58
+ for shard in shards:
59
+ with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines()))
60
+ return {id:i for i,id in enumerate(sorted(speakers))}
61
+
62
+ def speaker_id_extractor(speaker_map):
63
+ def _extractor(samples):
64
+ for s in samples:
65
+ s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]])
66
+ yield s
67
+ return _extractor
68
+
69
+ # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 27
70
+ def load_dataset(
71
+ atoks_shard_spec:str, # webdataset folder
72
+ stoks_shard_dir:str, # stoks webdataset base dir
73
+ samples:int, # samples per epoch
74
+ random_trunc_p:float=0,# probability of truncating the input to less than 30 seconds
75
+ vq_codes:int=4096,
76
+ language:str='en',
77
+ weight:float=1,
78
+ validation:bool=False,
79
+ exclude_files:str=None,
80
+ randomize_speakers:bool=False,
81
+ ):
82
+ import webdataset as wds
83
+ from whisperspeech import utils
84
+
85
+ shards = utils.shard_glob(atoks_shard_spec)
86
+ excludes = {x for file in exclude_files.split() for x in utils.readlines(file)} if exclude_files else set()
87
+
88
+ def check_for_nan(s):
89
+ if torch.tensor(s['spk_emb.npy']).isnan().any(): print("found NaN:", s['__key__'])
90
+ return s
91
+
92
+ def set_language(x):
93
+ x['language'] = language
94
+ return x
95
+
96
+ same_on_all_nodes = lambda urls: urls # will only be used for validation
97
+ ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
98
+ wds.decode(),
99
+ utils.merge_in(utils.derived_dataset('maxvad-stoks', base='atoks-3kbps', suffix='', dir=stoks_shard_dir)),
100
+ wds.map(check_for_nan),
101
+ wds.select(lambda s: s['__key__'] not in excludes),
102
+ wds.map_dict(**{'spk_emb.npy':np.nan_to_num}), # remove nans from the speaker embedding model
103
+ random_trunc(random_trunc_p) if random_trunc_p > 0 else lambda x: x,
104
+ pad_samples(stoks_pad_token=vq_codes-1),
105
+ wds.map(set_language),
106
+ wds.to_tuple('stoks.npy', 'atoks.npy', 'spk_emb.npy', 'language', 'out_stoks'),
107
+ wds.shuffle(20000, initial=20000),
108
+ wds.batched(64),
109
+ )
110
+ if randomize_speakers:
111
+ rng = np.random.default_rng()
112
+ ds = ds.compose(
113
+ wds.map_tuple(None, None, lambda x: rng.permutation(x), None),
114
+ )
115
+ if validation:
116
+ ds = ds.slice(samples // 64)
117
+ ds.total_samples = samples
118
+ ds.weight = weight
119
+
120
+ return ds
121
+
122
+ # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 37
123
+ class DelSumEmbedding(nn.Module):
124
+ def __init__(self, n_head=6, head_width=64, atoks_width=None, length=2250, codes=1024, quantizers=8, pos_embs=None):
125
+ super().__init__()
126
+ self.length = length
127
+ width = n_head * head_width
128
+ if atoks_width is None: atoks_width = width
129
+ self.width = width
130
+ self.quantizers = quantizers
131
+
132
+ emb = None
133
+ embs = []
134
+ for _ in range(quantizers):
135
+ emb = FlexEmbeddings(codes, width, special_codes=2, frozen_width=atoks_width,
136
+ special_embedding=emb and emb.special)
137
+ embs.append(emb)
138
+ self.embeddings = nn.ModuleList(embs)
139
+ if pos_embs is not None:
140
+ self.register_buffer("positional_embedding", pos_embs)
141
+
142
+ def forward(self, toks, xenc):
143
+ with record_function("embeddings"):
144
+ b,_,n = toks.shape
145
+ newn = min(n, self.length)
146
+
147
+ embs = torch.zeros((b,newn,self.width), dtype=xenc.dtype, device=xenc.device)
148
+ for i in range(self.quantizers):
149
+ embs[:, :] += self.embeddings[i](toks[:,i,:])
150
+
151
+ x = embs.to(xenc.dtype)
152
+ return x
153
+
154
+ # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 38
155
+ class DelSumHead(nn.Module):
156
+ def __init__(self, quantizers=8, n_head=6, head_width=64):
157
+ super().__init__()
158
+ self.width = n_head * head_width
159
+ self.quantizers = quantizers
160
+ self.splitter = nn.Sequential(
161
+ nn.Linear(self.width, self.width * quantizers),
162
+ nn.GELU(),
163
+ )
164
+
165
+ def forward(self, x, embeddings=None):
166
+ b, newn, _ = x.shape
167
+ with record_function("splitter"):
168
+ split = self.splitter(x).view(b,newn,self.quantizers,self.width)
169
+ with record_function("unembed"):
170
+ logits = torch.stack([embeddings[q].unembed(split[:,:,q]) for q in range(self.quantizers)], dim=1)
171
+ return logits
172
+
173
+ def rand(start, end):
174
+ return random.random() * (end - start) + start
175
+
176
+ @dataclasses.dataclass
177
+ class Tunables:
178
+ init_std :float = 9
179
+ embeddings_std :float = 0.2
180
+ embeddings_lr_scale: float = 10
181
+ output_mult :float = 5.6
182
+ # FIXME: try separate mults for self and cross attention
183
+ query_mult :float = .3
184
+ encoder_depth_ratio :float = 0.25
185
+ linear_heads :bool = False
186
+ rope :bool = True
187
+
188
+ lr0 :float = 3e-3
189
+ clip_gradient_norm :float = 2
190
+ weight_decay :float = 1e-3
191
+ warmup_steps :float = 2000
192
+
193
+ random :bool = False
194
+
195
+ def __post_init__(self):
196
+ # randomize the hyperparams if requested
197
+ if self.random:
198
+ self.init_std = 2*10**rand(0,1)
199
+ self.embeddings_std = 10**rand(-1.7,-0.22)
200
+ self.embeddings_lr_scale = 2**rand(2,4)
201
+ self.output_mult = 2**rand(1.5,3)
202
+ self.query_mult = 2**rand(-3,-1.3)
203
+ self.encoder_depth_ratio = random.choice([0.25,0.5])
204
+ self.linear_heads = False
205
+ self.rope = True
206
+
207
+ self.lr0 = 3e-3
208
+ self.clip_gradient_norm = 10**rand(-1,1)
209
+ self.warmup_steps = 100*(10**rand(1.18,1.3))
210
+
211
+ @staticmethod
212
+ def upgrade(args):
213
+ args = {k:v for k,v in args.items()}
214
+ def old_default(name, value):
215
+ if name not in args: args[name] = value
216
+ old_default('rope', False)
217
+ old_default('linear_heads', True)
218
+ return args
219
+
220
+ class SADelARTransformer(nn.Module):
221
+ def __init__(self, depth=3, ctx_n=2250,
222
+ stoks_len=750, stoks_codes=4097, stoks_width=None,
223
+ spk_width=None,
224
+ atoks_width=None,
225
+ n_head=3, head_width=64, ffn_mult=4,
226
+ quantizers=8, speaker_map={"1":0}, tunables=Tunables()):
227
+ super().__init__()
228
+ self.quantizers = quantizers
229
+ self.codes = 1024
230
+ width = n_head * head_width
231
+ store_attr("depth,ctx_n,stoks_len,stoks_codes,stoks_width,spk_width,atoks_width,n_head,head_width,ffn_mult,quantizers,speaker_map")
232
+ self.width = width
233
+ self.base_width = 3 * head_width
234
+ self.tunables = tunables
235
+
236
+ if stoks_width is None: stoks_width = width
237
+ if spk_width is None: spk_width = width
238
+ self.emb_factor = width != stoks_width
239
+ self.spk_factor = width != spk_width
240
+
241
+ if tunables.rope:
242
+ self.positional_embeddings = None
243
+ else:
244
+ self.register_buffer('positional_embeddings', sinusoids(ctx_n, width))
245
+
246
+ # self.speaker_embedding = nn.Embedding(len(speaker_map), spk_width)
247
+ self.semantic_embedding = nn.Embedding(stoks_codes, stoks_width)
248
+ if self.emb_factor:
249
+ self.emb_to_hidden = nn.Linear(stoks_width, width)
250
+ self.hidden_to_emb = nn.Linear(width, stoks_width)
251
+
252
+ if self.spk_factor:
253
+ self.spk_to_hidden = nn.Linear(spk_width, width)
254
+
255
+ qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)
256
+
257
+ encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
258
+ decoder_depth = depth * 2 - encoder_depth
259
+ self.encoder = nn.Sequential(*[
260
+ ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(encoder_depth)
261
+ ]) # FIXME: enclm requires causal attention here
262
+ self.ln_post = LayerNorm(width)
263
+
264
+ self.embds = DelSumEmbedding(
265
+ pos_embs=self.positional_embeddings, length=ctx_n,
266
+ n_head=n_head, head_width=head_width, atoks_width=atoks_width,
267
+ quantizers=quantizers,
268
+ )
269
+ self.decoder = BaseDecoder(qk_scale=qk_scale, length=ctx_n,
270
+ n_head=n_head, width=n_head * head_width,
271
+ ffn_mult=ffn_mult, depth=decoder_depth,
272
+ rope=tunables.rope)
273
+ self.head = DelSumHead(n_head=n_head, head_width=head_width, quantizers=quantizers)
274
+ for l in self.decoder.layers:
275
+ l.cross_attn.key_subsampling = 3
276
+ # for l in self.encoder:
277
+ # l.attn.key_subsampling = 3
278
+ # l.attn.query_subsampling = 3
279
+
280
+ self.register_buffer('val_true', torch.zeros(self.quantizers).cuda())
281
+ self.register_buffer('val_total', torch.zeros(self.quantizers).cuda())
282
+ self.apply(self.init_transformer)
283
+
284
+ def setup(self, device):
285
+ pass
286
+
287
+ def load_frozen_semantic_embeddings(self, vqmodel):
288
+ with torch.no_grad():
289
+ self.semantic_embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0]
290
+ self.semantic_embedding.lr_scale = 0
291
+
292
+ def load_frozen_acoustic_embeddings(self, amodel):
293
+ for i in range(self.quantizers):
294
+ self.decoder.embeddings[i].set_frozen_embeddings(amodel.quantizer.vq.layers[i].codebook)
295
+
296
+ def init_transformer(self, m):
297
+ if isinstance(m, LinearHead):
298
+ m.no_weight_decay = True
299
+ torch.nn.init.constant_(m.weight, 0)
300
+ elif isinstance(m, QueryHead):
301
+ m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
302
+ torch.nn.init.constant_(m.weight, 0)
303
+ elif isinstance(m, nn.Embedding):
304
+ m.no_weight_decay = True
305
+ m.lr_scale = self.tunables.embeddings_lr_scale
306
+ std = self.tunables.embeddings_std
307
+ torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
308
+ # elif isinstance(m, EmbeddingProjector):
309
+ # m.lr_scale = self.tunables.embeddings_lr_scale #1/(m.weight.shape[1] / self.base_width)
310
+ # m.lr_scale = 2/(m.weight.shape[1] / self.base_width)
311
+ # std = self.tunables.init_std / m.weight.shape[1]
312
+ # torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
313
+ elif isinstance(m, nn.Linear):
314
+ m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
315
+ std = self.tunables.init_std / m.weight.shape[1]
316
+ torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
317
+ if m.bias is not None:
318
+ torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
319
+ elif isinstance(m, nn.LayerNorm):
320
+ m.no_weight_decay = True
321
+ torch.nn.init.constant_(m.bias, 0)
322
+ torch.nn.init.constant_(m.weight, 1)
323
+
324
+ def embed_stoks(self, Stoks):
325
+ b,n = Stoks.shape
326
+ if self.stoks_len == 1500:
327
+ # converts 50 toks/s to 75 toks/s by adding padding between every two tokens
328
+ x = Stoks.reshape(b,n//2,2)
329
+ x = x.repeat_interleave(2, -1)[:,:,:3]
330
+ x[:,:,1] = 1024
331
+ x = x.reshape(b,n//2*3)
332
+ else:
333
+ # it's a lot easier with 25 toks/s
334
+ # x = Stoks.repeat_interleave(3, -1)
335
+ x = Stoks
336
+ # embed semantic tokens
337
+ Sembs = self.semantic_embedding(x.to(torch.long))
338
+ if self.emb_factor:
339
+ Sembs = self.emb_to_hidden(Sembs)
340
+ return Sembs
341
+
342
+ def _encoder(self, semb, positions):
343
+ x = semb
344
+ for l in self.encoder: x = l(x, positions)
345
+ return self.ln_post(x)
346
+
347
+ def run_encoder(self, Stoks, speakers):
348
+ semb = self.embed_stoks(Stoks)
349
+ with record_function("encoder"):
350
+ if self.positional_embeddings is not None: semb = semb + self.positional_embeddings
351
+ positions = torch.arange(0, semb.shape[1], device=semb.device)
352
+ xenc = self._encoder(semb, positions)
353
+ if self.training:
354
+ enc_logits = (self.hidden_to_emb(xenc) @ self.semantic_embedding.weight.to(xenc.dtype).T).float()
355
+ enc_logits = enc_logits * self.tunables.output_mult / (self.width / self.base_width)
356
+ else:
357
+ enc_logits = None
358
+ # print(xenc.shape, speakers.shape)
359
+ spk_embs = F.normalize(speakers, dim=-1) # use extracted embeddings
360
+ if self.spk_factor: spk_embs = self.spk_to_hidden(spk_embs)
361
+ return xenc + spk_embs.unsqueeze(1), positions, enc_logits
362
+
363
+ def forward(self, Stoks, Atoks, speakers, langs=None, out_stoks=None, noloss=False, xenc=None, xenc_positions=None, atoks_positions=None):
364
+ if xenc is None:
365
+ Atoks = Atoks.to(torch.long)
366
+ out_stoks = out_stoks.to(torch.long)
367
+ Atoks_gt = Atoks.clone()
368
+ Atoks_gt[Atoks == -100] = 1024
369
+ xenc, enc_logits = self.run_encoder(Stoks, speakers)
370
+ else:
371
+ Atoks_gt = Atoks
372
+ with record_function("decoder"):
373
+ embs = self.embds(Atoks, xenc)
374
+ if atoks_positions is None: atoks_positions = torch.arange(0, embs.shape[1], device=embs.device)
375
+ x = self.decoder(embs, atoks_positions, xenc, xenc_positions)
376
+ logits = self.head(x, embeddings=self.embds.embeddings)
377
+ logits *= self.tunables.output_mult / (self.width / self.base_width)
378
+
379
+ if noloss:
380
+ return logits
381
+
382
+ with record_function("loss"):
383
+ N = Atoks.shape[-1]
384
+ loss = 0
385
+ for i in range(self.quantizers):
386
+ loss += F.cross_entropy(logits[:,i,i:].reshape(-1,logits.shape[-1]), Atoks[:,i,:N-i].reshape(-1))
387
+ if self.training and i == 0:
388
+ loss *= 5
389
+ loss /= self.quantizers
390
+ if self.training:
391
+ loss += 0.1 * F.cross_entropy(enc_logits.transpose(-1,-2), out_stoks)
392
+
393
+ if not self.training:
394
+ for i in range(self.quantizers):
395
+ Atoks_i = Atoks[:,i,:N-i]
396
+ valid_Atoks = Atoks_i != -100
397
+ self.val_true[i] += (logits[:,i,i:].argmax(-1)[valid_Atoks] == Atoks_i[valid_Atoks]).float().sum()
398
+ self.val_total[i] += valid_Atoks.float().sum()
399
+
400
+ return logits, loss
401
+
402
+ def get_metrics(self):
403
+ metrics = {
404
+ f'acc_{i}':x.item() for i,x in enumerate(self.val_true / self.val_total)
405
+ }
406
+ self.val_true[:] = 0
407
+ self.val_total[:] = 0
408
+ return metrics
409
+
410
+ #
411
+ # inference
412
+ #
413
+ @classmethod
414
+ def load_model(cls, ref="collabora/whisperspeech:s2a-q4-small-en+pl.model",
415
+ repo_id=None, filename=None, local_filename=None):
416
+ if repo_id is None and filename is None and local_filename is None:
417
+ if ":" in ref:
418
+ repo_id, filename = ref.split(":", 1)
419
+ else:
420
+ local_filename = ref
421
+ if not local_filename:
422
+ local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
423
+ spec = torch.load(local_filename)
424
+ if '_extra_state' not in spec['state_dict']: spec['state_dict']['_extra_state'] = { 'speaker_map': spec['config']['speaker_map'] }
425
+ model = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec['tunables'])))
426
+ model.load_state_dict(spec['state_dict'])
427
+ model.eval()
428
+ return model
429
+
430
+ def get_extra_state(self):
431
+ return { 'speaker_map': self.speaker_map }
432
+
433
+ def set_extra_state(self, st):
434
+ self.speaker_map = st['speaker_map']
435
+
436
+ def load_checkpoint(self, local_filename):
437
+ spec = torch.load(local_filename, map_location='cpu')
438
+ assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
439
+ state_dict = {k.replace('model.', ''):v
440
+ for k,v in spec['state_dict'].items()}
441
+ self.load_state_dict(state_dict)
442
+ return self
443
+
444
+ def save_model(self, fname):
445
+ torch.save(dict(config = self.__stored_args__,
446
+ tunables = dataclasses.asdict(self.tunables),
447
+ state_dict = self.state_dict()), fname)
448
+
449
+ def switch_dtypes(self, dtype=torch.float16):
450
+ self.dtype = dtype
451
+ for n,m in self.named_modules():
452
+ # convert every leaf layer apart from the LayerNorms
453
+ if isinstance(m, (nn.Linear, nn.Embedding)):
454
+ m.to(dtype)
455
+ # take care of buffers ([kv]_cache, masks) that are not in the leaf layers
456
+ for bn,b in m.named_buffers(recurse=False):
457
+ setattr(m,bn,b.to(dtype))
458
+
459
+ def optimize(self, max_batch_size=1, dtype=torch.float16, torch_compile=True):
460
+ for emb in self.embds.embeddings:
461
+ emb.convert_for_eval()
462
+ for l in self.encoder:
463
+ l.attn.convert_for_eval()
464
+ for l in self.decoder.layers:
465
+ l.attn.convert_for_eval()
466
+ l.cross_attn.convert_for_eval()
467
+ l.setup_kv_cache(max_batch_size, self.ctx_n, self.stoks_len)
468
+ self.switch_dtypes(dtype)
469
+ if torch_compile:
470
+ self.generate_next = torch.compile(self.generate_next, mode="reduce-overhead", fullgraph=True)
471
+
472
+ @property
473
+ def device(self):
474
+ return next(self.parameters()).device
475
+
476
+ # from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
477
+ def multinomial_sample_one_no_sync(self, probs_sort): # Does multinomial sampling without a cuda synchronization
478
+ q = torch.empty_like(probs_sort).exponential_(1)
479
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
480
+
481
+ def logits_to_probs(self, logits, T=1.0, top_k=None):
482
+ logits = logits / max(T, 1e-5)
483
+
484
+ if top_k is not None:
485
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
486
+ pivot = v.select(-1, -1).unsqueeze(-1)
487
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
488
+ probs = torch.nn.functional.softmax(logits, dim=-1)
489
+ return probs
490
+
491
+ def sample(self, logits, T=1.0, top_k=None):
492
+ probs = self.logits_to_probs(logits[0,:,-1], T, top_k)
493
+ idx_next = self.multinomial_sample_one_no_sync(probs)
494
+ return idx_next
495
+
496
+ def generate_one(self, toks, positions, langs, xenc, xenc_positions, T, top_k):
497
+ probs = self(None, toks, None, langs, noloss=True, xenc=xenc, xenc_positions=xenc_positions, atoks_positions=positions)
498
+ return self.sample(probs, T, top_k)
499
+
500
+ def generate_next(self, *args, **kwargs):
501
+ return self.generate_one(*args, **kwargs)
502
+
503
+ @torch.no_grad()
504
+ def generate(self, stoks, speakers, langs=None, N=None, T=0.7, top_k=None, show_progress_bar=True, step=None, subsample_enc=False):
505
+ dev = self.device
506
+ N = N or len(stoks) * 3
507
+ stoks = F.pad(stoks.to(dev), (1, self.stoks_len - len(stoks)-1), value=self.stoks_codes-1).unsqueeze(0)
508
+ speakers = speakers.to(device=dev, dtype=self.dtype)
509
+ toks = torch.full((1,self.quantizers,2250), self.codes+1, dtype=torch.long, device=dev)
510
+ it = range(1,min(N,2250-1))
511
+ if show_progress_bar: it = progress_bar(it)
512
+ with record_function("encode"):
513
+ xenc, xenc_positions, _ = self.run_encoder(stoks, speakers)
514
+ toks_positions = torch.arange(N, device=dev)
515
+ with record_function("prefill"):
516
+ toks[0,0,1] = self.generate_one(toks[:,:,:1], toks_positions[:1], langs, xenc, xenc_positions, T, top_k)[0,0]
517
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
518
+ for i in it:
519
+ with record_function("generate_one"):
520
+ toks[0,:i+1,i+1] = self.generate_next(toks[:,:,i:i+1], toks_positions[i:i+1], langs, xenc, xenc_positions, T, top_k)[:i+1,0]
521
+
522
+ # for profiling, debugging or early exit
523
+ if step is not None: step()
524
+ # shift tokens
525
+ toks = toks[:,:,1:N]
526
+ for j in range(self.quantizers):
527
+ toks[0, j] = torch.roll(toks[0, j], -j)
528
+ return toks[0]
529
+
530
+ # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 39
531
+ def _make_model(size:str, quantizers:int=4, tunables:Tunables=Tunables(), **kwargs):
532
+ kwargs = dict(quantizers=quantizers, tunables=tunables, **kwargs)
533
+ if size == 'micro':
534
+ return SADelARTransformer(depth=4, n_head=3, ffn_mult=2, **kwargs)
535
+ if size == 'tiny-narrow':
536
+ return SADelARTransformer(depth=4, n_head=6, ffn_mult=1, **kwargs)
537
+ if size == 'tiny':
538
+ return SADelARTransformer(depth=4, n_head=6, **kwargs)
539
+ if size == 'base':
540
+ return SADelARTransformer(depth=6, n_head=8, **kwargs)
541
+ if size == 'base-deep':
542
+ return SADelARTransformer(depth=9, n_head=8, **kwargs)
543
+ if size == 'base-wide':
544
+ return SADelARTransformer(depth=6, n_head=12, **kwargs)
545
+ if size == 'small/2':
546
+ return SADelARTransformer(depth=9, n_head=12, **kwargs)
547
+ if size == 'small':
548
+ return SADelARTransformer(depth=12, n_head=12, **kwargs)
549
+ if size == 'medium':
550
+ return SADelARTransformer(depth=24, n_head=16, **kwargs)
551
+
552
+ def make_model(size:str, quantizers:int=4, frozen_embeddings_model:str=None, frozen_acoustic_embeddings:bool=False, spk_width:int=None, tunables:Tunables=Tunables(), dataset=None):
553
+ from encodec.model import EncodecModel
554
+ from whisperspeech import vq_stoks
555
+
556
+ amodel = EncodecModel.encodec_model_24khz() if frozen_acoustic_embeddings else None
557
+ vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model) if frozen_embeddings_model else None
558
+ model = _make_model(size, quantizers, tunables,
559
+ spk_width=spk_width,
560
+ atoks_width=amodel and amodel.quantizer.vq.layers[0]._codebook.embed.shape[-1],
561
+ stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
562
+ if vqmodel: model.load_frozen_semantic_embeddings(vqmodel)
563
+ if amodel: model.load_frozen_acoustic_embeddings(amodel)
564
+ return model
whisperspeech/t2s_up_wds.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/5B. Text to semantic token modeling.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['load_datasets', 'rand', 'Tunables', 'Encoder', 'Decoder', 'TSARTransformer', 'make_model']
5
+
6
+ # %% ../nbs/5B. Text to semantic token modeling.ipynb 1
7
+ import dataclasses
8
+ import random
9
+ import math
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.profiler import record_function
14
+
15
+ from huggingface_hub import hf_hub_download
16
+ from fastcore.basics import store_attr
17
+ from fastprogress import progress_bar
18
+
19
+ import webdataset as wds
20
+
21
+ # %% ../nbs/5B. Text to semantic token modeling.ipynb 2
22
+ from pathlib import Path
23
+ import pylab as plt
24
+ import pandas as pd
25
+ import numpy as np
26
+
27
+ # %% ../nbs/5B. Text to semantic token modeling.ipynb 3
28
+ import whisper
29
+ from whisperspeech.train import *
30
+ from whisperspeech.modules import *
31
+ from whisperspeech import vq_stoks
32
+
33
+ # %% ../nbs/5B. Text to semantic token modeling.ipynb 8
34
+ import re
35
+
36
+ class CharTokenizer:
37
+ """Trivial tokenizer โ€“ just use UTF-8 bytes"""
38
+ eot = 0
39
+
40
+ def encode(self, txt):
41
+ return list(bytes(txt.strip(), 'utf-8'))
42
+
43
+ def decode(self, tokens):
44
+ return bytes(tokens).decode('utf-8')
45
+
46
+ def tokenizer(ikey, okey, length):
47
+ """Tokenizes a transcript"""
48
+ tok = CharTokenizer()
49
+ def _tokenizer(samples):
50
+ for s in samples:
51
+ toks = torch.tensor(tok.encode(s[ikey]))
52
+ s[okey] = F.pad(toks, (0, length - toks.shape[-1]), value=tok.eot)
53
+ yield s
54
+ return _tokenizer
55
+
56
+ def ar_padder(ikey, okey, length, pad_token):
57
+ """Pads the tokens for autoregresive training"""
58
+ def _ar_padder(samples):
59
+ for s in samples:
60
+ toks = s[ikey]
61
+ if isinstance(toks, (list, np.ndarray)): toks = torch.tensor(toks)
62
+ toks = toks.to(torch.long)
63
+ s['in_' +okey] = F.pad(toks, (1, length - toks.shape[-1] - 1), value=pad_token)
64
+ s['out_'+okey] = F.pad(toks, (0, length - toks.shape[-1]), value=pad_token)
65
+ yield s
66
+ return _ar_padder
67
+
68
+ def char_per_seconder(txt_key, stoks_key, cps_key, stoks_per_second=25):
69
+ """Adds the characters per second metric to the input data"""
70
+ def _char_per_seconder(samples):
71
+ for s in samples:
72
+ secs = s[stoks_key].shape[-1] / stoks_per_second
73
+ s[cps_key] = len(s[txt_key]) / secs
74
+ yield s
75
+ return _char_per_seconder
76
+
77
+ # %% ../nbs/5B. Text to semantic token modeling.ipynb 9
78
+ def build_speaker_map(shards):
79
+ speakers = set()
80
+ for shard in shards:
81
+ with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines()))
82
+ return {id:i for i,id in enumerate(speakers)}
83
+
84
+ def speaker_id_extractor(speaker_map):
85
+ def _extractor(samples):
86
+ for s in samples:
87
+ s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]])
88
+ yield s
89
+ return _extractor
90
+
91
+ # %% ../nbs/5B. Text to semantic token modeling.ipynb 10
92
+ def load_datasets(
93
+ input:str, # webdataset folder or shard list
94
+ samples:int, # samples per epoch
95
+ subsample:float=1, # use a fraction of the files
96
+ val_samples:int=512,
97
+ vq_codes:int=4096,
98
+ ):
99
+ if isinstance(input, (Path, str)):
100
+ path = Path(input)
101
+ if path.is_dir():
102
+ glob = '*-t2s-*.tar.gz'
103
+ else:
104
+ glob = path.name
105
+ path = path.parent
106
+ input = Path(path).glob(glob)
107
+ elif isinstance(input, list):
108
+ pass
109
+ else:
110
+ raise ArgumentError("input should be either a list of a path with an optional glob specifier")
111
+ shards = [str(x) for x in input]
112
+
113
+ speaker_map = build_speaker_map(shards)
114
+
115
+ def ds(shards, length):
116
+ ds = wds.WebDataset(wds.ResampledShards(shards)).compose(
117
+ wds.decode(),
118
+ speaker_id_extractor(speaker_map),
119
+ wds.select(lambda s: s['stoks.npy'].shape[-1] > 12), # select samples > .5s
120
+ tokenizer('txt', 'ttoks', length=550),
121
+ ar_padder('stoks.npy', 'stoks', length=750, pad_token=vq_codes-1),
122
+ char_per_seconder('txt', 'stoks.npy', 'cps', stoks_per_second=25),
123
+ wds.to_tuple('ttoks', 'speaker', 'cps', 'in_stoks', 'out_stoks'),
124
+ wds.batched(64)
125
+ )
126
+ ds.speakers = speaker_map
127
+ ds.total_samples = length
128
+ ds.stoks_len = 750
129
+ ds.stoks_codes = vq_codes
130
+ ds.ttoks_len = 550
131
+ return ds.compose(wds.slice(length // 64)).with_epoch(length // 64).with_length(length // 64)
132
+
133
+ return (
134
+ ds(shards[1:], samples),
135
+ ds(shards[:1], val_samples),
136
+ )
137
+
138
+ # %% ../nbs/5B. Text to semantic token modeling.ipynb 14
139
+ def rand(start, end):
140
+ return random.random() * (end - start) + start
141
+
142
+ @dataclasses.dataclass
143
+ class Tunables:
144
+ init_std :float = 1
145
+ embeddings_std :float = .01
146
+ embeddings_lr_scale: float = 5
147
+ embedding_projector_lr_scale: float = 2.5
148
+ output_mult :float = .35
149
+ query_mult :float = 1
150
+ encoder_depth_ratio :float = 0.25
151
+ eot_dropout_p :float = .5
152
+ cps_input: bool = True
153
+ cps_bins: int = 32
154
+
155
+ lr0 :float = 1.5e-3
156
+ clip_gradient_norm :float = .2
157
+ weight_decay :float = 1e-1
158
+ warmup_steps :float = 4000
159
+
160
+ random :bool = False
161
+
162
+ def __post_init__(self):
163
+ # randomize the hyperparams if requested
164
+ if self.random:
165
+ self.init_std = 10**rand(-1,1)
166
+ self.embeddings_std = 10**rand(-3,-.7)
167
+ self.embeddings_lr_scale = rand(2,6)
168
+ self.output_mult = rand(0.25,0.65)
169
+ self.query_mult = 2**rand(-2,3)
170
+ self.encoder_depth_ratio = 0.25
171
+
172
+ self.lr0 = rand(1,5)*1e-3
173
+ self.clip_gradient_norm = 10**rand(-3,0)
174
+ self.warmup_steps = 100*(10**rand(1,1.85))
175
+
176
+ # %% ../nbs/5B. Text to semantic token modeling.ipynb 15
177
+ class EmbeddingProjector(nn.Linear):
178
+ pass
179
+
180
+ # %% ../nbs/5B. Text to semantic token modeling.ipynb 16
181
+ class Encoder(nn.Module):
182
+ def __init__(self, depth=6, width=384, n_head=6, length=1500, codes=1024, emb_width=384, ffn_mult=4, pos_embs=None, tunables=Tunables()):
183
+ super().__init__()
184
+ self.emb_width = emb_width
185
+
186
+ self.emb_factor = width != emb_width
187
+
188
+ self.embedding = nn.Embedding(codes, emb_width)
189
+ if self.emb_factor:
190
+ self.emb_to_hidden = EmbeddingProjector(emb_width, width)
191
+
192
+ if pos_embs is None: pos_embs = sinusoids(length, width)
193
+ self.register_buffer("positional_embedding", pos_embs)
194
+
195
+ self.layers = nn.Sequential(*[
196
+ ResidualAttentionBlock(width, n_head,
197
+ qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth)
198
+ ])
199
+
200
+ self.ln_post = LayerNorm(width)
201
+
202
+ def forward(self, Stoks):
203
+ xin = self.embedding(Stoks)
204
+ if self.emb_factor:
205
+ xin = self.emb_to_hidden(xin)
206
+
207
+ assert xin.shape[1:] == self.positional_embedding.shape, "incorrect semantic token shape"
208
+ xin = (xin + self.positional_embedding).to(xin.dtype)
209
+
210
+ return self.ln_post(self.layers(xin))
211
+
212
+ # %% ../nbs/5B. Text to semantic token modeling.ipynb 17
213
+ class Decoder(nn.Module):
214
+ def __init__(self, depth=6, stoks_width=384, width=384, n_head=6, length=1500, codes=1024, ffn_mult=4, pos_embs=None, tunables=Tunables()):
215
+ super().__init__()
216
+ self.length = length
217
+ self.codes = codes
218
+ self.width = width
219
+ self.stoks_width = stoks_width
220
+
221
+ self.emb_factor = width != stoks_width
222
+
223
+ # embed semantic tokens
224
+ self.embedding = nn.Embedding(codes, stoks_width)
225
+ if self.emb_factor:
226
+ self.emb_to_hidden = EmbeddingProjector(stoks_width, width)
227
+ self.hidden_to_emb = EmbeddingProjector(width, stoks_width)
228
+
229
+ if pos_embs is None: pos_embs = sinusoids(length, width)
230
+ self.register_buffer("positional_embedding", pos_embs)
231
+
232
+ self.layers = nn.ModuleList([
233
+ ResidualAttentionBlock(width, n_head, cross_attention=True,
234
+ qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth)
235
+ ])
236
+ self.ln_post = LayerNorm(width)
237
+
238
+ def forward(self, Stoks, xenc, cps=None):
239
+ Sembs = self.embedding(Stoks)
240
+
241
+ if self.emb_factor:
242
+ Sembs = self.emb_to_hidden(Sembs)
243
+
244
+ xin = (Sembs + self.positional_embedding[:Sembs.shape[1]]).to(xenc.dtype)
245
+ if cps is not None: xin = xin + cps
246
+
247
+ x = xin
248
+ for l in self.layers: x = l(x, xenc, causal=True)
249
+
250
+ x = self.ln_post(x)
251
+
252
+ if self.emb_factor:
253
+ x = self.hidden_to_emb(x)
254
+
255
+ logits = (x @ self.embedding.weight.to(x.dtype).T).float()
256
+ return logits
257
+
258
+ # %% ../nbs/5B. Text to semantic token modeling.ipynb 18
259
+ class TSARTransformer(nn.Module):
260
+ def __init__(self, depth=6, n_head=6, head_width=64, ffn_mult=4, language='en',
261
+ ttoks_len=200, ttoks_codes=50364, ttoks_width=None,
262
+ stoks_len=1500, stoks_codes=1024, stoks_width=None,
263
+ tunables=Tunables()):
264
+ assert language == 'en', "only english is supported right now"
265
+ super().__init__()
266
+ store_attr("depth,n_head,head_width,ffn_mult,stoks_width,ttoks_width,ttoks_len,stoks_len,ttoks_codes,stoks_codes,language")
267
+
268
+ width = n_head * head_width
269
+ self.width = width
270
+ self.base_width = 3 * head_width
271
+ self.tunables = tunables
272
+ if self.stoks_width is None: self.stoks_width = self.width
273
+ if self.ttoks_width is None: self.ttoks_width = self.width
274
+
275
+ if tunables.cps_input:
276
+ self.cps_embeddings = nn.Embedding(tunables.cps_bins, self.width)
277
+ else:
278
+ self.cps_embeddings = None
279
+
280
+ encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
281
+ decoder_depth = depth * 2 - encoder_depth
282
+ tformer_args = dict(width=width, n_head=n_head, ffn_mult=ffn_mult, tunables=tunables)
283
+ self.encoder = Encoder(length=ttoks_len, codes=ttoks_codes, emb_width=self.ttoks_width, depth=encoder_depth, **tformer_args)
284
+ self.decoder = Decoder(length=stoks_len, codes=stoks_codes, stoks_width=self.stoks_width, depth=decoder_depth, **tformer_args)
285
+
286
+ self.tokenizer = None
287
+
288
+ self.apply(self.init_transformer)
289
+
290
+ def load_frozen_semantic_embeddings(self, vqmodel):
291
+ with torch.no_grad():
292
+ self.decoder.embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0]
293
+ self.decoder.embedding.lr_scale = 0
294
+
295
+ def setup(self, device):
296
+ pass
297
+
298
+ def init_transformer(self, m):
299
+ if isinstance(m, LinearHead):
300
+ m.no_weight_decay = True
301
+ torch.nn.init.constant_(m.weight, 0)
302
+ elif isinstance(m, QueryHead):
303
+ m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
304
+ torch.nn.init.constant_(m.weight, 0)
305
+ elif isinstance(m, nn.Embedding):
306
+ m.no_weight_decay = True
307
+ m.lr_scale = self.tunables.embeddings_lr_scale
308
+ std = self.tunables.embeddings_std
309
+ torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
310
+ elif isinstance(m, EmbeddingProjector):
311
+ m.lr_scale = self.tunables.embedding_projector_lr_scale
312
+ std = self.tunables.init_std
313
+ torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
314
+ elif isinstance(m, nn.Linear):
315
+ m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
316
+ std = self.tunables.init_std / m.weight.shape[1]
317
+ torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
318
+ if m.bias is not None:
319
+ torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
320
+ elif isinstance(m, nn.LayerNorm):
321
+ m.no_weight_decay = True
322
+ torch.nn.init.constant_(m.bias, 0)
323
+ torch.nn.init.constant_(m.weight, 1)
324
+
325
+ def forward(self, Ttoks, speakers, cpss, in_stoks, out_stoks=None, loss=True):
326
+ with record_function("encoder"):
327
+ xenc = self.encoder(Ttoks.to(torch.long))
328
+ with record_function("decoder"):
329
+ if self.cps_embeddings:
330
+ cps_bin = (cpss / 20 * self.tunables.cps_bins).to(torch.long)
331
+ cps_bin[cps_bin >= self.tunables.cps_bins] = self.tunables.cps_bins-1
332
+ cps_embs = self.cps_embeddings(cps_bin).unsqueeze(1)
333
+ else:
334
+ cps_embs = None
335
+ logits = self.decoder(in_stoks, xenc, cps=cps_embs) * self.tunables.output_mult / (self.width / self.base_width)
336
+ if loss is not None:
337
+ with record_function("loss"):
338
+ loss = F.cross_entropy(logits.transpose(-1,-2), out_stoks)#, reduction='none')
339
+ return logits, loss
340
+
341
+ #
342
+ # inference
343
+ #
344
+ @classmethod
345
+ def load_model(cls, repo_id="collabora/whisperspeech", filename="t2s_up_wds.model", local_filename=None):
346
+ if not local_filename:
347
+ local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
348
+ spec = torch.load(local_filename)
349
+ model = cls(**spec['config'], tunables=Tunables(**spec['tunables']))
350
+ model.load_state_dict(spec['state_dict'])
351
+ model.eval()
352
+ return model
353
+
354
+ def load_checkpoint(self, local_filename):
355
+ spec = torch.load(local_filename, map_location='cpu')
356
+ assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
357
+ state_dict = {k.replace('model.', ''):v
358
+ for k,v in spec['state_dict'].items()}
359
+ self.load_state_dict(state_dict)
360
+ return self
361
+
362
+ def save_model(self, fname):
363
+ torch.save(dict(config = self.__stored_args__,
364
+ tunables = dataclasses.asdict(self.tunables),
365
+ state_dict = self.state_dict()), fname)
366
+
367
+ def ensure_tokenizer(self):
368
+ assert not self.training
369
+ if self.tokenizer is None: self.tokenizer = CharTokenizer()
370
+ #whisper.tokenizer.get_tokenizer(multilingual=True)
371
+
372
+ @property
373
+ def device(self):
374
+ return next(self.parameters()).device
375
+
376
+ @torch.no_grad()
377
+ def generate(self, txt, cps=15, N=None, T=0.7, top_k=None, show_progress_bar=True):
378
+ self.ensure_tokenizer()
379
+ N = N or self.stoks_len
380
+ dev = self.device
381
+ ttoks = torch.tensor(self.tokenizer.encode(txt), device=dev)
382
+ ttoks = F.pad(ttoks, (0, self.ttoks_len - len(ttoks)), value=self.tokenizer.eot).unsqueeze(0)
383
+ cpss = torch.tensor([cps], device=dev)
384
+ toks = torch.zeros((1,N), dtype=torch.long, device=dev)
385
+ toks[0,0] = self.stoks_codes-1
386
+ it = range(1,N)
387
+ if show_progress_bar: it = progress_bar(it)
388
+ for i in it:
389
+ p, _ = self(ttoks, None, cpss, toks[:,:i], loss=None)
390
+ last_p = p[0,-1]
391
+ if top_k:
392
+ last_p[last_p < torch.topk(last_p, top_k).values[-1,None]] = -torch.inf
393
+ tok = torch.multinomial((last_p / float(T)).softmax(-1), 1)
394
+ toks[0,i] = tok
395
+ if toks[0,i] == self.stoks_codes-1: return toks[0,1:i]
396
+ return toks[0,1:]
397
+
398
+ @torch.no_grad()
399
+ def generate_batch(self, txts, N=None, T=1.1, top_k=7, show_progress_bar=True):
400
+ self.ensure_tokenizer()
401
+ N = self.stoks_len
402
+ dev = self.device
403
+ ttoks = []
404
+ for txt in txts:
405
+ ttoks_ = torch.tensor(self.tokenizer.encode(txt), device=dev)
406
+ ttoks_ = F.pad(ttoks_, (0, self.ttoks_len - len(ttoks_)), value=self.tokenizer.eot).unsqueeze(0)
407
+ ttoks.append(ttoks_)
408
+ ttoks = torch.cat(ttoks, dim=0)
409
+ toks = torch.zeros((len(ttoks),N), dtype=torch.long, device=dev)
410
+ it = range(N)
411
+ if show_progress_bar: it = progress_bar(it)
412
+ for i in it:
413
+ p, _ = self(ttoks, toks[:,:i], loss=None)
414
+ last_p = p[:,-1]
415
+ if top_k:
416
+ last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf
417
+ tok = torch.multinomial((last_p / float(T)).softmax(-1), 1)
418
+ toks[:,i] = tok[:,0]
419
+ if (toks[:,i] == self.stoks_codes-1).all(): return toks[:,:i]
420
+ return toks
421
+
422
+ # %% ../nbs/5B. Text to semantic token modeling.ipynb 19
423
+ def _make_model(size:str, tunables:Tunables=Tunables(), dataset=None, **kwargs):
424
+ kwargs = dict(stoks_len = dataset.stoks_len, ttoks_len = dataset.ttoks_len, tunables=tunables, **kwargs)
425
+ if 'stoks_codes' not in kwargs: kwargs['stoks_codes'] = dataset.stoks_codes
426
+ if size == 'micro':
427
+ return TSARTransformer(depth=2, n_head=3, ffn_mult=1, **kwargs)
428
+ if size == 'tiny':
429
+ return TSARTransformer(depth=4, n_head=6, **kwargs)
430
+ if size == 'base':
431
+ return TSARTransformer(depth=6, n_head=8, **kwargs)
432
+ if size == 'small':
433
+ return TSARTransformer(depth=12, n_head=16, **kwargs)
434
+
435
+ def make_model(size:str, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
436
+ if frozen_embeddings_model:
437
+ vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model)
438
+ model = _make_model(size, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
439
+ model.load_frozen_semantic_embeddings(vqmodel)
440
+ else:
441
+ model = _make_model(size, quantizers, tunables, dataset)
442
+ return model
whisperspeech/t2s_up_wds_mlang_enclm.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/5B. Multi-lang text to semantic token modeling.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['load_dataset', 'rand', 'Tunables', 'T2SEmbedding', 'Encoder', 'TSARTransformer', 'make_model']
5
+
6
+ # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 1
7
+ import dataclasses
8
+ import random
9
+ import math
10
+ import itertools
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.profiler import record_function
15
+
16
+ from huggingface_hub import hf_hub_download
17
+ from fastcore.basics import store_attr
18
+ from fastprogress import progress_bar
19
+
20
+ from pathlib import Path
21
+
22
+ # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 2
23
+ from whisperspeech.modules import *
24
+ from whisperspeech import languages
25
+
26
+ # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 6
27
+ import re
28
+
29
+ class CharTokenizer:
30
+ """Trivial tokenizer โ€“ just use UTF-8 bytes"""
31
+ eot = 0
32
+
33
+ def encode(self, txt):
34
+ return list(bytes(txt.strip(), 'utf-8'))
35
+
36
+ def decode(self, tokens):
37
+ return bytes(tokens).decode('utf-8')
38
+
39
+ def tokenizer(ikey, okey, length):
40
+ """Tokenizes a transcript"""
41
+ tok = CharTokenizer()
42
+ def _tokenizer(samples):
43
+ for s in samples:
44
+ toks = torch.tensor(tok.encode(s[ikey]))
45
+ s[okey] = F.pad(toks, (0, length - toks.shape[-1]), value=tok.eot)
46
+ yield s
47
+ return _tokenizer
48
+
49
+ def ar_padder(ikey, okey, length, pad_token):
50
+ """Pads the tokens for autoregresive training"""
51
+ import numpy as np
52
+
53
+ def _ar_padder(samples):
54
+ for s in samples:
55
+ toks = s[ikey]
56
+ if isinstance(toks, (list, np.ndarray)): toks = torch.tensor(toks)
57
+ toks = toks.to(torch.long)
58
+ s['in_' +okey] = F.pad(toks, (1, length - toks.shape[-1] - 1), value=pad_token)
59
+ s['out_'+okey] = F.pad(toks, (0, length - toks.shape[-1]), value=pad_token)
60
+ yield s
61
+ return _ar_padder
62
+
63
+ def char_per_seconder(txt_key, stoks_key, cps_key, stoks_per_second=25):
64
+ """Adds the characters per second metric to the input data"""
65
+ def _char_per_seconder(samples):
66
+ for s in samples:
67
+ secs = s[stoks_key].shape[-1] / stoks_per_second
68
+ s[cps_key] = len(s[txt_key]) / secs
69
+ yield s
70
+ return _char_per_seconder
71
+
72
+ # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 7
73
+ def load_dataset(
74
+ txt_shard_spec:str, # transcription webdataset shards
75
+ stoks_shard_dir:str, # stoks webdataset base dir
76
+ samples:int, # samples per epoch
77
+ txt_kind:str='small.en-txt',
78
+ vq_codes:int=4096,
79
+ language:str='en',
80
+ weight:float=1,
81
+ validation:bool=False,
82
+ exclude_files:str=None,
83
+ ):
84
+ import webdataset as wds
85
+ from whisperspeech import utils
86
+
87
+ shards = utils.shard_glob(txt_shard_spec)
88
+ excludes = {x for file in exclude_files.split() for x in utils.readlines(file)} if exclude_files else set()
89
+
90
+ language = languages.to_id(language)
91
+
92
+ def set_language(x):
93
+ x['language'] = language
94
+ return x
95
+
96
+ same_on_all_nodes = lambda urls: urls # will only be used for validation
97
+ ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
98
+ wds.decode(),
99
+ utils.merge_in(utils.derived_dataset('eqvad-stoks', base=txt_kind, suffix='', dir=stoks_shard_dir)),
100
+ # discard validation samples, select samples > .5s
101
+ wds.select(lambda s: s['__key__'] not in excludes and s['stoks.npy'].shape[-1] > 12),
102
+ tokenizer('txt', 'ttoks', length=550),
103
+ ar_padder('stoks.npy', 'stoks', length=750, pad_token=vq_codes-1),
104
+ ar_padder('ttoks', 'ttoks', length=550, pad_token=CharTokenizer.eot),
105
+ char_per_seconder('txt', 'stoks.npy', 'cps', stoks_per_second=25),
106
+ wds.map(set_language),
107
+ wds.to_tuple('in_ttoks', 'out_ttoks', 'language', 'cps', 'in_stoks', 'out_stoks'),
108
+ wds.shuffle(20000, initial=20000),
109
+ wds.batched(64)
110
+ )
111
+ if validation:
112
+ ds = ds.slice(samples // 64)
113
+ ds.total_samples = samples
114
+ ds.stoks_len = 750
115
+ ds.stoks_codes = vq_codes
116
+ ds.ttoks_len = 550
117
+ ds.weight = weight
118
+
119
+ return ds
120
+
121
+ # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 14
122
+ def rand(start, end):
123
+ return random.random() * (end - start) + start
124
+
125
+ @dataclasses.dataclass
126
+ class Tunables:
127
+ init_std :float = 1
128
+ embeddings_std :float = .01
129
+ embeddings_lr_scale: float = 5
130
+ embedding_projector_lr_scale: float = 2.5
131
+ output_mult :float = .35
132
+ query_mult :float = 1
133
+ encoder_depth_ratio :float = 0.25
134
+ eot_dropout_p :float = .5
135
+ cps_input: bool = True
136
+ cps_bins: int = 32
137
+
138
+ lr0 :float = 1.5e-3
139
+ clip_gradient_norm :float = .2
140
+ weight_decay :float = 1e-1
141
+ warmup_steps :float = 4000
142
+
143
+ random :bool = False
144
+
145
+ def __post_init__(self):
146
+ # randomize the hyperparams if requested
147
+ if self.random:
148
+ self.init_std = 10**rand(-1,1)
149
+ self.embeddings_std = 10**rand(-3,-.7)
150
+ self.embeddings_lr_scale = rand(2,6)
151
+ self.output_mult = rand(0.25,0.65)
152
+ self.query_mult = 2**rand(-2,3)
153
+ self.encoder_depth_ratio = 0.25
154
+
155
+ self.lr0 = rand(1,5)*1e-3
156
+ self.clip_gradient_norm = 10**rand(-3,0)
157
+ self.warmup_steps = 100*(10**rand(1,1.85))
158
+
159
+ # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 15
160
+ class T2SEmbedding(nn.Module):
161
+ def __init__(self, length=1500, codes=1024, width=384, pos_embs=None, stoks_width=384):
162
+ super().__init__()
163
+ self.embedding = FlexEmbeddings(codes, width, special_codes=1, frozen_width=stoks_width)
164
+ if pos_embs is None: pos_embs = sinusoids(length, width)
165
+ self.register_buffer("positional_embedding", pos_embs)
166
+
167
+ def forward(self, Stoks, xenc, cps=None, offset=0):
168
+ Sembs = self.embedding(Stoks)
169
+ xin = (Sembs + self.positional_embedding[offset : offset + Sembs.shape[1]]).to(xenc.dtype)
170
+ if cps is not None: xin = xin + cps
171
+ return xin, offset
172
+
173
+ # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 16
174
+ class Encoder(nn.Module):
175
+ def __init__(self, depth=6, width=384, n_head=6, length=1500, codes=1024, emb_width=384, ffn_mult=4, pos_embs=None, tunables=Tunables()):
176
+ super().__init__()
177
+ self.emb_width = emb_width
178
+
179
+ self.embedding = FlexEmbeddings(codes, width, frozen_width=emb_width)
180
+
181
+ if pos_embs is None: pos_embs = sinusoids(length, width)
182
+ self.register_buffer("positional_embedding", pos_embs)
183
+
184
+ self.layers = nn.ModuleList([
185
+ ResidualAttentionBlock(width, n_head,
186
+ qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth)
187
+ ])
188
+
189
+ self.ln_post = LayerNorm(width)
190
+
191
+ mask = torch.empty(length, length).fill_(-torch.inf).triu_(1)
192
+ self.register_buffer("mask", mask, persistent=False)
193
+
194
+ def forward(self, Stoks, positions, lang_emb=None):
195
+ xin = self.embedding(Stoks)
196
+
197
+ if lang_emb is not None: xin += lang_emb
198
+
199
+ # assert xin.shape[1:] == self.positional_embedding.shape, "incorrect semantic token shape"
200
+ x = (xin +
201
+ self.positional_embedding[positions]).to(xin.dtype)
202
+
203
+ for l in self.layers: x = l(x, positions, causal=False, mask=self.mask)
204
+
205
+ return self.ln_post(x)
206
+
207
+ # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 17
208
+ class TSARTransformer(nn.Module):
209
+ def __init__(self, depth=6, n_head=6, head_width=64, ffn_mult=4,
210
+ ttoks_len=200, ttoks_codes=256, ttoks_width=None,
211
+ stoks_len=1500, stoks_codes=1024, stoks_width=None,
212
+ tunables=Tunables()):
213
+ super().__init__()
214
+ store_attr("depth,n_head,head_width,ffn_mult,stoks_width,ttoks_width,ttoks_len,stoks_len,ttoks_codes,stoks_codes")
215
+
216
+ width = n_head * head_width
217
+ self.width = width
218
+ self.base_width = 3 * head_width
219
+ self.tunables = tunables
220
+ if self.stoks_width is None: self.stoks_width = self.width
221
+ if self.ttoks_width is None: self.ttoks_width = self.width
222
+
223
+ self.lang_embeddings = nn.Embedding(len(languages.languages), width)
224
+ if tunables.cps_input:
225
+ self.cps_embeddings = nn.Embedding(tunables.cps_bins, self.width)
226
+ else:
227
+ self.cps_embeddings = None
228
+
229
+ encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
230
+ decoder_depth = depth * 2 - encoder_depth
231
+ tformer_args = dict(width=width, n_head=n_head, ffn_mult=ffn_mult, tunables=tunables)
232
+ self.encoder = Encoder(length=ttoks_len, codes=ttoks_codes, emb_width=self.ttoks_width, depth=encoder_depth, **tformer_args)
233
+ self.embeddings = T2SEmbedding(length=stoks_len, codes=stoks_codes, width=width, stoks_width=self.stoks_width)
234
+
235
+ self.decoder = BaseDecoder(
236
+ length=stoks_len,
237
+ depth=decoder_depth,
238
+ qk_scale=tunables.query_mult*8/math.sqrt(width/n_head),
239
+ width=width, n_head=n_head, ffn_mult=ffn_mult,
240
+ )
241
+ self.tokenizer = None
242
+
243
+ self.apply(self.init_transformer)
244
+
245
+ def load_frozen_semantic_embeddings(self, vqmodel):
246
+ self.embeddings.embedding.set_frozen_embeddings(vqmodel.rq.layers[0]._codebook.embed[0])
247
+
248
+ def setup(self, device):
249
+ pass
250
+
251
+ def init_transformer(self, m):
252
+ if isinstance(m, LinearHead):
253
+ m.no_weight_decay = True
254
+ torch.nn.init.constant_(m.weight, 0)
255
+ elif isinstance(m, QueryHead):
256
+ m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
257
+ torch.nn.init.constant_(m.weight, 0)
258
+ elif isinstance(m, nn.Embedding):
259
+ m.no_weight_decay = True
260
+ m.lr_scale = self.tunables.embeddings_lr_scale
261
+ std = self.tunables.embeddings_std
262
+ torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
263
+ elif isinstance(m, EmbeddingProjector):
264
+ m.lr_scale = self.tunables.embedding_projector_lr_scale
265
+ std = self.tunables.init_std
266
+ torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
267
+ elif isinstance(m, nn.Linear):
268
+ m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
269
+ std = self.tunables.init_std / m.weight.shape[1]
270
+ torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
271
+ if m.bias is not None:
272
+ torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
273
+ elif isinstance(m, nn.LayerNorm):
274
+ m.no_weight_decay = True
275
+ torch.nn.init.constant_(m.bias, 0)
276
+ torch.nn.init.constant_(m.weight, 1)
277
+
278
+ def _embed_cps(self, cpss):
279
+ if self.cps_embeddings is None: return None
280
+
281
+ cps_bin = (cpss / 20 * self.tunables.cps_bins).to(torch.long)
282
+ cps_bin[cps_bin >= self.tunables.cps_bins] = self.tunables.cps_bins-1
283
+ return self.cps_embeddings(cps_bin).unsqueeze(1)
284
+
285
+ def run_encoder(self, in_ttoks, languages, cpss):
286
+ if len(languages.shape) != 3: lang_embs = self.lang_embeddings(languages)
287
+ else: lang_embs = languages
288
+ if len(lang_embs.shape) == 2: lang_embs = lang_embs.unsqueeze(1)
289
+
290
+ cps_emb = self._embed_cps(cpss)
291
+
292
+ with record_function("encoder"):
293
+ positions = torch.arange(0, in_ttoks.shape[1], device=in_ttoks.device)
294
+ xenc = self.encoder(in_ttoks.to(torch.long), positions, lang_emb=lang_embs)
295
+
296
+ return xenc, positions, cps_emb
297
+
298
+ def forward(self, in_ttoks, out_ttoks, languages, cpss, in_stoks, in_stoks_positions, out_stoks=None, loss=True, offset=None, xenc=None, xenc_positions=None, cps_emb=None):
299
+ if xenc is None:
300
+ xenc, cps_emb = self.run_encoder(in_ttoks, languages, cpss)
301
+
302
+ with record_function("decoder"):
303
+ x = (self.embeddings.embedding(in_stoks) +
304
+ self.embeddings.positional_embedding[in_stoks_positions] +
305
+ cps_emb).to(xenc[0].dtype)
306
+ x = self.decoder(x, in_stoks_positions, xenc, xenc_positions)
307
+ logits = self.embeddings.embedding.unembed(x)
308
+ logits = logits * self.tunables.output_mult / (self.width / self.base_width)
309
+
310
+ if loss is not None:
311
+ enc_logits = self.encoder.embedding.unembed(xenc[0])
312
+ enc_logits = enc_logits * self.tunables.output_mult / (self.width / self.base_width)
313
+ with record_function("loss"):
314
+ loss = F.cross_entropy(logits.transpose(-1,-2), out_stoks)
315
+ if self.training:
316
+ loss += 0.1 * F.cross_entropy(enc_logits.transpose(-1,-2), out_ttoks)
317
+
318
+ return logits, loss
319
+
320
+ #
321
+ # inference
322
+ #
323
+ @classmethod
324
+ def load_model(cls, ref="collabora/whisperspeech:t2s-small-en+pl.model",
325
+ repo_id=None, filename=None, local_filename=None):
326
+ if repo_id is None and filename is None and local_filename is None:
327
+ if ":" in ref:
328
+ repo_id, filename = ref.split(":", 1)
329
+ else:
330
+ local_filename = ref
331
+ if not local_filename:
332
+ local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
333
+ spec = torch.load(local_filename)
334
+ model = cls(**spec['config'], tunables=Tunables(**spec['tunables']))
335
+ model.load_state_dict(spec['state_dict'])
336
+ model.eval()
337
+ return model
338
+
339
+ def load_checkpoint(self, local_filename):
340
+ spec = torch.load(local_filename, map_location='cpu')
341
+ assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
342
+ state_dict = {k.replace('model.', ''):v
343
+ for k,v in spec['state_dict'].items()}
344
+ self.load_state_dict(state_dict)
345
+ return self
346
+
347
+ def save_model(self, fname):
348
+ torch.save(dict(config = self.__stored_args__,
349
+ tunables = dataclasses.asdict(self.tunables),
350
+ state_dict = self.state_dict()), fname)
351
+
352
+ def ensure_tokenizer(self):
353
+ assert not self.training
354
+ if self.tokenizer is None: self.tokenizer = CharTokenizer()
355
+
356
+ def switch_dtypes(self, dtype=torch.float16):
357
+ self.dtype = dtype
358
+ for n,m in self.named_modules():
359
+ # convert every leaf layer apart from the LayerNorms
360
+ if isinstance(m, (nn.Linear, nn.Embedding)):
361
+ m.to(dtype)
362
+ # take care of buffers ([kv]_cache, masks) that are not in the leaf layers
363
+ for bn,b in m.named_buffers(recurse=False):
364
+ setattr(m,bn,b.to(dtype))
365
+
366
+ def optimize(self, max_batch_size=1, dtype=torch.float16, torch_compile=True):
367
+ for emb in [self.embeddings.embedding, self.embeddings.embedding]:
368
+ emb.convert_for_eval()
369
+ for l in self.encoder.layers:
370
+ l.attn.convert_for_eval()
371
+ for l in self.decoder.layers:
372
+ l.attn.convert_for_eval()
373
+ l.cross_attn.convert_for_eval()
374
+ l.setup_kv_cache(max_batch_size, self.stoks_len, self.ttoks_len)
375
+ self.switch_dtypes(dtype)
376
+ if torch_compile:
377
+ self.generate_next = torch.compile(self.generate_next, mode="reduce-overhead", fullgraph=True)
378
+
379
+ @property
380
+ def device(self):
381
+ return next(self.parameters()).device
382
+
383
+ # from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
384
+ def multinomial_sample_one_no_sync(self, probs_sort): # Does multinomial sampling without a cuda synchronization
385
+ q = torch.empty_like(probs_sort).exponential_(1)
386
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
387
+
388
+ def logits_to_probs(self, logits, T=1.0, top_k=None):
389
+ logits = logits / max(T, 1e-5)
390
+
391
+ logits[self.embeddings.embedding.codes:] = -torch.inf
392
+ if top_k is not None:
393
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
394
+ pivot = v.select(-1, -1).unsqueeze(-1)
395
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
396
+
397
+ probs = torch.nn.functional.softmax(logits, dim=-1)
398
+ return probs
399
+
400
+ def sample(self, logits, T=1.0, top_k=None):
401
+ probs = self.logits_to_probs(logits[0,-1], T, top_k)
402
+ idx_next = self.multinomial_sample_one_no_sync(probs)
403
+ return idx_next
404
+
405
+ def generate_one(self, toks, toks_positions, cps_emb, xenc, xenc_positions, T, top_k):
406
+ probs, _ = self(None, None, None, None, toks, toks_positions, loss=None, xenc=xenc, xenc_positions=xenc_positions, cps_emb=cps_emb)
407
+ return self.sample(probs, T, top_k)
408
+
409
+ def generate_next(self, *args, **kwargs):
410
+ return self.generate_one(*args, **kwargs)
411
+
412
+ @torch.no_grad()
413
+ def prep(self, txt, cps=15, lang="en"):
414
+ dev = self.device
415
+ ttoks = torch.tensor(self.tokenizer.encode(txt), device=dev)
416
+ ttoks = F.pad(ttoks, (0, self.ttoks_len - len(ttoks)), value=self.tokenizer.eot).unsqueeze(0)
417
+ cpss = torch.tensor([cps], device=dev)
418
+ langs = torch.tensor([languages.to_id(lang)], device=dev)
419
+ return ttoks, cpss, langs
420
+
421
+ @torch.no_grad()
422
+ def generate(self, txt, cps=15, lang="en", N=None, T=0.7, top_k=None, step=None, show_progress_bar=True):
423
+ self.ensure_tokenizer()
424
+ N = N or self.stoks_len
425
+ dev = self.device
426
+ ttoks = []
427
+ langs = []
428
+ if isinstance(lang, list):
429
+ lang0 = lang[0]
430
+ assert isinstance(txt, list), "lang and txt have to be both lists or strings"
431
+ for txt, lang in zip(txt, lang):
432
+ tt = self.tokenizer.encode(txt)
433
+ ttoks += tt
434
+ langs += [languages.to_id(lang)] * len(tt)
435
+ elif isinstance(lang, torch.Tensor):
436
+ langs = lang
437
+ ttoks = self.tokenizer.encode(txt)
438
+ else:
439
+ lang0 = lang
440
+ ttoks = self.tokenizer.encode(txt)
441
+ langs = torch.tensor([languages.to_id(lang)], device=dev).unsqueeze(0)
442
+ ttoks = torch.tensor(ttoks, device=dev)
443
+ ttoks = F.pad(ttoks, (1, self.ttoks_len - len(ttoks) - 1), value=self.tokenizer.eot).unsqueeze(0)
444
+ cpss = torch.tensor([cps], device=dev)
445
+ if not isinstance(langs, torch.Tensor):
446
+ langs = torch.tensor(langs, device=dev)
447
+ langs = F.pad(langs, (1, self.ttoks_len - len(langs) - 1), value=languages.to_id(lang0)).unsqueeze(0)
448
+ it = range(0,N-1)
449
+ if show_progress_bar: it = progress_bar(it)
450
+
451
+ toks = torch.zeros((1,N), dtype=torch.long, device=dev)
452
+ toks[:,0] = self.stoks_codes-1
453
+ toks_positions = torch.arange(N, device=dev)
454
+ with record_function("encode"):
455
+ xenc, xenc_positions, cps_emb = self.run_encoder(ttoks, langs, cpss)
456
+ toks_positions = torch.arange(N+1, device=dev)
457
+ # contrary to S2A this model works without prefill and is actually a tiny bit faster
458
+ # with record_function("prefill"):
459
+ # toks[0,1] = self.generate_one(toks[:,:1], toks_positions[:1], cps_emb, xenc, xenc_positions, T, top_k)
460
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
461
+ for i in it:
462
+ toks[0,i+1] = self.generate_next(toks[:,i:i+1], toks_positions[i:i+1], cps_emb, xenc, xenc_positions, T, top_k)
463
+ if i % 25 == 0 and toks[0,i+1] == self.stoks_codes-1: return toks[0,:i+1]
464
+
465
+ # for profiling, debugging or early exit
466
+ if step is not None: step()
467
+ return toks[0,:]
468
+
469
+ @torch.no_grad()
470
+ def generate_batch(self, txts, N=None, T=1.1, top_k=7, show_progress_bar=True):
471
+ self.ensure_tokenizer()
472
+ N = self.stoks_len
473
+ dev = self.device
474
+ ttoks = []
475
+ for txt in txts:
476
+ ttoks_ = torch.tensor(self.tokenizer.encode(txt), device=dev)
477
+ ttoks_ = F.pad(ttoks_, (0, self.ttoks_len - len(ttoks_)), value=self.tokenizer.eot).unsqueeze(0)
478
+ ttoks.append(ttoks_)
479
+ ttoks = torch.cat(ttoks, dim=0)
480
+ toks = torch.zeros((len(ttoks),N), dtype=torch.long, device=dev)
481
+ it = range(N)
482
+ if show_progress_bar: it = progress_bar(it)
483
+ for i in it:
484
+ p, _ = self(ttoks, toks[:,:i], loss=None)
485
+ last_p = p[:,-1]
486
+ if top_k:
487
+ last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf
488
+ tok = torch.multinomial((last_p / float(T)).softmax(-1), 1)
489
+ toks[:,i] = tok[:,0]
490
+ if (toks[:,i] == self.stoks_codes-1).all(): return toks[:,:i]
491
+ return toks
492
+
493
+ # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 18
494
+ def _make_model(size:str, tunables:Tunables=Tunables(), dataset=None, **kwargs):
495
+ kwargs = dict(stoks_len = dataset.stoks_len, ttoks_len = dataset.ttoks_len, tunables=tunables, **kwargs)
496
+ if 'stoks_codes' not in kwargs: kwargs['stoks_codes'] = dataset.stoks_codes
497
+ if size == 'micro':
498
+ return TSARTransformer(depth=2, n_head=3, ffn_mult=1, **kwargs)
499
+ if size == 'tiny':
500
+ return TSARTransformer(depth=4, n_head=6, **kwargs)
501
+ if size == 'base':
502
+ return TSARTransformer(depth=6, n_head=8, **kwargs)
503
+ if size == 'small':
504
+ return TSARTransformer(depth=12, n_head=12, **kwargs)
505
+ if size == 'small+':
506
+ return TSARTransformer(depth=12, n_head=16, **kwargs)
507
+ if size == 'medium':
508
+ return TSARTransformer(depth=24, n_head=16, **kwargs)
509
+
510
+ def make_model(size:str, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
511
+ from whisperspeech import vq_stoks
512
+
513
+ if frozen_embeddings_model:
514
+ vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model)
515
+ model = _make_model(size, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
516
+ model.load_frozen_semantic_embeddings(vqmodel)
517
+ else:
518
+ model = _make_model(size, tunables, dataset, mode=mode)
519
+ return model
whisperspeech/train.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B1. Training.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['SimpleVisual', 'validate', 'train']
5
+
6
+ # %% ../nbs/B1. Training.ipynb 2
7
+ import io
8
+ import time
9
+ import random
10
+ from pathlib import Path
11
+
12
+ from fastprogress import progress_bar, master_bar
13
+ import fastprogress
14
+
15
+ import numpy as np
16
+ import pylab as plt
17
+ import math
18
+
19
+ import IPython
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.utils.data.dataloader import DataLoader
24
+ from torch.profiler import record_function
25
+
26
+ import webdataset as wds
27
+
28
+ torch.backends.cudnn.benchmark = True
29
+ torch.backends.cudnn.enabled = True
30
+ torch.backends.cuda.matmul.allow_tf32 = True
31
+ torch.set_float32_matmul_precision('medium')
32
+
33
+ # %% ../nbs/B1. Training.ipynb 3
34
+ class SimpleVisual:
35
+ def __init__ (self, model, masterbar, total_steps):
36
+ self.model = model
37
+ self.masterbar = masterbar
38
+ self.total_steps = total_steps
39
+ self.epochs = total_steps // masterbar.main_bar.total
40
+
41
+ gs = plt.GridSpec(2, 1, height_ratios=[3,1])
42
+ graph_fig = plt.figure(figsize=(10,6))
43
+ self.graph_fig = graph_fig
44
+ self.loss_p = graph_fig.add_subplot(gs[0])
45
+ self.lr_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p)
46
+ self.lr_p.tick_params('x', labelbottom=False)
47
+ self.graph_out = None
48
+
49
+ self.its = []
50
+ self.train_losses = []
51
+ self.val_losses = []
52
+ self.lr_history = []
53
+
54
+ def show(self):
55
+ self.start_t = time.time()
56
+ self.masterbar.write(["samples", "train", "val", "time"], table=True)
57
+ self.graph_out = display(self.graph_fig, display_id=True, clear=True)
58
+
59
+ def hide(self):
60
+ if self.graph_out is not None:
61
+ self.graph_out.update(IPython.display.HTML(''))
62
+
63
+ def plot(self):
64
+ loss_p, lr_p = self.loss_p, self.lr_p
65
+ loss_p.clear()
66
+ loss_p.plot(self.its, self.train_losses)
67
+ loss_p.plot(self.its, self.val_losses)
68
+ loss_p.set_xlim(0, self.total_steps)
69
+ loss_p.set_yscale('log')
70
+ lr_p.clear()
71
+ lrs = np.array(self.lr_history)
72
+ lr_p.plot(self.its, lrs)
73
+ self.graph_out.update(self.graph_fig)
74
+
75
+ def add_data(self, it, lr, train_loss, val_los):
76
+ self.its.append(it)
77
+ self.train_losses.append(train_loss)
78
+ self.val_losses.append(val_los)
79
+ self.lr_history.append(lr)
80
+ self.plot()
81
+
82
+ def add_table_row(self, it, avg_train_loss, val_loss):
83
+ elapsed_t = time.time() - self.start_t
84
+ self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True)
85
+
86
+ def on_iter(self, bar, it, avg_train_loss, val_loss):
87
+ epoch = math.ceil(it / self.total_steps * self.epochs)
88
+ bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}"
89
+
90
+ # %% ../nbs/B1. Training.ipynb 4
91
+ # FIXME: we need to keep this synchronised with the validation code below...
92
+ def validate(model, val, half=True, bs=16, drop_last=False, dl_workers=8, device="cuda"):
93
+ if isinstance(val, torch.utils.data.IterableDataset):
94
+ val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
95
+ .unbatched().shuffle(1024).batched(bs)
96
+ else:
97
+ val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last)
98
+
99
+ with torch.no_grad():
100
+ val_loss = 0
101
+ val_samples = 0
102
+ for args in val_loader:
103
+ args = [x.to(device, non_blocking=True) for x in args]
104
+ with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
105
+ ps, loss = model(*args)
106
+ N = args[0].shape[0]
107
+ val_loss += loss.mean().item() * N
108
+ val_samples += N
109
+ val_loss = val_loss / val_samples
110
+
111
+ return val_loss
112
+
113
+ # %% ../nbs/B1. Training.ipynb 5
114
+ def train(checkpoint_path, model, train, val, half=True, bs=16, lr=1e-4, drop_last=False,
115
+ weight_decay=0.1, warmup_steps=10000, epochs=10, clip_gradient_norm=None,
116
+ dl_workers=8, visual_class = SimpleVisual, profiler=None,
117
+ run_valid_every_iters=8000, table_row_every_iters=80000, chkpt_every_iters=None,
118
+ device="cuda", trainable_params=None):
119
+ if chkpt_every_iters is None:
120
+ chkpt_every_iters = table_row_every_iters
121
+
122
+ mb = master_bar(range(epochs))
123
+ if isinstance(train, torch.utils.data.IterableDataset):
124
+ pct_start = min(0.3, warmup_steps / (epochs * (train.total_samples//bs)))
125
+ visual = visual_class(model, mb, epochs * train.total_samples)
126
+ # pct_start = min(0.3, warmup_steps / (epochs * len(train)))
127
+ # visual = visual_class(model, mb, epochs*len(train)*bs)
128
+ else:
129
+ pct_start = min(0.3, warmup_steps / (epochs * len(train) / bs))
130
+ visual = visual_class(model, mb, epochs*len(train))
131
+ model.visual = visual
132
+
133
+ Path(checkpoint_path).mkdir(exist_ok=True)
134
+
135
+ if isinstance(train, torch.utils.data.IterableDataset):
136
+ # train_loader = DataLoader(train, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False, shuffle=False)
137
+ # val_loader = DataLoader(val, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False)
138
+ train_loader = wds.WebLoader(train, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
139
+ .unbatched().shuffle(1024).batched(bs, partial=False)
140
+ val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
141
+ .unbatched().shuffle(1024).batched(bs)
142
+ else:
143
+ train_loader = DataLoader(train, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last, shuffle=True)
144
+ val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last)
145
+
146
+ val_loss = torch.nan
147
+ avg_train_loss = torch.nan
148
+
149
+ if hasattr(model, 'setup'):
150
+ model.setup(device)
151
+
152
+ try:
153
+ scheduler = None
154
+
155
+ if trainable_params is None: trainable_params = model.parameters()
156
+ all_params = set(trainable_params)
157
+ customized_params = set()
158
+ groups = []
159
+ group_map = {}
160
+ for name,m in model.named_modules():
161
+ if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'):
162
+ m_trainable = [x for x in m.parameters() if x in all_params]
163
+ if not m_trainable: continue
164
+ customized_params |= set(m_trainable)
165
+ m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay
166
+ m_lr = lr * getattr(m, 'lr_scale', 1)
167
+ group = group_map.get((m_wd, m_lr), None)
168
+ if not group:
169
+ group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr}
170
+ groups.append(group)
171
+ group_map[(m_wd, m_lr)] = group
172
+ group['params'] += m_trainable
173
+ group['names'].append(name)
174
+
175
+ other_params = all_params - customized_params
176
+
177
+ if other_params:
178
+ groups = groups + [
179
+ {"names": ["other"], "params": list(other_params), "weight_decay": weight_decay },
180
+ ]
181
+
182
+ optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), fused=device!='cpu', params=groups)
183
+ model._optimizer = optimizer
184
+ scaler = torch.cuda.amp.GradScaler(enabled=half)
185
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
186
+ optimizer, pct_start=pct_start, steps_per_epoch=math.ceil(train.total_samples/bs), epochs=epochs,
187
+ max_lr=[pg.get('lr', lr) for pg in groups],
188
+ final_div_factor=25)
189
+
190
+ it = 0
191
+ next_val_it = it + 50
192
+ next_chkpt_it = chkpt_every_iters
193
+ next_table_it = table_row_every_iters
194
+
195
+ visual.show()
196
+
197
+ running_loss = [0]
198
+
199
+ for epoch in mb:
200
+ bar = progress_bar(train_loader, total=train.total_samples//bs, parent=mb)
201
+ for args in bar:
202
+ with record_function("forward"):
203
+ args = [x.to(device, non_blocking=True) for x in args]
204
+
205
+ # zero the parameter gradients
206
+ optimizer.zero_grad(set_to_none=True)
207
+
208
+ with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
209
+ ps, loss = model(*args)
210
+ loss = loss.mean()
211
+
212
+ with record_function("backward"):
213
+ scaler.scale(loss).backward()
214
+
215
+ if clip_gradient_norm:
216
+ scaler.unscale_(optimizer)
217
+ # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
218
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_gradient_norm)
219
+
220
+ scaler.step(optimizer)
221
+ scaler.update()
222
+
223
+ scheduler.step()
224
+
225
+ if profiler is not None: profiler.step()
226
+
227
+ with record_function("running_loss"):
228
+ running_loss.append(loss.item())
229
+ running_loss = running_loss[-5:]
230
+ avg_train_loss = sum(running_loss)/len(running_loss)
231
+
232
+ if it >= next_chkpt_it:
233
+ with record_function("checkpoint"):
234
+ next_chkpt_it += chkpt_every_iters
235
+ torch.save(model.state_dict(), f'{checkpoint_path}/{it:08d}.pt')
236
+
237
+ if it >= next_val_it:
238
+ next_val_it += run_valid_every_iters
239
+ with record_function("validation"):
240
+ with record_function("model.eval"):
241
+ model.eval()
242
+ with torch.no_grad():
243
+ val_loss = 0
244
+ val_samples = 0
245
+ for args in val_loader:
246
+ args = [x.to(device, non_blocking=True) for x in args]
247
+ with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
248
+ ps, loss = model(*args)
249
+ N = args[0].shape[0]
250
+ val_loss += loss.mean().item() * N
251
+ val_samples += N
252
+ val_loss = val_loss / val_samples
253
+ with record_function("model.train"):
254
+ model.train()
255
+ with record_function("plotting"):
256
+ visual.add_data(it, scheduler.get_last_lr(), avg_train_loss, val_loss)
257
+
258
+ if it >= next_table_it:
259
+ visual.add_table_row(it, avg_train_loss, val_loss)
260
+ next_table_it += table_row_every_iters
261
+
262
+ it += bs
263
+ visual.on_iter(bar, it, avg_train_loss, val_loss)
264
+ except KeyboardInterrupt:
265
+ mb.write(f"interrupted")
266
+ mb.show()
267
+ pass
268
+ finally:
269
+ visual.add_table_row(it, avg_train_loss, val_loss)
270
+ mb.show()
271
+ visual.hide()
whisperspeech/train_multi.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B2. Training (Lightning).ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = []
5
+
6
+ # %% ../nbs/B2. Training (Lightning).ipynb 2
7
+ import io
8
+ import time
9
+ import random
10
+ from pathlib import Path
11
+
12
+ from fastprogress import progress_bar, master_bar
13
+ import fastprogress
14
+ import wandb
15
+
16
+ import numpy as np
17
+ import pylab as plt
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.utils.data.dataloader import DataLoader
22
+ from torch.profiler import record_function
23
+
24
+ # %% ../nbs/B2. Training (Lightning).ipynb 3
25
+ import lightning.pytorch as pl
26
+ import math
27
+
28
+ class TrainingTask(pl.LightningModule):
29
+ def __init__(self, model, model_hparams=None):
30
+ super().__init__()
31
+ self.model = model
32
+ self.model_hparams = model_hparams
33
+
34
+ def on_fit_start(self):
35
+ if getattr(self.model, 'setup'):
36
+ self.model.setup(self.device)
37
+
38
+ def configure_optimizers(self):
39
+ """ Initialize AdamW optimizer"""
40
+ lr = self.model_hparams['lr0']
41
+ weight_decay = self.model_hparams['weight_decay']
42
+
43
+ all_params = set(model.parameters())
44
+ customized_params = set()
45
+ groups = []
46
+ group_map = {}
47
+ for name,m in model.named_modules():
48
+ if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'):
49
+ customized_params |= set(m.parameters())
50
+ m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay
51
+ m_lr = lr * getattr(m, 'lr_scale', 1)
52
+ group = group_map.get((m_wd, m_lr), None)
53
+ if not group:
54
+ group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr}
55
+ groups.append(group)
56
+ group_map[(m_wd, m_lr)] = group
57
+ group['params'] += m.parameters()
58
+ group['names'].append(name)
59
+
60
+ other_params = all_params - customized_params
61
+
62
+ param_groups = groups + [
63
+ {"names": ["other"], "params": list(other_params), "weight_decay": weight_decay },
64
+ ]
65
+
66
+ optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), params=param_groups)
67
+
68
+ # modified from https://github.com/Lightning-AI/lightning/issues/5449#issuecomment-1501597319
69
+ def num_steps_per_epoch() -> int:
70
+ """Get number of steps"""
71
+ # Accessing _data_source is flaky and might break
72
+ dataset = self.trainer.fit_loop._data_source.dataloader()
73
+ dataset_size = len(dataset)
74
+ # math.ceil so always overestimate (underestimating throws exceptions)
75
+ num_steps = math.ceil(dataset_size / self.trainer.accumulate_grad_batches)
76
+ return num_steps
77
+
78
+ total_steps = self.model_hparams['epochs'] * num_steps_per_epoch()
79
+ self.model_hparams['pct_start'] = min(0.3, self.model_hparams['warmup_steps'] / total_steps)
80
+
81
+ print(f"{self.model_hparams['epochs']=} epochs x {num_steps_per_epoch()=} steps")
82
+
83
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
84
+ optimizer,
85
+ pct_start=self.model_hparams['pct_start'],
86
+ max_lr=[pg.get('lr', lr) for pg in param_groups],
87
+ steps_per_epoch=num_steps_per_epoch(),
88
+ epochs=int(self.model_hparams['epochs']),
89
+ final_div_factor=25
90
+ )
91
+
92
+ return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]
93
+
94
+ def training_step(self, train_batch, batch_idx):
95
+ train_logits, train_loss = self.model.forward(*train_batch)
96
+
97
+ self.log("train_loss", train_loss, sync_dist=True)
98
+ return train_loss
99
+
100
+ def validation_step(self, val_batch, batch_idx):
101
+ val_logits, val_loss = self.model.forward(*val_batch)
102
+
103
+ self.log("val_loss", val_loss, sync_dist=True)
104
+ return val_loss
105
+
106
+ def on_validation_epoch_end(self):
107
+ if hasattr(self.model, 'get_metrics'):
108
+ self.log_dict({'metrics/'+k:v for k,v in self.model.get_metrics().items()}, sync_dist=True)
109
+
110
+ def test_step(self, val_batch, batch_idx):
111
+ test_logits, test_loss = self.model.forward(*val_batch)
112
+
113
+ self.log("test_loss", test_loss, sync_dist=True)
114
+ return test_loss
115
+
116
+ # %% ../nbs/B2. Training (Lightning).ipynb 4
117
+ from fastcore.script import anno_parser
118
+ import shlex
119
+
120
+ # watch out: we can only pass Python values as keyword arguments (not positional)
121
+ # everything else has to be a string
122
+ def parse_and_call(name, fun, args, kwargs={}, log_to_wandb=True):
123
+ p = anno_parser(fun)
124
+ args = p.parse_args(args).__dict__
125
+ args.pop('xtra'); args.pop('pdb')
126
+ args.update({k:v for k, v in kwargs.items()})
127
+ if log_to_wandb and type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
128
+ wandb_logger.experiment.config[name] = {k:v for k,v in args.items() if k not in ['dataset', 'tunables']}
129
+ return fun(**args)
130
+
131
+ # %% ../nbs/B2. Training (Lightning).ipynb 8
132
+ import argparse
133
+
134
+ parser = argparse.ArgumentParser()
135
+ parser.add_argument('--task', type=str, help='Task to train')
136
+ parser.add_argument('--seed', type=int, default=0, help='Global training seed')
137
+ parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
138
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
139
+ parser.add_argument('--input-dir', type=str, default='', help='input data path') # fixed in the model for now
140
+ parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints/", help="directory to save the checkpoints")
141
+ parser.add_argument('--epochs', type=int, default=10, help='total training epochs')
142
+ parser.add_argument('--validate-every-n-steps', type=int, default=500, help='how training steps to run between validations')
143
+ parser.add_argument('--weight-decay', type=float, default=1e-2, help='optimizer weight decay')
144
+ parser.add_argument('--lr0', type=float, default=1e-4, help='optimizer initial learning rate')
145
+ parser.add_argument('--clip-gradient-norm', type=float, default=None, help='enable gradient norm clipping')
146
+ parser.add_argument('--accumulate-grad-batches', type=int, default=1, help='perform the optimizer step only after going through several batches of samples')
147
+ parser.add_argument('--precision', type=str, default="16-mixed", help="floating point precision")
148
+ parser.add_argument('--warmup-steps', type=int, default=10000, help='total number steps during which the learning rate rises (defaults to 10k updates)')
149
+ parser.add_argument('--tunables', type=str, default="", help='tunable hyperparameters')
150
+ parser.add_argument('--resume-from', type=Path, default=None, help='resume training from the given checkpoint')
151
+ parser.add_argument('--strategy', type=str, default='ddp', help='distributed training strategy')
152
+ parser.add_argument('--wandb-suffix', type=str, default=None, help='W&B project name suffix')
153
+ parser.add_argument('--wandb-task-name', type=str, default=None, help='Task name for the W&B project name')
154
+
155
+ args = parser.parse_args().__dict__
156
+
157
+ task_args: list = shlex.split(args.pop("task"))
158
+ task_name, task_args = task_args[0], task_args[1:]
159
+ input_args: list = shlex.split(args.pop("input_dir"))
160
+ checkpoint_dir: str = args.pop("checkpoint_dir")
161
+ num_workers: int = args.pop("workers")
162
+ batch_size: int = args.pop("batch_size")
163
+ epochs: int = args.pop("epochs")
164
+ tunables_args: list = shlex.split(args.pop("tunables"))
165
+
166
+ hyp_params = {}
167
+ hyp_params['batch_size'] = batch_size
168
+ hyp_params['warmup_steps'] = args['warmup_steps']
169
+ hyp_params['weight_decay'] = args['weight_decay']
170
+ hyp_params['clip_gradient_norm'] = args['clip_gradient_norm']
171
+ hyp_params['accumulate_grad_batches'] = args['accumulate_grad_batches']
172
+ hyp_params['precision'] = args['precision']
173
+ hyp_params['lr0'] = args['lr0']
174
+ hyp_params['epochs'] = epochs
175
+ hyp_params['strategy'] = args['strategy']
176
+
177
+ # %% ../nbs/B2. Training (Lightning).ipynb 9
178
+ from lightning.pytorch.loggers import WandbLogger
179
+ from lightning.pytorch.callbacks import LearningRateMonitor
180
+ import datetime
181
+ import webdataset as wds
182
+ import importlib
183
+
184
+ torch.set_float32_matmul_precision('medium')
185
+
186
+ project = f"WhisperSpeech-{args['wandb_task_name'] or task_name}"
187
+ if args['wandb_suffix']:
188
+ project += "-"+args['wandb_suffix']
189
+
190
+ wandb_logger = WandbLogger(project=project)
191
+
192
+ ckpt_callback = pl.callbacks.ModelCheckpoint(
193
+ dirpath=f'{task_name}-{epochs}e',
194
+ filename=task_name+"-{epoch}-{step}-{val_loss:.2f}",
195
+ monitor="val_loss",
196
+ save_top_k=4,
197
+ train_time_interval=datetime.timedelta(minutes=5),
198
+ )
199
+
200
+ lr_monitor_callback = LearningRateMonitor(logging_interval='step')
201
+
202
+ from torch.utils.data import DataLoader
203
+
204
+ task = importlib.import_module("whisperspeech."+task_name)
205
+
206
+ train_ds, val_ds = parse_and_call('dataset', task.load_datasets, input_args)
207
+
208
+ tunables = None
209
+ if hasattr(task, "Tunables"):
210
+ import dataclasses
211
+ tunables = parse_and_call('tunables', task.Tunables, tunables_args, log_to_wandb=False)
212
+ if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
213
+ wandb_logger.experiment.config['tunables'] = dataclasses.asdict(tunables)
214
+
215
+ for name in ["lr0", "clip_gradient_norm", "weight_decay", "warmup_steps"]:
216
+ val = getattr(tunables, name, None)
217
+ if val is not None: hyp_params[name] = val
218
+
219
+ if isinstance(train_ds, torch.utils.data.IterableDataset):
220
+ dl_batch_size, dl_shuffle = None, False
221
+ pin_memory = False
222
+ else:
223
+ dl_batch_size, dl_shuffle = batch_size, True
224
+ pin_memory = True
225
+
226
+ val_loader = wds.WebLoader(val_ds,
227
+ batch_size=dl_batch_size,
228
+ num_workers=num_workers,
229
+ drop_last=False,
230
+ pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(val_ds.total_samples // batch_size)
231
+
232
+ train_loader = wds.WebLoader(train_ds,
233
+ batch_size=dl_batch_size,
234
+ num_workers=num_workers,
235
+ drop_last=False,
236
+ shuffle=dl_shuffle,
237
+ pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(train_ds.total_samples // batch_size)
238
+
239
+ model_kwargs = dict(dataset=train_ds)
240
+ if tunables is not None: model_kwargs['tunables'] = tunables
241
+ model = parse_and_call('model', task.make_model, task_args, model_kwargs)
242
+
243
+ task = TrainingTask(model, model_hparams=hyp_params)
244
+
245
+ trainer = pl.Trainer(strategy=hyp_params['strategy'],
246
+ max_epochs=hyp_params['epochs'],
247
+ accelerator="gpu",
248
+ profiler="simple",
249
+ precision=hyp_params['precision'],
250
+ gradient_clip_val=hyp_params['clip_gradient_norm'],
251
+ accumulate_grad_batches=hyp_params['accumulate_grad_batches'],
252
+ val_check_interval=args.pop("validate_every_n_steps"),
253
+ enable_checkpointing=True,
254
+ logger=wandb_logger,
255
+ callbacks=[ckpt_callback, lr_monitor_callback])
256
+
257
+ if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
258
+ wandb_logger.experiment.config.update(hyp_params)
259
+
260
+ kwargs = {}
261
+ if 'resume_from' in args:
262
+ kwargs['ckpt_path'] = args['resume_from']
263
+ trainer.fit(model=task, train_dataloaders=train_loader, val_dataloaders=val_loader, **kwargs)
whisperspeech/utils.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/D. Common dataset utilities.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['shard_glob', 'join_datasets', 'resampler', 'derived_name', 'derived_dataset', 'merge_in', 'AtomicTarWriter',
5
+ 'readlines']
6
+
7
+ # %% ../nbs/D. Common dataset utilities.ipynb 1
8
+ import os
9
+ import torch
10
+ import torchaudio
11
+ from pathlib import Path
12
+ import webdataset as wds
13
+ from contextlib import contextmanager
14
+
15
+ import torch.nn.functional as F
16
+
17
+ # %% ../nbs/D. Common dataset utilities.ipynb 2
18
+ def shard_glob(input):
19
+ if '{' in input:
20
+ return wds.shardlists.expand_urls(input)
21
+ if isinstance(input, (Path, str)):
22
+ path = Path(input)
23
+ if path.is_dir():
24
+ glob = '*.tar.gz'
25
+ else:
26
+ glob = path.name
27
+ path = path.parent
28
+ input = Path(path).glob(glob)
29
+ else:
30
+ raise ArgumentError("input should be either a list or a path with an optional glob specifier")
31
+ return [str(x) for x in input]
32
+
33
+ # %% ../nbs/D. Common dataset utilities.ipynb 3
34
+ class join_datasets(torch.utils.data.IterableDataset):
35
+ def __init__(self, datasets):
36
+ self.datasets = datasets
37
+
38
+ def __iter__(self):
39
+ probs = torch.tensor([getattr(ds, 'weight', 1) for ds in self.datasets], dtype=torch.float)
40
+ its = [iter(ds) for ds in self.datasets]
41
+ while True:
42
+ try:
43
+ yield next(its[torch.multinomial(probs, 1)])
44
+ except StopIteration:
45
+ return
46
+
47
+ def __len__(self):
48
+ return sum([ds.total_samples for ds in self.datasets])
49
+
50
+ # %% ../nbs/D. Common dataset utilities.ipynb 5
51
+ def resampler(newsr = 24000, key = 'samples_24k'):
52
+ _last_sr = None
53
+ tform = None
54
+
55
+ def _resample(samples):
56
+ for s in samples:
57
+ sr = s['sample_rate']
58
+ if sr != newsr:
59
+ if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr)
60
+ s[key] = tform(s['samples'])
61
+ else:
62
+ s[key] = s['samples']
63
+ yield s
64
+
65
+ return _resample
66
+
67
+ # %% ../nbs/D. Common dataset utilities.ipynb 6
68
+ def derived_name(input, kind, base="audio", suffix=".gz", dir=None):
69
+ dir = Path(dir) if dir else Path(input).parent
70
+ return str(dir/(Path(input).name.replace(f"-{base}-", f"-{kind}-") + suffix))
71
+
72
+ # %% ../nbs/D. Common dataset utilities.ipynb 7
73
+ def derived_dataset(kind, base='audio', suffix=".gz", decoders=[], dir=None):
74
+ def deriver(url):
75
+ url = str(derived_name(url, kind, base=base, suffix=suffix, dir=dir))
76
+ return wds.WebDataset(
77
+ wds.SimpleShardList([url])
78
+ ).decode(*decoders)
79
+ return deriver
80
+
81
+ # %% ../nbs/D. Common dataset utilities.ipynb 8
82
+ def merge_in(dataset_fun):
83
+ """Merge a dataset into the current one returning samples with the union of keys. Pass in a function
84
+ that takes a URL of a sample and returns a dataset for it (called everytime the URL changes).
85
+
86
+ It requires (and validates) that both datasets have the same ordering of keys so you have
87
+ to use it before any sample shuffling. Shard shuffling is ok.
88
+ """
89
+ def merge_loop(main_samples):
90
+ #print("new merge loop:", dataset_fun)
91
+ merged_samples = None
92
+ cur_url = None
93
+ i = None
94
+ for s in main_samples:
95
+ url = s['__url__']
96
+ if url != cur_url:
97
+ # this will open a new file when we get the first sample with a new __url__
98
+ merged_samples = iter(dataset_fun(url))
99
+ cur_url = url
100
+ try:
101
+ merge_s = next(merged_samples)
102
+ except StopIteration:
103
+ # if the original shard got repeated we won't observe a __url__ change
104
+ # in this case restart the dataset from the beginning
105
+ merged_samples = iter(dataset_fun(url))
106
+ merge_s = next(merged_samples)
107
+ assert merge_s['__key__'] == s['__key__'], f"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}"
108
+ news = {}
109
+ news.update(merge_s)
110
+ news.update(s)
111
+ yield news
112
+ return merge_loop
113
+
114
+ # %% ../nbs/D. Common dataset utilities.ipynb 9
115
+ def split_to_chunks(stream, ikey='vad.npy', metakeys=[], pad_to_seconds=30, random_shift=False):
116
+ for s in stream:
117
+ audio, sr = s['audio']
118
+ imax = len(s[ikey]) - 1
119
+ for i,(ts,te) in enumerate(s[ikey]):
120
+ samples = audio[0,int(ts*sr):int(te*sr)]
121
+ if pad_to_seconds is not None:
122
+ padding = pad_to_seconds*sr-samples.shape[-1]
123
+ lpad = random.randint(0, padding) if random_shift else 0
124
+ samples = F.pad(samples, (lpad, padding-lpad))
125
+ subs = {"__key__": s['__key__'] + f"_{i:03d}",
126
+ "src_key": s['__key__'],
127
+ "__url__": s['__url__'],
128
+ "i": i, "imax": imax,
129
+ "tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr,
130
+ "lpad": lpad, "rpad": padding-lpad,
131
+ "lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr,
132
+ "samples": samples, "sample_rate": sr}
133
+ for k in metakeys:
134
+ subs[k] = s[k][i]
135
+ yield subs
136
+
137
+ # %% ../nbs/D. Common dataset utilities.ipynb 10
138
+ def vad_dataset(shards, ikey='vad.npy', kind='vad'):
139
+ return wds.WebDataset(shards).compose(
140
+ wds.decode(wds.torch_audio),
141
+ merge_in(derived_dataset(kind)),
142
+ wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio
143
+ wds.rename(audio="flac;mp3;wav;ogg"),
144
+ lambda x: split_to_chunks(x, ikey=ikey),
145
+ )
146
+
147
+ # %% ../nbs/D. Common dataset utilities.ipynb 11
148
+ @contextmanager
149
+ def AtomicTarWriter(name, throwaway=False):
150
+ tmp = name+".tmp"
151
+ with wds.TarWriter(tmp, compress=name.endswith('gz')) as sink:
152
+ yield sink
153
+ if not throwaway:
154
+ os.rename(tmp, name)
155
+
156
+ # %% ../nbs/D. Common dataset utilities.ipynb 12
157
+ def readlines(fname):
158
+ with open(fname) as file:
159
+ return [line.rstrip() for line in file]
whisperspeech/vad.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/1B. Voice activity detection.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = []
5
+
6
+ # %% ../nbs/1B. Voice activity detection.ipynb 3
7
+ import os
8
+ import torch
9
+ import torchaudio
10
+
11
+ from pathlib import Path
12
+ from fastprogress import progress_bar
13
+ from fastcore.script import call_parse
14
+
15
+ import whisperx
16
+ import random
17
+ import numpy as np
18
+ import webdataset as wds
19
+
20
+ # %% ../nbs/1B. Voice activity detection.ipynb 5
21
+ # some of the original file names have a dot in their name
22
+ # webdataset does not like it so let's patch it
23
+ def fix_dots_in_names(name):
24
+ name, ext = name.rsplit('.', 1)
25
+ return ".".join((name.replace('.', '_'), ext))
26
+
27
+ def load_dataset(url, decode=True, rename_files=None):
28
+ ds = wds.WebDataset(url, rename_files=rename_files)
29
+ if not decode: return ds
30
+ return ds.decode(wds.torch_audio)
31
+
32
+ # %% ../nbs/1B. Voice activity detection.ipynb 7
33
+ def extract_segments(vad_result, max_duration):
34
+ binarize = whisperx.vad.Binarize(max_duration=max_duration)
35
+ segments = binarize(vad_result)
36
+ return [(x.start, x.end) for x in segments.get_timeline()]
37
+
38
+ def segment_audio(vad_model, audio, sr=16000):
39
+ vad_result = vad_model({"waveform": audio, "sample_rate": sr})
40
+ return extract_segments(vad_result, 30)
41
+
42
+ # %% ../nbs/1B. Voice activity detection.ipynb 13
43
+ def flac_to_vad_name(input):
44
+ if '-flac-' in input:
45
+ return input.rsplit("/", 1)[1].replace('flac', 'vad') + ".gz"
46
+ else:
47
+ return input.rsplit("/", 1)[1].replace('raw', 'vad') + ".gz"
48
+
49
+ @call_parse
50
+ def process_shard(
51
+ input:str, # input shard URL/path
52
+ output:str=None, # output shard URL/path
53
+ fix_dots:bool=False, # fix dots in LibriLight filenames
54
+ ):
55
+ if output is None: output = flac_to_vad_name(input)
56
+
57
+ ds = torch.utils.data.DataLoader(load_dataset(input, rename_files=fix_dots_in_names if fix_dots else None), num_workers=2, batch_size=None)
58
+ vad_model = whisperx.vad.load_vad_model('cuda')
59
+
60
+ tmp = output+".tmp"
61
+ with wds.TarWriter(tmp) as sink:
62
+ for s in progress_bar(ds, total='noinfer'):
63
+ audio, sr = s.get('flac', s.get('wav', (None, None)))
64
+ if audio is None:
65
+ print(f"warning: '{s['__key__']}' does not contain an audio file")
66
+ continue
67
+ sink.write({
68
+ "__key__": s['__key__'],
69
+ "vad.npy": np.array(segment_audio(vad_model, audio, sr=sr), dtype=np.float16)
70
+ })
71
+ os.rename(tmp, output)
whisperspeech/vq_stoks.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/2B. Whisper quantization (semantic token) model.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['RQBottleneckTransformer', 'make_model']
5
+
6
+ # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 2
7
+ import io
8
+ import sys
9
+ import time
10
+ import torch
11
+ import torchaudio
12
+
13
+ # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 3
14
+ from pathlib import Path
15
+ import json
16
+ from fastprogress import progress_bar, master_bar
17
+ import fastprogress
18
+ import numpy as np
19
+ import pylab as plt
20
+ import pandas as pd
21
+ import random
22
+
23
+ import whisper
24
+ from huggingface_hub import hf_hub_download
25
+ from fastcore.basics import store_attr
26
+
27
+ from torch import nn
28
+ import torch.optim as optim
29
+ import torch.nn.functional as F
30
+ from torch.utils.data.dataloader import DataLoader
31
+ import webdataset as wds
32
+ from . import utils
33
+
34
+ from vector_quantize_pytorch import ResidualVQ
35
+
36
+ from fastcore.script import *
37
+
38
+ # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 9
39
+ def merge_in(dataset_fun):
40
+ """Merge a dataset into the current one returning samples with the union of keys. Pass in a function
41
+ that takes a URL of a sample and returns a dataset for it (called everytime the URL changes).
42
+
43
+ It requires (and validates) that both datasets have the same ordering of keys so you have
44
+ to use it before any sample shuffling. Shard shuffling is ok.
45
+ """
46
+ def merge_loop(main_samples):
47
+ #print("new merge loop:", dataset_fun)
48
+ merged_samples = None
49
+ cur_url = None
50
+ i = None
51
+ for s in main_samples:
52
+ url = s['__url__']
53
+ if url != cur_url:
54
+ # this will open a new file when we get the first sample with a new __url__
55
+ merged_samples = iter(dataset_fun(url))
56
+ cur_url = url
57
+ try:
58
+ merge_s = next(merged_samples)
59
+ except StopIteration:
60
+ # if the original shard got repeated we won't observe a __url__ change
61
+ # in this case restart the dataset from the beginning
62
+ merged_samples = iter(dataset_fun(url))
63
+ merge_s = next(merged_samples)
64
+ assert merge_s['__key__'] == s['__key__'], f"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}"
65
+ news = {}
66
+ news.update(merge_s)
67
+ news.update(s)
68
+ yield news
69
+ return merge_loop
70
+
71
+ # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 10
72
+ def derived_dataset(kind, key='audio'):
73
+ def deriver(url):
74
+ url = str(Path(url).parent/(Path(url).name.replace(key, kind) + ".gz"))
75
+ return wds.WebDataset(
76
+ wds.SimpleShardList([url])
77
+ ).decode()
78
+ return deriver
79
+
80
+ # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 17
81
+ def add_masks(samples):
82
+ for s in samples:
83
+ seconds = s['tend'] - s['tstart']
84
+ # a mask (downsampled to the Whisper encoder token rate of 50/s) is used
85
+ # to teach the model the concept of padding
86
+ # this let's us decode shorter sequences later
87
+ mask = torch.zeros(30*16000//320, dtype=torch.bool)
88
+ mask[:int(seconds * 16000) // 320] = 1
89
+ s['mask'] = mask
90
+ yield s
91
+
92
+ def tokenize_text(samples, ttoks_size=200, model="base.en", language="en"):
93
+ multilingual = not model.endswith(".en")
94
+ tokenizer = whisper.tokenizer.get_tokenizer(multilingual, language=language, task="transcribe")
95
+ for s in samples:
96
+ ttoks = tokenizer.encode(s['txt'])
97
+ tokens = list(tokenizer.sot_sequence) + ttoks
98
+ rpad = ttoks_size - len(tokens)
99
+ s['in_ttoks'] = F.pad(torch.tensor(tokens), (0, rpad), value=tokenizer.eot)
100
+ s['out_ttoks'] = F.pad(torch.tensor(tokens[1:] + [tokenizer.eot]), (0, rpad), value=-100)
101
+ yield s
102
+
103
+ # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 22
104
+ def load_dataset(
105
+ shard_spec:str,
106
+ proc_dataset_path:Path, # processed VAD and txt files
107
+ samples:int, # set the per-GPU sample count
108
+ txt_label:str="base.en-txt", # the label of the files containing transcriptions
109
+ model:str="base.en",
110
+ key:str="flac",
111
+ language:str=None,
112
+ validation:bool=False,
113
+ ):
114
+ from . import wh_transcribe
115
+ shards = utils.shard_glob(shard_spec)
116
+
117
+ if not language and model.endswith('en'): language = 'en'
118
+ assert language, "please provide the dataset language for multilang models"
119
+
120
+ same_on_all_nodes = lambda urls: urls # will only be used for validation
121
+ ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
122
+ wds.decode(wds.torch_audio),
123
+ wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio
124
+ wds.rename(audio="flac;mp3;wav;ogg"),
125
+ merge_in(derived_dataset(proc_dataset_path, 'vad', key=key)),
126
+ wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}),
127
+ wh_transcribe.split_to_chunks,
128
+ utils.resampler(16000, 'samples_16k'),
129
+ merge_in(derived_dataset(proc_dataset_path, txt_label, key=key)),
130
+ )
131
+ if 'librilight' in shards[0]:
132
+ ds = ds.compose(
133
+ # drop the first and last segment because they tend to be inaccurate
134
+ # (the transcriptions don't have the "LibriVox" headers and "end of chapter" suffixes)
135
+ wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']),
136
+ )
137
+ ds = ds.compose(
138
+ add_masks,
139
+ lambda x: tokenize_text(x, model=model, language=language),
140
+ wds.to_tuple('samples_16k', 'mask', 'in_ttoks', 'out_ttoks'),
141
+ wds.batched(32),
142
+ )
143
+ ds.total_samples = samples
144
+
145
+ return ds
146
+
147
+ # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 28
148
+ from whisperspeech.train import *
149
+ from whisperspeech.modules import *
150
+
151
+ # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 29
152
+ import dataclasses
153
+
154
+ def rand(start, end):
155
+ return random.random() * (end - start) + start
156
+
157
+ def logrand(start, end):
158
+ return 10**rand(math.log10(start), math.log10(end))
159
+
160
+ @dataclasses.dataclass
161
+ class Tunables:
162
+ init_std :float = 1.5
163
+ embeddings_std :float = 4.5e-2
164
+ embeddings_lr_scale: float = 1
165
+ output_mult :float = 1
166
+ query_mult :float = 2
167
+ rope :bool = True
168
+ mask_embs :bool = True # force embeddings corresponding to the input audio padding to a constant value
169
+ downsample_conv: bool = False
170
+ downsample_mean: bool = True
171
+
172
+ codebook_dim: int = 32
173
+ codebook_decay: float = 0.9
174
+
175
+ lr0 :float = .9e-3
176
+ clip_gradient_norm :float = 2
177
+ weight_decay :float = 1e-3
178
+ warmup_steps :float = 850
179
+
180
+ random :bool = False
181
+
182
+ def __post_init__(self):
183
+ # randomize the hyperparams if requested
184
+ if self.random:
185
+ self.init_std = logrand(1, 2)
186
+ self.embeddings_std = logrand(3e-2,6e-2)
187
+ self.embeddings_lr_scale = 2**rand(0,3)
188
+ self.output_mult = 2**rand(-3,3)
189
+ self.query_mult = logrand(1,8)
190
+ self.codebook_dim = int(logrand(30,50))
191
+ self.codebook_decay = logrand(0.86,0.95)
192
+ self.rope = True
193
+ self.mask_embs = True
194
+ self.downsample_mean = True
195
+
196
+ self.lr0 = logrand(.8e-3,1e-3)
197
+ self.clip_gradient_norm = 10**rand(-1,1)
198
+ self.warmup_steps = logrand(700,1000)
199
+
200
+ @staticmethod
201
+ def upgrade(args):
202
+ args = {k:v for k,v in args.items()}
203
+ def old_default(name, value):
204
+ if name not in args: args[name] = value
205
+ old_default('output_mult', 1)
206
+ old_default('query_mult', 1)
207
+ old_default('rope', False)
208
+ old_default('mask_embs', False)
209
+ old_default('downsample_conv', False)
210
+ old_default('downsample_mean', False)
211
+ if 'encoder_depth_ratio' in args: del args['encoder_depth_ratio']
212
+ if 'vq_codes' in args: del args['vq_codes']
213
+ return args
214
+
215
+ # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 30
216
+ import math
217
+
218
+ # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 31
219
+ class RQBottleneckTransformer(nn.Module):
220
+ def __init__(self, vq_codes=512, q_depth=12, depth=1, n_head=2, head_width=64, ffn_mult=4,
221
+ codebook_dim=2, threshold_ema_dead_code=2, use_cosine_sim = False, kl_loss_mul=1,
222
+ downsample=1,
223
+ whisper_model_name='tiny.en', tunables=Tunables()):
224
+ super().__init__()
225
+ width = n_head * head_width
226
+ store_attr("codebook_dim,vq_codes,q_depth,n_head,head_width,ffn_mult,depth,use_cosine_sim,downsample,whisper_model_name")
227
+ self.width = width
228
+ self.base_width = 3 * head_width
229
+ self.vq_codes = vq_codes
230
+ self.tunables = tunables
231
+ self.stoks_len = 1500//downsample
232
+ self.stoks_per_sec = self.stoks_len//30
233
+
234
+ qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)
235
+
236
+ self.kl_loss_mul = kl_loss_mul
237
+
238
+ n_mlp = width * ffn_mult
239
+ self.mlp = nn.Sequential(
240
+ nn.Linear(width, n_mlp), nn.GELU(), nn.Linear(n_mlp, width)
241
+ )
242
+ self.mlp_ln = LayerNorm(width)
243
+
244
+ if tunables.downsample_conv:
245
+ self.downsample_conv = nn.Conv1d(width, width, kernel_size=3, stride=downsample, padding=1)
246
+ else:
247
+ self.downsample_conv = None
248
+
249
+ if tunables.mask_embs: vq_codes = vq_codes + 1
250
+ self.rq = ResidualVQ(
251
+ dim = width,
252
+ codebook_size = vq_codes, # codebook size
253
+ decay = tunables.codebook_decay, # the exponential moving average decay, lower means the dictionary will change faster
254
+ commitment_weight = 1., # the weight on the commitment loss
255
+ threshold_ema_dead_code = threshold_ema_dead_code,
256
+ use_cosine_sim = use_cosine_sim,
257
+ codebook_dim = codebook_dim,
258
+ num_quantizers= 1,
259
+ )
260
+
261
+ self.ce_lossf = nn.CrossEntropyLoss(ignore_index=-100)
262
+ self.kl_lossf = nn.KLDivLoss(reduction='batchmean')
263
+
264
+ self.positional_embedding = nn.Embedding(1500, width) # FIXME: should be self.stoks_len
265
+
266
+ self.out_blocks = nn.Sequential(*[
267
+ ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(depth)
268
+ ])
269
+ self.ln_post = LayerNorm(width)
270
+
271
+ self.whmodel = None
272
+
273
+ self.apply(self.init_transformer)
274
+ self.register_buffer('val_true', torch.zeros(1).cuda())
275
+ self.register_buffer('val_total', torch.zeros(1).cuda())
276
+
277
+ def setup(self, device):
278
+ self.ensure_whisper(device)
279
+
280
+ def init_transformer(self, m):
281
+ if isinstance(m, LinearHead):
282
+ m.no_weight_decay = True
283
+ torch.nn.init.constant_(m.weight, 0)
284
+ elif isinstance(m, QueryHead):
285
+ m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
286
+ torch.nn.init.constant_(m.weight, 0)
287
+ elif isinstance(m, nn.Embedding):
288
+ m.no_weight_decay = True
289
+ m.lr_scale = self.tunables.embeddings_lr_scale
290
+ std = self.tunables.embeddings_std
291
+ torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
292
+ elif isinstance(m, nn.Linear):
293
+ m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
294
+ std = self.tunables.init_std / m.weight.shape[1]
295
+ torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
296
+ if m.bias is not None:
297
+ torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
298
+ elif isinstance(m, nn.LayerNorm):
299
+ m.no_weight_decay = True
300
+ torch.nn.init.constant_(m.bias, 0)
301
+ torch.nn.init.constant_(m.weight, 1)
302
+
303
+ @property
304
+ def device(self):
305
+ return next(self.parameters()).device
306
+
307
+ #
308
+ # training
309
+ #
310
+ @torch.no_grad()
311
+ def extract_teacher(self, samples, input_toks, output_toks):
312
+ embs = self.whmodel[0].encoder(whisper.log_mel_spectrogram(samples))
313
+ teacher_logits = self.whmodel[0].decoder(input_toks, embs)
314
+ # set teacher logits to 0 for padding positions so KLDivLoss ignores them
315
+ teacher_logits[output_toks == -100] = 0
316
+ return embs, teacher_logits
317
+
318
+ def downsample_embeddings(self, x):
319
+ if self.downsample_conv is not None:
320
+ return x[:,::self.downsample] + self.downsample_conv(x.transpose(-1,-2)).transpose(-2,-1)
321
+ elif self.tunables.downsample_mean:
322
+ bs,slen,depth = x.shape
323
+ return x.reshape(bs,slen//self.downsample,self.downsample,depth).mean(-2)
324
+ else:
325
+ return x[:,::self.downsample]
326
+
327
+ def forward(self, samples, mask, input_toks, output_toks):
328
+ embs, teacher_logits = self.extract_teacher(samples, input_toks, output_toks)
329
+
330
+ x = self.downsample_embeddings(embs)
331
+ x = x + self.mlp(self.mlp_ln(x))
332
+ # VQ bottleneck
333
+ quantized, self.indices, self.commit_loss = self.rq(x)
334
+ self.commit_loss = self.commit_loss.mean()
335
+
336
+ x = quantized.repeat_interleave(self.downsample, -2)
337
+ project_out = getattr(self.rq, 'project_out', None) or self.rq.layers[0].project_out
338
+ if self.tunables.mask_embs: x[~mask] = project_out(self.rq.layers[0]._codebook.embed[0,self.vq_codes])
339
+ positions = torch.arange(0, x.shape[-2], dtype=torch.long, device=x.device)
340
+ x = x + self.positional_embedding(positions)
341
+ x = self.ln_post(self.out_blocks(x))
342
+
343
+ logits = self.whmodel[0].decoder(input_toks, x)
344
+ self.ce_loss = self.ce_lossf(logits.view(-1,logits.shape[-1]), output_toks.view(-1))
345
+ self.kl_loss = self.kl_lossf(F.log_softmax(logits, dim=-1), F.softmax(teacher_logits, dim=-1))
346
+ loss = self.ce_loss + self.kl_loss_mul * self.kl_loss + self.commit_loss
347
+
348
+ if not self.training:
349
+ valid_toks = output_toks != -100
350
+ self.val_true += (logits.argmax(-1)[valid_toks] == output_toks[valid_toks]).float().sum()
351
+ self.val_total += valid_toks.float().sum()
352
+
353
+ return x, loss
354
+
355
+ def get_metrics(self):
356
+ metrics = {
357
+ 'acc_0': (self.val_true / self.val_total).item(),
358
+ }
359
+ self.val_true[:] = 0
360
+ self.val_total[:] = 0
361
+ return metrics
362
+
363
+ #
364
+ # inference
365
+ #
366
+ @classmethod
367
+ def load_model(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-en+pl.model",
368
+ repo_id=None, filename=None, local_filename=None):
369
+ if repo_id is None and filename is None and local_filename is None:
370
+ if ":" in ref:
371
+ repo_id, filename = ref.split(":", 1)
372
+ else:
373
+ local_filename = ref
374
+ if not local_filename:
375
+ local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
376
+ spec = torch.load(local_filename)
377
+ vqmodel = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec.get('tunables', {}))))
378
+ vqmodel.load_state_dict(spec['state_dict'])
379
+ vqmodel.eval()
380
+ return vqmodel
381
+
382
+ def load_checkpoint(self, local_filename):
383
+ spec = torch.load(local_filename, map_location='cpu')
384
+ assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
385
+ state_dict = {k.replace('model.', ''):v
386
+ for k,v in spec['state_dict'].items()}
387
+ self.load_state_dict(state_dict)
388
+ return self
389
+
390
+ def save_model(self, fname, store_parameters=True):
391
+ torch.save(dict(config = self.__stored_args__,
392
+ tunables = dataclasses.asdict(self.tunables),
393
+ state_dict = self.state_dict() if store_parameters else None), fname)
394
+
395
+ def ensure_whisper(self, device):
396
+ # the list wrapper is a hack to make sure the whole of Whisper is not sucked into self.parameters()
397
+ if self.whmodel is None: self.whmodel = [whisper.load_model(self.whisper_model_name, device=device)]
398
+ self.decoding_options = whisper.DecodingOptions()
399
+ multilingual = not self.whisper_model_name.endswith('.en')
400
+ self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual)
401
+
402
+ def quantize(self, embs):
403
+ x = self.downsample_embeddings(embs)
404
+ x = x + self.mlp(self.mlp_ln(x))
405
+ _, stoks, _ = self.rq(x)
406
+ if self.q_depth == 1:
407
+ stoks = stoks.squeeze(-1)
408
+ return stoks
409
+
410
+ def dequantize(self, stoks):
411
+ assert self.q_depth == 1
412
+ assert len(stoks.shape) == 1, "batch processing is not supported"
413
+ if isinstance(stoks, np.ndarray): stoks = torch.tensor(stoks)
414
+ # remove padding
415
+ padding = torch.nonzero(stoks == self.vq_codes)
416
+ if padding.any(): stoks = stoks[:padding[0,0]]
417
+ stoks = F.pad(stoks, (0,self.stoks_len - stoks.shape[-1]), value=self.vq_codes if self.tunables.mask_embs else 0)
418
+ x = self.rq.layers[0]._codebook.embed[0,stoks.to(torch.long).view(-1)]
419
+ x = x.repeat_interleave(self.downsample, -2)
420
+ project_out = getattr(self.rq, 'project_out', None) or self.rq.layers[0].project_out
421
+ x = project_out(x).unsqueeze(0)
422
+ positions = torch.arange(0, x.shape[-2], dtype=torch.long, device=x.device)
423
+ x = x + self.positional_embedding(positions)
424
+ return self.ln_post(self.out_blocks(x))
425
+
426
+ def encode_audio(self, audio):
427
+ if isinstance(audio, str):
428
+ x, sr = torchaudio.load(audio)
429
+ x = torchaudio.transforms.Resample(sr, 16000)(x)[0]
430
+ audio = x.unsqueeze(0)
431
+ return self.encode_mel(whisper.log_mel_spectrogram(audio).to(self.device))
432
+
433
+ def encode_mel(self, mel):
434
+ assert len(mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)"
435
+ self.ensure_whisper(self.device)
436
+ n = mel.shape[-1]
437
+ if n > whisper.audio.N_FRAMES:
438
+ padding = 0
439
+ padded = mel[:,:,:whisper.audio.N_FRAMES]
440
+ else:
441
+ padding = -n % whisper.audio.N_FRAMES
442
+ padded = F.pad(mel, (0, padding), value=-1.5)
443
+ embs = self.whmodel[0].encoder(padded)#.to(self.whmodel[0].device))#[:,:n//2]
444
+ stoks = self.quantize(embs)
445
+ if self.tunables.mask_embs:
446
+ return stoks[:,:n//2//self.downsample]
447
+ else:
448
+ return stoks
449
+
450
+ def decode_text(self, stoks, decoding_options=None):
451
+ self.ensure_whisper(self.device)
452
+ if decoding_options is None: decoding_options = self.decoding_options
453
+ embs = self.dequantize(stoks).to(self.whmodel[0].device)
454
+ return self.whmodel[0].decode(embs, decoding_options)
455
+
456
+ # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 33
457
+ def make_model(size:str, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
458
+ if size == 'base.en-2d-4096c':
459
+ model = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1,
460
+ downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
461
+ whisper_model_name=size.split("-")[0], tunables=tunables)
462
+ return model
463
+ if size == 'base.en-2d-512c':
464
+ model = RQBottleneckTransformer(codebook_dim=32, vq_codes=512, q_depth=1, n_head=8, depth=1,
465
+ downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
466
+ whisper_model_name=size.split("-")[0], tunables=tunables)
467
+ return model
468
+ if size == 'base.en-2d-512c-dim64':
469
+ model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=8, depth=1,
470
+ downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
471
+ whisper_model_name=size.split("-")[0], tunables=tunables)
472
+ return model
473
+ if size == 'base-2d-512c-dim64':
474
+ model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=8, depth=1,
475
+ downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
476
+ whisper_model_name=size.split("-")[0], tunables=tunables)
477
+ return model
478
+ if size == 'base-2d-1024c-dim64':
479
+ model = RQBottleneckTransformer(codebook_dim=64, vq_codes=1024, q_depth=1, n_head=8, depth=1,
480
+ downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
481
+ whisper_model_name=size.split("-")[0], tunables=tunables)
482
+ return model
483
+ if size == 'medium-2d-512c-dim64':
484
+ model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=16, depth=1,
485
+ downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
486
+ whisper_model_name=size.split("-")[0], tunables=tunables)
487
+ return model
488
+ if size == 'medium-2d-1024c-dim64':
489
+ model = RQBottleneckTransformer(codebook_dim=64, vq_codes=1024, q_depth=1, n_head=16, depth=1,
490
+ downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
491
+ whisper_model_name=size.split("-")[0], tunables=tunables)
492
+ return model
493
+ raise ArgumentError(f"invalid model size: {size}")
whisperspeech/wer_metrics.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/C. Word error rate metrics.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['librispeech_data', 'DfBuilder', 'WERStats']
5
+
6
+ # %% ../nbs/C. Word error rate metrics.ipynb 2
7
+ import jiwer
8
+ from whisper_normalizer.english import EnglishTextNormalizer
9
+
10
+ import torchaudio
11
+ from pathlib import Path
12
+ import pandas as pd
13
+
14
+ # %% ../nbs/C. Word error rate metrics.ipynb 3
15
+ engnorm = EnglishTextNormalizer()
16
+ def whisper_normalize(x):
17
+ if type(x) == list:
18
+ return [engnorm(y) for y in x]
19
+ else:
20
+ return engnorm(x)
21
+
22
+ default_transform = jiwer.transforms.Compose([
23
+ jiwer.transforms.ToLowerCase(),
24
+ jiwer.transforms.ExpandCommonEnglishContractions(),
25
+ whisper_normalize,
26
+ jiwer.transforms.RemoveMultipleSpaces(),
27
+ jiwer.transforms.Strip(),
28
+ jiwer.transforms.RemovePunctuation(),
29
+ jiwer.transforms.ReduceToListOfListOfWords(),
30
+ ])
31
+
32
+ # %% ../nbs/C. Word error rate metrics.ipynb 5
33
+ def librispeech_data(datadir, sample_rate=16000):
34
+ for file in Path(datadir).rglob('*.txt'):
35
+ for line in file.read_text().split('\n'):
36
+ if not line: continue
37
+ idx, text = line.split(" ", 1)
38
+ x, sr = torchaudio.load((file.parent/idx).with_suffix('.flac'))
39
+ if sr != sample_rate:
40
+ x = torchaudio.transforms.Resample(sr, self.sample_rate)(x)
41
+ yield x, text
42
+
43
+ # %% ../nbs/C. Word error rate metrics.ipynb 6
44
+ class DfBuilder:
45
+ def __init__(self):
46
+ self.data = {}
47
+
48
+ def push(self, **kwargs):
49
+ for k,v in kwargs.items():
50
+ if k not in self.data:
51
+ self.data[k] = [v]
52
+ else:
53
+ self.data[k].append(v)
54
+
55
+ def df(self):
56
+ return pd.DataFrame(self.data)
57
+
58
+ # %% ../nbs/C. Word error rate metrics.ipynb 7
59
+ class WERStats(DfBuilder):
60
+ def __init__(self, transform=default_transform):
61
+ super().__init__()
62
+ self.reference_transform = transform
63
+ self.hypothesis_transform = transform
64
+
65
+ def push_sample(self, snd, gt_text, text, idx=None):
66
+ if snd is not None: self.push(secs = snd.shape[-1]/16000)
67
+ diff = jiwer.process_words(gt_text, text, reference_transform=self.reference_transform, hypothesis_transform=self.hypothesis_transform)
68
+ self.push(
69
+ idx = idx,
70
+ gt_text = gt_text,
71
+ text = text,
72
+ wer = diff.wer,
73
+ mer = diff.mer,
74
+ wil = diff.wil,
75
+ wip = diff.wip,
76
+ )
77
+ return diff
whisperspeech/wh_transcribe.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/2A. Whisper quantization dataset preparation.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = []
5
+
6
+ # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 3
7
+ import os
8
+ import io
9
+ import time
10
+ import torch
11
+ import torchaudio
12
+
13
+ # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 4
14
+ from pathlib import Path
15
+ import json
16
+ from fastprogress import progress_bar, master_bar
17
+ import numpy as np
18
+ import random
19
+
20
+ import whisper
21
+
22
+ from torch import nn
23
+ import torch.nn.functional as F
24
+ from torch.utils.data.dataloader import DataLoader
25
+
26
+ from fastcore.script import *
27
+
28
+ from . import vad
29
+ import webdataset as wds
30
+
31
+ # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 9
32
+ # let's make it a bit more conservative
33
+ # with full 30 second chunks it sometimes misses a small part of the transcript
34
+ def random_cutter(dur):
35
+ if random.random() < 0.5:
36
+ return dur > 28 * (random.random()*0.95+0.05)
37
+ else:
38
+ return dur > 28
39
+
40
+ def chunk_merger(segments, should_cut=lambda x: x > 28):
41
+ if len(segments) == 0: return segments
42
+ curr_start = segments[0][0]
43
+ curr_end = 0
44
+ merged = []
45
+
46
+ for ts,te in segments:
47
+ if should_cut(te - curr_start) and curr_end - curr_start > 0:
48
+ merged.append((curr_start, curr_end))
49
+ curr_start = ts
50
+ curr_end = te
51
+ merged.append((curr_start, curr_end))
52
+ return merged
53
+
54
+ # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 18
55
+ def merge_in(*datasets):
56
+ """Merge multiple datasets into the current one returning samples with the union of keys.
57
+
58
+ It requires (and validates) all datasets to have the same ordering of keys so you have
59
+ to use it before any sample shuffling. Shard shuffling is ok.
60
+ """
61
+ def merge_loop(main_samples):
62
+ for samples in zip(*[main_samples]+[iter(x) for x in datasets]):
63
+ key = samples[0]['__key__']
64
+ news = {}
65
+ for s in samples:
66
+ assert s['__key__'] == key
67
+ news.update(s)
68
+ yield news
69
+ return merge_loop
70
+
71
+ # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 19
72
+ import copy
73
+
74
+ # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 20
75
+ # a workaround for https://github.com/webdataset/webdataset/issues/297
76
+ # should be possible to use ds.compose here
77
+ def wds_compose(ds, *args):
78
+ ds = copy.copy(ds)
79
+ ds.pipeline = copy.copy(ds.pipeline)
80
+ for f in args:
81
+ ds.append(f)
82
+ return ds
83
+
84
+ # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 24
85
+ def split_to_chunks(stream, pad_to_seconds=30, random_shift=False):
86
+ for s in stream:
87
+ audio, sr = s.get('flac', s.get('wav', (None, None)))
88
+ if audio is None:
89
+ print(f"warning: '{s['__key__']}' does not contain an audio file")
90
+ continue
91
+ imax = len(s['vad.npy']) - 1
92
+ for i,(ts,te) in enumerate(s['vad.npy']):
93
+ samples = audio[0,int(ts*sr):int(te*sr)]
94
+ if pad_to_seconds is not None:
95
+ padding = pad_to_seconds*sr-samples.shape[-1]
96
+ lpad = random.randint(0, padding) if random_shift else 0
97
+ samples = F.pad(samples, (lpad, padding-lpad))
98
+ yield {"__key__": s['__key__'] + f"_{i:03d}",
99
+ "__url__": s['__url__'],
100
+ "i": i, "imax": imax,
101
+ "tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr,
102
+ "lpad": lpad, "rpad": padding-lpad,
103
+ "lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr,
104
+ "samples": samples, "sample_rate": sr}
105
+
106
+ # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 38
107
+ def flac_to_txt_name(input, model_size):
108
+ return input.rsplit("/", 1)[1].replace('flac', f'{model_size}-txt') + ".gz"
109
+
110
+ @call_parse
111
+ def process_shard(
112
+ input:str, # input shard URL/path
113
+ output:str=None, # output shard URL/path
114
+ bs:int=None, # batch size (16 uses around 11GB of VRAM)
115
+ n_samples:int=None, # limit the number of samples (useful for quick benchmarking)
116
+ whisper_model:str="base.en" # Whisper model size
117
+ ):
118
+ if output is None: output = flac_to_txt_name(input, whisper_model)
119
+ if bs is None: bs = 16
120
+ if n_samples is None: n_samples = 'noinfer'
121
+ else: n_samples = n_samples // bs
122
+
123
+ ds = wds_compose(vad.load_dataset(input),
124
+ merge_in(wds.WebDataset(vad.flac_to_vad_name(input)).decode()),
125
+ wds.map_dict(**{"vad.npy":chunk_merger}),
126
+ split_to_chunks,
127
+ wds.to_tuple('__key__', 'samples'),
128
+ wds.batched(bs),
129
+ )
130
+ dl = DataLoader(ds, num_workers=2, batch_size=None)
131
+
132
+ whmodel = whisper.load_model(whisper_model)
133
+ decoding_options = whisper.DecodingOptions(language='en')
134
+
135
+ tmp = output+".tmp"
136
+ with wds.TarWriter(tmp) as sink:
137
+ for keys, samples in progress_bar(dl, total=n_samples):
138
+ with torch.no_grad():
139
+ embs = whmodel.encoder(whisper.log_mel_spectrogram(samples).cuda())
140
+ decs = whmodel.decode(embs, decoding_options)
141
+ for key, dec in zip(keys, decs):
142
+ sink.write({
143
+ "__key__": key,
144
+ "txt": dec.text,
145
+ })
146
+ os.rename(tmp, output)