ahmedghani commited on
Commit
8235b4f
·
1 Parent(s): f8fec52

initial commit

Browse files
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Code of Conduct
3
+
4
+ ## Our Pledge
5
+
6
+ In the interest of fostering an open and welcoming environment, we as
7
+ contributors and maintainers pledge to make participation in our project and
8
+ our community a harassment-free experience for everyone, regardless of age, body
9
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
10
+ level of experience, education, socio-economic status, nationality, personal
11
+ appearance, race, religion, or sexual identity and orientation.
12
+
13
+ ## Our Standards
14
+
15
+ Examples of behavior that contributes to creating a positive environment
16
+ include:
17
+
18
+ * Using welcoming and inclusive language
19
+ * Being respectful of differing viewpoints and experiences
20
+ * Gracefully accepting constructive criticism
21
+ * Focusing on what is best for the community
22
+ * Showing empathy towards other community members
23
+
24
+ Examples of unacceptable behavior by participants include:
25
+
26
+ * The use of sexualized language or imagery and unwelcome sexual attention or
27
+ advances
28
+ * Trolling, insulting/derogatory comments, and personal or political attacks
29
+ * Public or private harassment
30
+ * Publishing others' private information, such as a physical or electronic
31
+ address, without explicit permission
32
+ * Other conduct which could reasonably be considered inappropriate in a
33
+ professional setting
34
+
35
+ ## Our Responsibilities
36
+
37
+ Project maintainers are responsible for clarifying the standards of acceptable
38
+ behavior and are expected to take appropriate and fair corrective action in
39
+ response to any instances of unacceptable behavior.
40
+
41
+ Project maintainers have the right and responsibility to remove, edit, or
42
+ reject comments, commits, code, wiki edits, issues, and other contributions
43
+ that are not aligned to this Code of Conduct, or to ban temporarily or
44
+ permanently any contributor for other behaviors that they deem inappropriate,
45
+ threatening, offensive, or harmful.
46
+
47
+ ## Scope
48
+
49
+ This Code of Conduct applies within all project spaces, and it also applies when
50
+ an individual is representing the project or its community in public spaces.
51
+ Examples of representing a project or community include using an official
52
+ project e-mail address, posting via an official social media account, or acting
53
+ as an appointed representative at an online or offline event. Representation of
54
+ a project may be further defined and clarified by project maintainers.
55
+
56
+ ## Enforcement
57
+
58
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
59
+ reported by contacting the project team at <opensource-conduct@fb.com>. All
60
+ complaints will be reviewed and investigated and will result in a response that
61
+ is deemed necessary and appropriate to the circumstances. The project team is
62
+ obligated to maintain confidentiality with regard to the reporter of an incident.
63
+ Further details of specific enforcement policies may be posted separately.
64
+
65
+ Project maintainers who do not follow or enforce the Code of Conduct in good
66
+ faith may face temporary or permanent repercussions as determined by other
67
+ members of the project's leadership.
68
+
69
+ ## Attribution
70
+
71
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
72
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
73
+
74
+ [homepage]: https://www.contributor-covenant.org
75
+
76
+ For answers to common questions about this code of conduct, see
77
+ https://www.contributor-covenant.org/faq
78
+
CONTRIBUTING.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to Denoiser
2
+
3
+ ## Pull Requests
4
+
5
+ In order to accept your pull request, we need you to submit a CLA. You only need
6
+ to do this once to work on any of Facebook's open source projects.
7
+
8
+ Complete your CLA here: <https://code.facebook.com/cla>
9
+
10
+ Demucs is the implementation of a research paper.
11
+ Therefore, we do not plan on accepting many pull requests for new features.
12
+ We certainly welcome them for bug fixes.
13
+
14
+
15
+ ## Issues
16
+
17
+ We use GitHub issues to track public bugs. Please ensure your description is
18
+ clear and has sufficient instructions to be able to reproduce the issue.
19
+ Please first check existing issues as well as the README for existing solutions.
20
+
21
+
22
+ ## License
23
+ By contributing to this repository, you agree that your contributions will be licensed
24
+ under the LICENSE file in the root directory of this source tree.
25
+
LICENSE ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58
+ Public License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
63
+ ("Public License"). To the extent this Public License may be
64
+ interpreted as a contract, You are granted the Licensed Rights in
65
+ consideration of Your acceptance of these terms and conditions, and the
66
+ Licensor grants You such rights in consideration of benefits the
67
+ Licensor receives from making the Licensed Material available under
68
+ these terms and conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. BY-NC-SA Compatible License means a license listed at
88
+ creativecommons.org/compatiblelicenses, approved by Creative
89
+ Commons as essentially the equivalent of this Public License.
90
+
91
+ d. Copyright and Similar Rights means copyright and/or similar rights
92
+ closely related to copyright including, without limitation,
93
+ performance, broadcast, sound recording, and Sui Generis Database
94
+ Rights, without regard to how the rights are labeled or
95
+ categorized. For purposes of this Public License, the rights
96
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
97
+ Rights.
98
+
99
+ e. Effective Technological Measures means those measures that, in the
100
+ absence of proper authority, may not be circumvented under laws
101
+ fulfilling obligations under Article 11 of the WIPO Copyright
102
+ Treaty adopted on December 20, 1996, and/or similar international
103
+ agreements.
104
+
105
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
106
+ any other exception or limitation to Copyright and Similar Rights
107
+ that applies to Your use of the Licensed Material.
108
+
109
+ g. License Elements means the license attributes listed in the name
110
+ of a Creative Commons Public License. The License Elements of this
111
+ Public License are Attribution, NonCommercial, and ShareAlike.
112
+
113
+ h. Licensed Material means the artistic or literary work, database,
114
+ or other material to which the Licensor applied this Public
115
+ License.
116
+
117
+ i. Licensed Rights means the rights granted to You subject to the
118
+ terms and conditions of this Public License, which are limited to
119
+ all Copyright and Similar Rights that apply to Your use of the
120
+ Licensed Material and that the Licensor has authority to license.
121
+
122
+ j. Licensor means the individual(s) or entity(ies) granting rights
123
+ under this Public License.
124
+
125
+ k. NonCommercial means not primarily intended for or directed towards
126
+ commercial advantage or monetary compensation. For purposes of
127
+ this Public License, the exchange of the Licensed Material for
128
+ other material subject to Copyright and Similar Rights by digital
129
+ file-sharing or similar means is NonCommercial provided there is
130
+ no payment of monetary compensation in connection with the
131
+ exchange.
132
+
133
+ l. Share means to provide material to the public by any means or
134
+ process that requires permission under the Licensed Rights, such
135
+ as reproduction, public display, public performance, distribution,
136
+ dissemination, communication, or importation, and to make material
137
+ available to the public including in ways that members of the
138
+ public may access the material from a place and at a time
139
+ individually chosen by them.
140
+
141
+ m. Sui Generis Database Rights means rights other than copyright
142
+ resulting from Directive 96/9/EC of the European Parliament and of
143
+ the Council of 11 March 1996 on the legal protection of databases,
144
+ as amended and/or succeeded, as well as other essentially
145
+ equivalent rights anywhere in the world.
146
+
147
+ n. You means the individual or entity exercising the Licensed Rights
148
+ under this Public License. Your has a corresponding meaning.
149
+
150
+
151
+ Section 2 -- Scope.
152
+
153
+ a. License grant.
154
+
155
+ 1. Subject to the terms and conditions of this Public License,
156
+ the Licensor hereby grants You a worldwide, royalty-free,
157
+ non-sublicensable, non-exclusive, irrevocable license to
158
+ exercise the Licensed Rights in the Licensed Material to:
159
+
160
+ a. reproduce and Share the Licensed Material, in whole or
161
+ in part, for NonCommercial purposes only; and
162
+
163
+ b. produce, reproduce, and Share Adapted Material for
164
+ NonCommercial purposes only.
165
+
166
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
167
+ Exceptions and Limitations apply to Your use, this Public
168
+ License does not apply, and You do not need to comply with
169
+ its terms and conditions.
170
+
171
+ 3. Term. The term of this Public License is specified in Section
172
+ 6(a).
173
+
174
+ 4. Media and formats; technical modifications allowed. The
175
+ Licensor authorizes You to exercise the Licensed Rights in
176
+ all media and formats whether now known or hereafter created,
177
+ and to make technical modifications necessary to do so. The
178
+ Licensor waives and/or agrees not to assert any right or
179
+ authority to forbid You from making technical modifications
180
+ necessary to exercise the Licensed Rights, including
181
+ technical modifications necessary to circumvent Effective
182
+ Technological Measures. For purposes of this Public License,
183
+ simply making modifications authorized by this Section 2(a)
184
+ (4) never produces Adapted Material.
185
+
186
+ 5. Downstream recipients.
187
+
188
+ a. Offer from the Licensor -- Licensed Material. Every
189
+ recipient of the Licensed Material automatically
190
+ receives an offer from the Licensor to exercise the
191
+ Licensed Rights under the terms and conditions of this
192
+ Public License.
193
+
194
+ b. Additional offer from the Licensor -- Adapted Material.
195
+ Every recipient of Adapted Material from You
196
+ automatically receives an offer from the Licensor to
197
+ exercise the Licensed Rights in the Adapted Material
198
+ under the conditions of the Adapter's License You apply.
199
+
200
+ c. No downstream restrictions. You may not offer or impose
201
+ any additional or different terms or conditions on, or
202
+ apply any Effective Technological Measures to, the
203
+ Licensed Material if doing so restricts exercise of the
204
+ Licensed Rights by any recipient of the Licensed
205
+ Material.
206
+
207
+ 6. No endorsement. Nothing in this Public License constitutes or
208
+ may be construed as permission to assert or imply that You
209
+ are, or that Your use of the Licensed Material is, connected
210
+ with, or sponsored, endorsed, or granted official status by,
211
+ the Licensor or others designated to receive attribution as
212
+ provided in Section 3(a)(1)(A)(i).
213
+
214
+ b. Other rights.
215
+
216
+ 1. Moral rights, such as the right of integrity, are not
217
+ licensed under this Public License, nor are publicity,
218
+ privacy, and/or other similar personality rights; however, to
219
+ the extent possible, the Licensor waives and/or agrees not to
220
+ assert any such rights held by the Licensor to the limited
221
+ extent necessary to allow You to exercise the Licensed
222
+ Rights, but not otherwise.
223
+
224
+ 2. Patent and trademark rights are not licensed under this
225
+ Public License.
226
+
227
+ 3. To the extent possible, the Licensor waives any right to
228
+ collect royalties from You for the exercise of the Licensed
229
+ Rights, whether directly or through a collecting society
230
+ under any voluntary or waivable statutory or compulsory
231
+ licensing scheme. In all other cases the Licensor expressly
232
+ reserves any right to collect such royalties, including when
233
+ the Licensed Material is used other than for NonCommercial
234
+ purposes.
235
+
236
+
237
+ Section 3 -- License Conditions.
238
+
239
+ Your exercise of the Licensed Rights is expressly made subject to the
240
+ following conditions.
241
+
242
+ a. Attribution.
243
+
244
+ 1. If You Share the Licensed Material (including in modified
245
+ form), You must:
246
+
247
+ a. retain the following if it is supplied by the Licensor
248
+ with the Licensed Material:
249
+
250
+ i. identification of the creator(s) of the Licensed
251
+ Material and any others designated to receive
252
+ attribution, in any reasonable manner requested by
253
+ the Licensor (including by pseudonym if
254
+ designated);
255
+
256
+ ii. a copyright notice;
257
+
258
+ iii. a notice that refers to this Public License;
259
+
260
+ iv. a notice that refers to the disclaimer of
261
+ warranties;
262
+
263
+ v. a URI or hyperlink to the Licensed Material to the
264
+ extent reasonably practicable;
265
+
266
+ b. indicate if You modified the Licensed Material and
267
+ retain an indication of any previous modifications; and
268
+
269
+ c. indicate the Licensed Material is licensed under this
270
+ Public License, and include the text of, or the URI or
271
+ hyperlink to, this Public License.
272
+
273
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
274
+ reasonable manner based on the medium, means, and context in
275
+ which You Share the Licensed Material. For example, it may be
276
+ reasonable to satisfy the conditions by providing a URI or
277
+ hyperlink to a resource that includes the required
278
+ information.
279
+ 3. If requested by the Licensor, You must remove any of the
280
+ information required by Section 3(a)(1)(A) to the extent
281
+ reasonably practicable.
282
+
283
+ b. ShareAlike.
284
+
285
+ In addition to the conditions in Section 3(a), if You Share
286
+ Adapted Material You produce, the following conditions also apply.
287
+
288
+ 1. The Adapter's License You apply must be a Creative Commons
289
+ license with the same License Elements, this version or
290
+ later, or a BY-NC-SA Compatible License.
291
+
292
+ 2. You must include the text of, or the URI or hyperlink to, the
293
+ Adapter's License You apply. You may satisfy this condition
294
+ in any reasonable manner based on the medium, means, and
295
+ context in which You Share Adapted Material.
296
+
297
+ 3. You may not offer or impose any additional or different terms
298
+ or conditions on, or apply any Effective Technological
299
+ Measures to, Adapted Material that restrict exercise of the
300
+ rights granted under the Adapter's License You apply.
301
+
302
+
303
+ Section 4 -- Sui Generis Database Rights.
304
+
305
+ Where the Licensed Rights include Sui Generis Database Rights that
306
+ apply to Your use of the Licensed Material:
307
+
308
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309
+ to extract, reuse, reproduce, and Share all or a substantial
310
+ portion of the contents of the database for NonCommercial purposes
311
+ only;
312
+
313
+ b. if You include all or a substantial portion of the database
314
+ contents in a database in which You have Sui Generis Database
315
+ Rights, then the database in which You have Sui Generis Database
316
+ Rights (but not its individual contents) is Adapted Material,
317
+ including for purposes of Section 3(b); and
318
+
319
+ c. You must comply with the conditions in Section 3(a) if You Share
320
+ all or a substantial portion of the contents of the database.
321
+
322
+ For the avoidance of doubt, this Section 4 supplements and does not
323
+ replace Your obligations under this Public License where the Licensed
324
+ Rights include other Copyright and Similar Rights.
325
+
326
+
327
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328
+
329
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339
+
340
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349
+
350
+ c. The disclaimer of warranties and limitation of liability provided
351
+ above shall be interpreted in a manner that, to the extent
352
+ possible, most closely approximates an absolute disclaimer and
353
+ waiver of all liability.
354
+
355
+
356
+ Section 6 -- Term and Termination.
357
+
358
+ a. This Public License applies for the term of the Copyright and
359
+ Similar Rights licensed here. However, if You fail to comply with
360
+ this Public License, then Your rights under this Public License
361
+ terminate automatically.
362
+
363
+ b. Where Your right to use the Licensed Material has terminated under
364
+ Section 6(a), it reinstates:
365
+
366
+ 1. automatically as of the date the violation is cured, provided
367
+ it is cured within 30 days of Your discovery of the
368
+ violation; or
369
+
370
+ 2. upon express reinstatement by the Licensor.
371
+
372
+ For the avoidance of doubt, this Section 6(b) does not affect any
373
+ right the Licensor may have to seek remedies for Your violations
374
+ of this Public License.
375
+
376
+ c. For the avoidance of doubt, the Licensor may also offer the
377
+ Licensed Material under separate terms or conditions or stop
378
+ distributing the Licensed Material at any time; however, doing so
379
+ will not terminate this Public License.
380
+
381
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382
+ License.
383
+
384
+
385
+ Section 7 -- Other Terms and Conditions.
386
+
387
+ a. The Licensor shall not be bound by any additional or different
388
+ terms or conditions communicated by You unless expressly agreed.
389
+
390
+ b. Any arrangements, understandings, or agreements regarding the
391
+ Licensed Material not stated herein are separate from and
392
+ independent of the terms and conditions of this Public License.
393
+
394
+
395
+ Section 8 -- Interpretation.
396
+
397
+ a. For the avoidance of doubt, this Public License does not, and
398
+ shall not be interpreted to, reduce, limit, restrict, or impose
399
+ conditions on any use of the Licensed Material that could lawfully
400
+ be made without permission under this Public License.
401
+
402
+ b. To the extent possible, if any provision of this Public License is
403
+ deemed unenforceable, it shall be automatically reformed to the
404
+ minimum extent necessary to make it enforceable. If the provision
405
+ cannot be reformed, it shall be severed from this Public License
406
+ without affecting the enforceability of the remaining terms and
407
+ conditions.
408
+
409
+ c. No term or condition of this Public License will be waived and no
410
+ failure to comply consented to unless expressly agreed to by the
411
+ Licensor.
412
+
413
+ d. Nothing in this Public License constitutes or may be interpreted
414
+ as a limitation upon, or waiver of, any privileges and immunities
415
+ that apply to the Licensor or You, including from the legal
416
+ processes of any jurisdiction or authority.
417
+
418
+ =======================================================================
419
+
420
+ Creative Commons is not a party to its public
421
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
422
+ its public licenses to material it publishes and in those instances
423
+ will be considered the “Licensor.” The text of the Creative Commons
424
+ public licenses is dedicated to the public domain under the CC0 Public
425
+ Domain Dedication. Except for the limited purpose of indicating that
426
+ material is shared under a Creative Commons public license or as
427
+ otherwise permitted by the Creative Commons policies published at
428
+ creativecommons.org/policies, Creative Commons does not authorize the
429
+ use of the trademark "Creative Commons" or any other trademark or logo
430
+ of Creative Commons without its prior written consent including,
431
+ without limitation, in connection with any unauthorized modifications
432
+ to any of its public licenses or any other arrangements,
433
+ understandings, or agreements concerning use of licensed material. For
434
+ the avoidance of doubt, this paragraph does not form part of the
435
+ public licenses.
436
+
437
+ Creative Commons may be contacted at creativecommons.org.
README.md CHANGED
@@ -1,13 +1,123 @@
1
- ---
2
- title: Svoice Demo
3
- emoji: 🐠
4
- colorFrom: purple
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 3.11.0
8
- app_file: app.py
9
- pinned: false
10
- license: cc-by-nc-sa-4.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Speaker Voice Separation using Neural Nets Gradio Demo
2
+
3
+ ## Installation
4
+
5
+ ```bash
6
+ git clone https://github.com/Muhammad-Ahmad-Ghani/svoice_demo.git
7
+ cd svoice_demo
8
+ conda create -n svoice python=3.7 -y
9
+ conda activate svoice
10
+ conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch -y
11
+ pip install -r requirements.txt
12
+ ```
13
+
14
+ | Pretrained-Model | Dataset | Epochs | Train Loss | Valid Loss |
15
+ |:------------:|:------------:|:------------:|:------------:|:------------:
16
+ | [checkpoint.th](https://drive.google.com/drive/folders/1WzhvH1oIB9LqoTyItA6jViTRai5aURzJ?usp=sharing) | Librimix-7 (16k-mix_clean) | 31 | 0.04 | 0.64 |
17
+
18
+ This is an intermediate checkpoint just for demo purpose.
19
+
20
+ create directory ```outputs/exp_``` and save checkpoint there
21
+ ```
22
+ svoice_demo
23
+ ├── outputs
24
+ │ └── exp_
25
+ │ └── checkpoint.th
26
+ ...
27
+ ```
28
+
29
+ ## Running End To End project
30
+ #### Terminal 1
31
+ ```bash
32
+ conda activate svoice
33
+ python demo.py
34
+ ```
35
+
36
+ ## Training
37
+ Create dataset ```mix_clean``` with sample rate ```16K``` using [librimix](https://github.com/shakeddovrat/librimix) repo.
38
+
39
+ Dataset Structure
40
+ ```
41
+ svoice_demo
42
+ ├── Libri7Mix_Dataset
43
+ │ └── wav16k
44
+ │ └── min
45
+ │ │ └── dev
46
+ │ │ └── ...
47
+ │ │ └── test
48
+ │ │ └── ...
49
+ │ │ └── train-360
50
+ │ │ └── ...
51
+ ...
52
+ ```
53
+
54
+ #### Create ```metadata``` files
55
+ For Librimix7 dataset
56
+ ```
57
+ bash create_metadata_librimix7.sh
58
+ ```
59
+
60
+ For Librimix10 dataset
61
+ ```
62
+ bash create_metadata_librimix10.sh
63
+ ```
64
+
65
+ Change ```conf/config.yaml``` according to your settings. Set ```C: 10``` value at line 66 for number of speakers.
66
+
67
+ ```
68
+ python train.py
69
+ ```
70
+ This will automaticlly read all the configurations from the `conf/config.yaml` file.
71
+ To know more about the training you may refer to original [svoice](https://github.com/facebookresearch/svoice) repo.
72
+
73
+ #### Distributed Training
74
+
75
+ ```
76
+ python train.py ddp=1
77
+ ```
78
+
79
+ ### Evaluating
80
+
81
+ ```
82
+ python -m svoice.evaluate <path to the model> <path to folder containing mix.json and all target separated channels json files s<ID>.json>
83
+ ```
84
+
85
+ ### Citation
86
+
87
+ The svoice code is borrowed from original [svoice](https://github.com/facebookresearch/svoice) repository. All rights of code are reserved by [META Research](https://github.com/facebookresearch).
88
+
89
+ ```
90
+ @inproceedings{nachmani2020voice,
91
+ title={Voice Separation with an Unknown Number of Multiple Speakers},
92
+ author={Nachmani, Eliya and Adi, Yossi and Wolf, Lior},
93
+ booktitle={Proceedings of the 37th international conference on Machine learning},
94
+ year={2020}
95
+ }
96
+ ```
97
+ ```
98
+ @misc{cosentino2020librimix,
99
+ title={LibriMix: An Open-Source Dataset for Generalizable Speech Separation},
100
+ author={Joris Cosentino and Manuel Pariente and Samuele Cornell and Antoine Deleforge and Emmanuel Vincent},
101
+ year={2020},
102
+ eprint={2005.11262},
103
+ archivePrefix={arXiv},
104
+ primaryClass={eess.AS}
105
+ }
106
+ ```
107
+ ## License
108
+ This repository is released under the CC-BY-NC-SA 4.0. license as found in the [LICENSE](LICENSE) file.
109
+
110
+ The file: `svoice/models/sisnr_loss.py` and `svoice/data/preprocess.py` were adapted from the [kaituoxu/Conv-TasNet][convtas] repository. It is an unofficial implementation of the [Conv-TasNet: Surpassing Ideal Time-Frequency Magnitude Masking for Speech Separation][convtas-paper] paper, released under the MIT License.
111
+ Additionally, several input manipulation functions were borrowed and modified from the [yluo42/TAC][tac] repository, released under the CC BY-NC-SA 3.0 License.
112
+
113
+ [icml]: https://arxiv.org/abs/2003.01531.pdf
114
+ [icassp]: https://arxiv.org/pdf/2011.02329.pdf
115
+ [web]: https://enk100.github.io/speaker_separation/
116
+ [pytorch]: https://pytorch.org/
117
+ [hydra]: https://github.com/facebookresearch/hydra
118
+ [hydra-web]: https://hydra.cc/
119
+ [convtas]: https://github.com/kaituoxu/Conv-TasNet
120
+ [convtas-paper]: https://arxiv.org/pdf/1809.07454.pdf
121
+ [tac]: https://github.com/yluo42/TAC
122
+ [nprirgen]: https://github.com/ty274/rir-generator
123
+ [rir]:https://asa.scitation.org/doi/10.1121/1.382599
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from svoice.separate import *
2
+ import scipy.io as sio
3
+ from scipy.io.wavfile import write
4
+ import gradio as gr
5
+ import os
6
+ from transformers import AutoProcessor, pipeline
7
+ from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
8
+ from glob import glob
9
+ load_model()
10
+
11
+ BASE_PATH = os.path.dirname(os.path.abspath(__file__))
12
+ os.makedirs('input', exist_ok=True)
13
+ os.makedirs('separated', exist_ok=True)
14
+ os.makedirs('whisper_checkpoint', exist_ok=True)
15
+
16
+ print("Loading ASR model...")
17
+ processor = AutoProcessor.from_pretrained("openai/whisper-small")
18
+ if not os.path.exists("whisper_checkpoint"):
19
+ model = ORTModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small", from_transformers=True)
20
+ speech_recognition_pipeline = pipeline(
21
+ "automatic-speech-recognition",
22
+ model=model,
23
+ feature_extractor=processor.feature_extractor,
24
+ tokenizer=processor.tokenizer,
25
+ )
26
+ model.save_pretrained("whisper_checkpoint")
27
+ else:
28
+ model = ORTModelForSpeechSeq2Seq.from_pretrained("whisper_checkpoint", from_transformers=False)
29
+ speech_recognition_pipeline = pipeline(
30
+ "automatic-speech-recognition",
31
+ model=model,
32
+ feature_extractor=processor.feature_extractor,
33
+ tokenizer=processor.tokenizer,
34
+ )
35
+ print("Whisper ASR model loaded.")
36
+
37
+ def separator(audio, rec_audio):
38
+ outputs= {}
39
+
40
+ if audio:
41
+ write('input/original.wav', audio[0], audio[1])
42
+ elif rec_audio:
43
+ write('input/original.wav', rec_audio[0], rec_audio[1])
44
+
45
+ separate_demo(mix_dir="./input")
46
+ separated_files = glob(os.path.join('separated', "*.wav"))
47
+ separated_files = [f for f in separated_files if "original.wav" not in f]
48
+ outputs['transcripts'] = []
49
+ for file in sorted(separated_files):
50
+ separated_audio = sio.wavfile.read(file)
51
+ outputs['transcripts'].append(speech_recognition_pipeline(separated_audio[1])['text'])
52
+ return sorted(separated_files) + outputs['transcripts']
53
+
54
+ def set_example_audio(example: list) -> dict:
55
+ return gr.Audio.update(value=example[0])
56
+
57
+ demo = gr.Blocks()
58
+ with demo:
59
+ gr.Markdown('''
60
+ <center>
61
+ <h1>Multiple Voice Separation with Transcription DEMO</h1>
62
+ <div style="display:flex;align-items:center;justify-content:center;"><iframe src="https://streamable.com/e/0x8osl?autoplay=1&nocontrols=1" frameborder="0" allow="autoplay"></iframe></div>
63
+ <p>
64
+ This is a demo for the multiple voice separation algorithm. The algorithm is trained on the LibriMix7 dataset and can be used to separate multiple voices from a single audio file.
65
+ </p>
66
+ </center>
67
+ ''')
68
+
69
+ with gr.Row():
70
+ input_audio = gr.Audio(label="Input audio", type="numpy")
71
+ rec_audio = gr.Audio(label="Record Using Microphone", type="numpy", source="microphone")
72
+
73
+ with gr.Row():
74
+ output_audio1 = gr.Audio(label='Speaker 1', interactive=False)
75
+ output_text1 = gr.Text(label='Speaker 1', interactive=False)
76
+ output_audio2 = gr.Audio(label='Speaker 2', interactive=False)
77
+ output_text2 = gr.Text(label='Speaker 2', interactive=False)
78
+
79
+ with gr.Row():
80
+ output_audio3 = gr.Audio(label='Speaker 3', interactive=False)
81
+ output_text3 = gr.Text(label='Speaker 3', interactive=False)
82
+ output_audio4 = gr.Audio(label='Speaker 4', interactive=False)
83
+ output_text4 = gr.Text(label='Speaker 4', interactive=False)
84
+
85
+ with gr.Row():
86
+ output_audio5 = gr.Audio(label='Speaker 5', interactive=False)
87
+ output_text5 = gr.Text(label='Speaker 5', interactive=False)
88
+ output_audio6 = gr.Audio(label='Speaker 6', interactive=False)
89
+ output_text6 = gr.Text(label='Speaker 6', interactive=False)
90
+
91
+ with gr.Row():
92
+ output_audio7 = gr.Audio(label='Speaker 7', interactive=False)
93
+ output_text7 = gr.Text(label='Speaker 7', interactive=False)
94
+
95
+ outputs_audio = [output_audio1, output_audio2, output_audio3, output_audio4, output_audio5, output_audio6, output_audio7]
96
+ outputs_text = [output_text1, output_text2, output_text3, output_text4, output_text5, output_text6, output_text7]
97
+ button = gr.Button("Separate")
98
+ button.click(separator, inputs=[input_audio, rec_audio], outputs=outputs_audio + outputs_text)
99
+
100
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ libsndfile1-dev
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pesq==0.0.2
2
+ tqdm
3
+ hydra_core==1.0.3
4
+ hydra_colorlog==1.0.0
5
+ pystoi==0.3.3
6
+ librosa==0.7.1
7
+ numba==0.48
8
+ numpy
9
+ flask
10
+ flask-cors
11
+ uvicorn[standard]
12
+ asgiref
13
+ gradio
14
+ transformers==4.24.0
15
+ torch
16
+ torchvision
17
+ torchaudio
18
+ optimum[onnxruntime]==1.5.0
svoice/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
svoice/data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
svoice/data/audio.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Author: Alexandre Défossez @adefossez, 2020
7
+
8
+ import json
9
+ from pathlib import Path
10
+ import math
11
+ import os
12
+ import tqdm
13
+ import sys
14
+
15
+ import torchaudio
16
+ torchaudio.set_audio_backend("sox_io")
17
+ import soundfile as sf
18
+ import torch as th
19
+ from torch.nn import functional as F
20
+
21
+
22
+ # If used, this should be saved somewhere as it takes quite a bit
23
+ # of time to generate
24
+ def find_audio_files(path, exts=[".wav"], progress=True):
25
+ audio_files = []
26
+ for root, folders, files in os.walk(path, followlinks=True):
27
+ for file in files:
28
+ file = Path(root) / file
29
+ if file.suffix.lower() in exts:
30
+ audio_files.append(str(os.path.abspath(file)))
31
+ meta = []
32
+ if progress:
33
+ audio_files = tqdm.tqdm(audio_files, ncols=80)
34
+ for file in audio_files:
35
+ siginfo, _ = torchaudio.info(file)
36
+ length = siginfo.length // siginfo.channels
37
+ meta.append((file, length))
38
+ meta.sort()
39
+ return meta
40
+
41
+
42
+ class Audioset:
43
+ def __init__(self, files, length=None, stride=None, pad=True, augment=None):
44
+ """
45
+ files should be a list [(file, length)]
46
+ """
47
+ self.files = files
48
+ self.num_examples = []
49
+ self.length = length
50
+ self.stride = stride or length
51
+ self.augment = augment
52
+ for file, file_length in self.files:
53
+ if length is None:
54
+ examples = 1
55
+ elif file_length < length:
56
+ examples = 1 if pad else 0
57
+ elif pad:
58
+ examples = int(
59
+ math.ceil((file_length - self.length) / self.stride) + 1)
60
+ else:
61
+ examples = (file_length - self.length) // self.stride + 1
62
+ self.num_examples.append(examples)
63
+
64
+ def __len__(self):
65
+ return sum(self.num_examples)
66
+
67
+ def __getitem__(self, index):
68
+ for (file, _), examples in zip(self.files, self.num_examples):
69
+ if index >= examples:
70
+ index -= examples
71
+ continue
72
+ num_frames = 0
73
+ offset = 0
74
+ if self.length is not None:
75
+ offset = self.stride * index
76
+ num_frames = self.length
77
+ # out = th.Tensor(sf.read(str(file), start=offset, frames=num_frames)[0]).unsqueeze(0)
78
+ out = torchaudio.load(str(file), frame_offset=offset,
79
+ num_frames=num_frames)[0]
80
+ if self.augment:
81
+ out = self.augment(out.squeeze(0).numpy()).unsqueeze(0)
82
+ if num_frames:
83
+ out = F.pad(out, (0, num_frames - out.shape[-1]))
84
+ return out[0]
85
+
86
+
87
+ if __name__ == "__main__":
88
+ json.dump(find_audio_files(sys.argv[1]), sys.stdout, indent=4)
89
+ print()
svoice/data/data.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Authors: Yossi Adi (adiyoss) and Alexandre Défossez (adefossez)
7
+
8
+ import json
9
+ import logging
10
+ import math
11
+ from pathlib import Path
12
+ import os
13
+ import re
14
+
15
+ import librosa
16
+ import numpy as np
17
+ import torch
18
+ import torch.utils.data as data
19
+
20
+ from .preprocess import preprocess_one_dir
21
+ from .audio import Audioset
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def sort(infos): return sorted(
27
+ infos, key=lambda info: int(info[1]), reverse=True)
28
+
29
+
30
+ class Trainset:
31
+ def __init__(self, json_dir, sample_rate=16000, segment=4.0, stride=1.0, pad=True):
32
+ mix_json = os.path.join(json_dir, 'mix.json')
33
+ s_jsons = list()
34
+ s_infos = list()
35
+ sets_re = re.compile(r's[0-9]+.json')
36
+ print(os.listdir(json_dir))
37
+ for s in os.listdir(json_dir):
38
+ if sets_re.search(s):
39
+ s_jsons.append(os.path.join(json_dir, s))
40
+
41
+ with open(mix_json, 'r') as f:
42
+ mix_infos = json.load(f)
43
+ for s_json in s_jsons:
44
+ with open(s_json, 'r') as f:
45
+ s_infos.append(json.load(f))
46
+
47
+ length = int(sample_rate * segment)
48
+ stride = int(sample_rate * stride)
49
+
50
+ kw = {'length': length, 'stride': stride, 'pad': pad}
51
+ self.mix_set = Audioset(sort(mix_infos), **kw)
52
+
53
+ self.sets = list()
54
+ for s_info in s_infos:
55
+ self.sets.append(Audioset(sort(s_info), **kw))
56
+
57
+ # verify all sets has the same size
58
+ for s in self.sets:
59
+ assert len(s) == len(self.mix_set)
60
+
61
+ def __getitem__(self, index):
62
+ mix_sig = self.mix_set[index]
63
+ tgt_sig = [self.sets[i][index] for i in range(len(self.sets))]
64
+ return self.mix_set[index], torch.LongTensor([mix_sig.shape[0]]), torch.stack(tgt_sig)
65
+
66
+ def __len__(self):
67
+ return len(self.mix_set)
68
+
69
+
70
+ class Validset:
71
+ """
72
+ load entire wav.
73
+ """
74
+
75
+ def __init__(self, json_dir):
76
+ mix_json = os.path.join(json_dir, 'mix.json')
77
+ s_jsons = list()
78
+ s_infos = list()
79
+ sets_re = re.compile(r's[0-9]+.json')
80
+ for s in os.listdir(json_dir):
81
+ if sets_re.search(s):
82
+ s_jsons.append(os.path.join(json_dir, s))
83
+ with open(mix_json, 'r') as f:
84
+ mix_infos = json.load(f)
85
+ for s_json in s_jsons:
86
+ with open(s_json, 'r') as f:
87
+ s_infos.append(json.load(f))
88
+ self.mix_set = Audioset(sort(mix_infos))
89
+ self.sets = list()
90
+ for s_info in s_infos:
91
+ self.sets.append(Audioset(sort(s_info)))
92
+ for s in self.sets:
93
+ assert len(s) == len(self.mix_set)
94
+
95
+ def __getitem__(self, index):
96
+ mix_sig = self.mix_set[index]
97
+ tgt_sig = [self.sets[i][index] for i in range(len(self.sets))]
98
+ return self.mix_set[index], torch.LongTensor([mix_sig.shape[0]]), torch.stack(tgt_sig)
99
+
100
+ def __len__(self):
101
+ return len(self.mix_set)
102
+
103
+
104
+ # The following piece of code was adapted from https://github.com/kaituoxu/Conv-TasNet
105
+ # released under the MIT License.
106
+ # Author: Kaituo XU
107
+ # Created on 2018/12
108
+ class EvalDataset(data.Dataset):
109
+
110
+ def __init__(self, mix_dir, mix_json, batch_size, sample_rate=8000):
111
+ """
112
+ Args:
113
+ mix_dir: directory including mixture wav files
114
+ mix_json: json file including mixture wav files
115
+ """
116
+ super(EvalDataset, self).__init__()
117
+ assert mix_dir != None or mix_json != None
118
+ if mix_dir is not None:
119
+ # Generate mix.json given mix_dir
120
+ preprocess_one_dir(mix_dir, mix_dir, 'mix',
121
+ sample_rate=sample_rate)
122
+ mix_json = os.path.join(mix_dir, 'mix.json')
123
+ with open(mix_json, 'r') as f:
124
+ mix_infos = json.load(f)
125
+ # sort it by #samples (impl bucket)
126
+ def sort(infos): return sorted(
127
+ infos, key=lambda info: int(info[1]), reverse=True)
128
+ sorted_mix_infos = sort(mix_infos)
129
+ # generate minibach infomations
130
+ minibatch = []
131
+ start = 0
132
+ while True:
133
+ end = min(len(sorted_mix_infos), start + batch_size)
134
+ minibatch.append([sorted_mix_infos[start:end],
135
+ sample_rate])
136
+ if end == len(sorted_mix_infos):
137
+ break
138
+ start = end
139
+ self.minibatch = minibatch
140
+
141
+ def __getitem__(self, index):
142
+ return self.minibatch[index]
143
+
144
+ def __len__(self):
145
+ return len(self.minibatch)
146
+
147
+
148
+ class EvalDataLoader(data.DataLoader):
149
+ """
150
+ NOTE: just use batchsize=1 here, so drop_last=True makes no sense here.
151
+ """
152
+
153
+ def __init__(self, *args, **kwargs):
154
+ super(EvalDataLoader, self).__init__(*args, **kwargs)
155
+ self.collate_fn = _collate_fn_eval
156
+
157
+
158
+ def _collate_fn_eval(batch):
159
+ """
160
+ Args:
161
+ batch: list, len(batch) = 1. See AudioDataset.__getitem__()
162
+ Returns:
163
+ mixtures_pad: B x T, torch.Tensor
164
+ ilens : B, torch.Tentor
165
+ filenames: a list contain B strings
166
+ """
167
+ # batch should be located in list
168
+ assert len(batch) == 1
169
+ mixtures, filenames = load_mixtures(batch[0])
170
+
171
+ # get batch of lengths of input sequences
172
+ ilens = np.array([mix.shape[0] for mix in mixtures])
173
+
174
+ # perform padding and convert to tensor
175
+ pad_value = 0
176
+ mixtures_pad = pad_list([torch.from_numpy(mix).float()
177
+ for mix in mixtures], pad_value)
178
+ ilens = torch.from_numpy(ilens)
179
+ return mixtures_pad, ilens, filenames
180
+
181
+
182
+ def load_mixtures(batch):
183
+ """
184
+ Returns:
185
+ mixtures: a list containing B items, each item is T np.ndarray
186
+ filenames: a list containing B strings
187
+ T varies from item to item.
188
+ """
189
+ mixtures, filenames = [], []
190
+ mix_infos, sample_rate = batch
191
+ # for each utterance
192
+ for mix_info in mix_infos:
193
+ mix_path = mix_info[0]
194
+ # read wav file
195
+ mix, _ = librosa.load(mix_path, sr=sample_rate)
196
+ mixtures.append(mix)
197
+ filenames.append(mix_path)
198
+ return mixtures, filenames
199
+
200
+
201
+ def pad_list(xs, pad_value):
202
+ n_batch = len(xs)
203
+ max_len = max(x.size(0) for x in xs)
204
+ pad = xs[0].new(n_batch, max_len, * xs[0].size()[1:]).fill_(pad_value)
205
+ for i in range(n_batch):
206
+ pad[i, :xs[i].size(0)] = xs[i]
207
+ return pad
svoice/data/preprocess.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The following piece of code was adapted from https://github.com/kaituoxu/Conv-TasNet
2
+ # released under the MIT License.
3
+ # Author: Kaituo XU
4
+ # Created on 2018/12
5
+
6
+ # Revised by: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf
7
+
8
+ import argparse
9
+ import json
10
+ import os
11
+
12
+ import librosa
13
+ from tqdm import tqdm
14
+
15
+
16
+ def preprocess_one_dir(in_dir, out_dir, out_filename, sample_rate=8000):
17
+ file_infos = []
18
+ in_dir = os.path.abspath(in_dir)
19
+ wav_list = os.listdir(in_dir)
20
+ for wav_file in tqdm(wav_list):
21
+ if not wav_file.endswith('.wav'):
22
+ continue
23
+ wav_path = os.path.join(in_dir, wav_file)
24
+ samples, _ = librosa.load(wav_path, sr=sample_rate)
25
+ file_infos.append((wav_path, len(samples)))
26
+ if not os.path.exists(out_dir):
27
+ os.makedirs(out_dir)
28
+ with open(os.path.join(out_dir, out_filename + '.json'), 'w') as f:
29
+ json.dump(file_infos, f, indent=4)
30
+
31
+
32
+ def preprocess(args):
33
+ for data_type in ['tr', 'cv', 'tt']:
34
+ for signal in ['noisy', 'clean']:
35
+ preprocess_one_dir(os.path.join(args.in_dir, data_type, signal),
36
+ os.path.join(args.out_dir, data_type),
37
+ signal,
38
+ sample_rate=args.sample_rate)
39
+
40
+
41
+ def preprocess_alldirs(args):
42
+ for d in os.listdir(args.in_dir):
43
+ local_dir = os.path.join(args.in_dir, d)
44
+ if os.path.isdir(local_dir):
45
+ preprocess_one_dir(os.path.join(args.in_dir, local_dir),
46
+ os.path.join(args.out_dir),
47
+ d,
48
+ sample_rate=args.sample_rate)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser("WSJ0 data preprocessing")
53
+ parser.add_argument('--in_dir', type=str, default=None,
54
+ help='Directory path of wsj0 including tr, cv and tt')
55
+ parser.add_argument('--out_dir', type=str, default=None,
56
+ help='Directory path to put output files')
57
+ parser.add_argument('--sample_rate', type=int, default=16000,
58
+ help='Sample rate of audio file')
59
+ parser.add_argument("--one_dir", action="store_true",
60
+ help="Generate json files from specific directory")
61
+ parser.add_argument("--all_dirs", action="store_true",
62
+ help="Generate json files from all dirs in specific directory")
63
+ parser.add_argument('--json_name', type=str, default=None,
64
+ help='The name of the json to be generated. '
65
+ 'To be used only with one-dir option.')
66
+ args = parser.parse_args()
67
+ print(args)
68
+ if args.all_dirs:
69
+ preprocess_alldirs(args)
70
+ elif args.one_dir:
71
+ preprocess_one_dir(args.in_dir, args.out_dir,
72
+ args.json_name, sample_rate=args.sample_rate)
73
+ else:
74
+ preprocess(args)
svoice/distrib.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adefossez
7
+
8
+ import logging
9
+ import os
10
+
11
+ import torch
12
+ from torch.utils.data.distributed import DistributedSampler
13
+ from torch.utils.data import DataLoader, Subset
14
+ from torch.nn.parallel.distributed import DistributedDataParallel
15
+
16
+ logger = logging.getLogger(__name__)
17
+ rank = 0
18
+ world_size = 1
19
+
20
+
21
+ def init(args):
22
+ """init.
23
+ Initialize DDP using the given rendezvous file.
24
+ """
25
+ global rank, world_size
26
+ if args.ddp:
27
+ assert args.rank is not None and args.world_size is not None
28
+ rank = args.rank
29
+ world_size = args.world_size
30
+ if world_size == 1:
31
+ return
32
+ torch.cuda.set_device(rank)
33
+ torch.distributed.init_process_group(
34
+ backend=args.ddp_backend,
35
+ init_method='file://' + os.path.abspath(args.rendezvous_file),
36
+ world_size=world_size,
37
+ rank=rank)
38
+ logger.debug("Distributed rendezvous went well, rank %d/%d", rank, world_size)
39
+
40
+
41
+ def average(metrics, count=1.):
42
+ """average.
43
+ Average all the relevant metrices across processes
44
+ `metrics`should be a 1D float32 fector. Returns the average of `metrics`
45
+ over all hosts. You can use `count` to control the weight of each worker.
46
+ """
47
+ if world_size == 1:
48
+ return metrics
49
+ tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
50
+ tensor *= count
51
+ torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
52
+ return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
53
+
54
+
55
+ def wrap(model):
56
+ """wrap.
57
+ Wrap a model with DDP if distributed training is enabled.
58
+ """
59
+ if world_size == 1:
60
+ return model
61
+ else:
62
+ return DistributedDataParallel(
63
+ model,
64
+ device_ids=[torch.cuda.current_device()],
65
+ output_device=torch.cuda.current_device())
66
+
67
+
68
+ def barrier():
69
+ if world_size > 1:
70
+ torch.distributed.barrier()
71
+
72
+
73
+ def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
74
+ """loader.
75
+ Create a dataloader properly in case of distributed training.
76
+ If a gradient is going to be computed you must set `shuffle=True`.
77
+ :param dataset: the dataset to be parallelized
78
+ :param args: relevant args for the loader
79
+ :param shuffle: shuffle examples
80
+ :param klass: loader class
81
+ :param kwargs: relevant args
82
+ """
83
+
84
+ if world_size == 1:
85
+ return klass(dataset, *args, shuffle=shuffle, **kwargs)
86
+
87
+ if shuffle:
88
+ # train means we will compute backward, we use DistributedSampler
89
+ sampler = DistributedSampler(dataset)
90
+ # We ignore shuffle, DistributedSampler already shuffles
91
+ return klass(dataset, *args, **kwargs, sampler=sampler)
92
+ else:
93
+ # We make a manual shard, as DistributedSampler otherwise replicate some examples
94
+ dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
95
+ return klass(dataset, *args, shuffle=shuffle)
svoice/evaluate.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Authors: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf and Alexandre Defossez (adefossez)
8
+
9
+ import argparse
10
+ from concurrent.futures import ProcessPoolExecutor
11
+ import json
12
+ import logging
13
+ import sys
14
+
15
+ import numpy as np
16
+ from pesq import pesq
17
+ from pystoi import stoi
18
+ import torch
19
+
20
+ from .models.sisnr_loss import cal_loss
21
+ from .data.data import Validset
22
+ from . import distrib
23
+ from .utils import bold, deserialize_model, LogProgress
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ parser = argparse.ArgumentParser(
29
+ 'Evaluate separation performance using MulCat blocks')
30
+ parser.add_argument('model_path',
31
+ help='Path to model file created by training')
32
+ parser.add_argument('data_dir',
33
+ help='directory including mix.json, s1.json, s2.json, ... files')
34
+ parser.add_argument('--device', default="cuda")
35
+ parser.add_argument('--sdr', type=int, default=0)
36
+ parser.add_argument('--sample_rate', default=16000,
37
+ type=int, help='Sample rate')
38
+ parser.add_argument('--num_workers', type=int, default=5)
39
+ parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
40
+ default=logging.INFO, help="More loggging")
41
+
42
+
43
+ def evaluate(args, model=None, data_loader=None, sr=None):
44
+ total_sisnr = 0
45
+ total_pesq = 0
46
+ total_stoi = 0
47
+ total_cnt = 0
48
+ updates = 5
49
+
50
+ # Load model
51
+ if not model:
52
+ pkg = torch.load(args.model_path, map_location=args.device)
53
+ if 'model' in pkg:
54
+ model = pkg['model']
55
+ else:
56
+ model = pkg
57
+ model = deserialize_model(model)
58
+ if 'best_state' in pkg:
59
+ model.load_state_dict(pkg['best_state'])
60
+ logger.debug(model)
61
+ model.eval()
62
+ model.to(args.device)
63
+ # Load data
64
+ if not data_loader:
65
+ dataset = Validset(args.data_dir)
66
+ data_loader = distrib.loader(
67
+ dataset, batch_size=1, num_workers=args.num_workers)
68
+ sr = args.sample_rate
69
+ pendings = []
70
+ with ProcessPoolExecutor(args.num_workers) as pool:
71
+ with torch.no_grad():
72
+ iterator = LogProgress(logger, data_loader, name="Eval estimates")
73
+ for i, data in enumerate(iterator):
74
+ # Get batch data
75
+ mixture, lengths, sources = [x.to(args.device) for x in data]
76
+ # Forward
77
+ with torch.no_grad():
78
+ mixture /= mixture.max()
79
+ estimate = model(mixture)[-1]
80
+ sisnr_loss, snr, estimate, reorder_estimate = cal_loss(
81
+ sources, estimate, lengths)
82
+ reorder_estimate = reorder_estimate.cpu()
83
+ sources = sources.cpu()
84
+ mixture = mixture.cpu()
85
+
86
+ pendings.append(
87
+ pool.submit(_run_metrics, sources, reorder_estimate, mixture, None,
88
+ sr=sr))
89
+ total_cnt += sources.shape[0]
90
+
91
+ for pending in LogProgress(logger, pendings, updates, name="Eval metrics"):
92
+ sisnr_i, pesq_i, stoi_i = pending.result()
93
+ total_sisnr += sisnr_i
94
+ total_pesq += pesq_i
95
+ total_stoi += stoi_i
96
+
97
+ metrics = [total_sisnr, total_pesq, total_stoi]
98
+ sisnr, pesq, stoi = distrib.average(
99
+ [m/total_cnt for m in metrics], total_cnt)
100
+ logger.info(
101
+ bold(f'Test set performance: SISNRi={sisnr:.2f} PESQ={pesq}, STOI={stoi}.'))
102
+ return sisnr, pesq, stoi
103
+
104
+
105
+ def _run_metrics(clean, estimate, mix, model, sr, pesq=False):
106
+ if model is not None:
107
+ torch.set_num_threads(1)
108
+ # parallel evaluation here
109
+ with torch.no_grad():
110
+ estimate = model(estimate)[-1]
111
+ estimate = estimate.numpy()
112
+ clean = clean.numpy()
113
+ mix = mix.numpy()
114
+ sisnr = cal_SISNRi(clean, estimate, mix)
115
+ if pesq:
116
+ pesq_i = cal_PESQ(clean, estimate, sr=sr)
117
+ stoi_i = cal_STOI(clean, estimate, sr=sr)
118
+ else:
119
+ pesq_i = 0
120
+ stoi_i = 0
121
+ return sisnr.mean(), pesq_i, stoi_i
122
+
123
+
124
+ def cal_SISNR(ref_sig, out_sig, eps=1e-8):
125
+ """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
126
+ Args:
127
+ ref_sig: numpy.ndarray, [B, T]
128
+ out_sig: numpy.ndarray, [B, T]
129
+ Returns:
130
+ SISNR
131
+ """
132
+ assert len(ref_sig) == len(out_sig)
133
+ B, T = ref_sig.shape
134
+ ref_sig = ref_sig - np.mean(ref_sig, axis=1).reshape(B, 1)
135
+ out_sig = out_sig - np.mean(out_sig, axis=1).reshape(B, 1)
136
+ ref_energy = (np.sum(ref_sig ** 2, axis=1) + eps).reshape(B, 1)
137
+ proj = (np.sum(ref_sig * out_sig, axis=1).reshape(B, 1)) * \
138
+ ref_sig / ref_energy
139
+ noise = out_sig - proj
140
+ ratio = np.sum(proj ** 2, axis=1) / (np.sum(noise ** 2, axis=1) + eps)
141
+ sisnr = 10 * np.log(ratio + eps) / np.log(10.0)
142
+ return sisnr.mean()
143
+
144
+
145
+ def cal_PESQ(ref_sig, out_sig, sr):
146
+ """Calculate PESQ.
147
+ Args:
148
+ ref_sig: numpy.ndarray, [B, C, T]
149
+ out_sig: numpy.ndarray, [B, C, T]
150
+ Returns
151
+ PESQ
152
+ """
153
+ B, C, T = ref_sig.shape
154
+ ref_sig = ref_sig.reshape(B*C, T)
155
+ out_sig = out_sig.reshape(B*C, T)
156
+ pesq_val = 0
157
+ for i in range(len(ref_sig)):
158
+ pesq_val += pesq(sr, ref_sig[i], out_sig[i], 'nb')
159
+ return pesq_val / (B*C)
160
+
161
+
162
+ def cal_STOI(ref_sig, out_sig, sr):
163
+ """Calculate STOI.
164
+ Args:
165
+ ref_sig: numpy.ndarray, [B, C, T]
166
+ out_sig: numpy.ndarray, [B, C, T]
167
+ Returns:
168
+ STOI
169
+ """
170
+ B, C, T = ref_sig.shape
171
+ ref_sig = ref_sig.reshape(B*C, T)
172
+ out_sig = out_sig.reshape(B*C, T)
173
+ try:
174
+ stoi_val = 0
175
+ for i in range(len(ref_sig)):
176
+ stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False)
177
+ return stoi_val / (B*C)
178
+ except:
179
+ return 0
180
+
181
+
182
+ def cal_SISNRi(src_ref, src_est, mix):
183
+ """Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi)
184
+ Args:
185
+ src_ref: numpy.ndarray, [B, C, T]
186
+ src_est: numpy.ndarray, [B, C, T], reordered by best PIT permutation
187
+ mix: numpy.ndarray, [T]
188
+ Returns:
189
+ average_SISNRi
190
+ """
191
+ avg_SISNRi = 0.0
192
+ B, C, T = src_ref.shape
193
+ for c in range(C):
194
+ sisnr = cal_SISNR(src_ref[:, c], src_est[:, c])
195
+ sisnrb = cal_SISNR(src_ref[:, c], mix)
196
+ avg_SISNRi += (sisnr - sisnrb)
197
+ avg_SISNRi /= C
198
+ return avg_SISNRi
199
+
200
+
201
+ def main():
202
+ args = parser.parse_args()
203
+ logging.basicConfig(stream=sys.stderr, level=args.verbose)
204
+ logger.debug(args)
205
+ sisnr, pesq, stoi = evaluate(args)
206
+ json.dump({'sisnr': sisnr,
207
+ 'pesq': pesq, 'stoi': stoi}, sys.stdout)
208
+ sys.stdout.write('\n')
209
+
210
+
211
+ if __name__ == '__main__':
212
+ main()
svoice/evaluate_auto_select.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Authors: Yossi Adi (adiyoss)
8
+
9
+ import argparse
10
+ from concurrent.futures import ProcessPoolExecutor
11
+ import json
12
+ import logging
13
+ import sys
14
+
15
+ import numpy as np
16
+ from pesq import pesq
17
+ from pystoi import stoi
18
+ import torch
19
+
20
+ from .models.sisnr_loss import cal_loss
21
+ from .data.data import Validset
22
+ from . import distrib
23
+ from .utils import bold, deserialize_model, LogProgress
24
+ from .evaluate import _run_metrics
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ parser = argparse.ArgumentParser(
30
+ 'Evaluate model automatic selection performance')
31
+ parser.add_argument('model_path_2spk',
32
+ help='Path to 2spk model file created by training')
33
+ parser.add_argument('model_path_3spk',
34
+ help='Path to 3spk model file created by training')
35
+ parser.add_argument('model_path_4spk',
36
+ help='Path to 4spk model file created by training')
37
+ parser.add_argument('model_path_5spk',
38
+ help='Path to 5spk model file created by training')
39
+ parser.add_argument(
40
+ 'data_dir', help='directory including mix.json, s1.json and s2.json files')
41
+ parser.add_argument('--device', default="cuda")
42
+ parser.add_argument('--sample_rate', default=8000,
43
+ type=int, help='Sample rate')
44
+ parser.add_argument('--thresh', default=0.001,
45
+ type=float, help='Threshold for model auto selection')
46
+ parser.add_argument('--num_workers', type=int, default=5)
47
+ parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
48
+ default=logging.INFO, help="More loggging")
49
+
50
+
51
+
52
+ # test pariwise matching
53
+ def pair_wise(padded_source, estimate_source):
54
+ pair_wise = torch.sum(padded_source.unsqueeze(
55
+ 1)*estimate_source.unsqueeze(2), dim=3)
56
+ if estimate_source.shape[1] != padded_source.shape[1]:
57
+ idxs = pair_wise.argmax(dim=1)
58
+ new_src = torch.FloatTensor(padded_source.shape)
59
+ for b, idx in enumerate(idxs):
60
+ new_src[b:, :, ] = estimate_source[b][idx]
61
+ padded_source_pad = padded_source
62
+ estimate_source_pad = new_src.cuda()
63
+ else:
64
+ padded_source_pad = padded_source
65
+ estimate_source_pad = estimate_source
66
+ return estimate_source_pad
67
+
68
+
69
+ def evaluate_auto_select(args):
70
+ total_sisnr = 0
71
+ total_pesq = 0
72
+ total_stoi = 0
73
+ total_cnt = 0
74
+ updates = 5
75
+
76
+ models = list()
77
+ paths = [args.model_path_2spk, args.model_path_3spk,
78
+ args.model_path_4spk, args.model_path_5spk]
79
+
80
+ for path in paths:
81
+ # Load model
82
+ pkg = torch.load(path)
83
+ if 'model' in pkg:
84
+ model = pkg['model']
85
+ else:
86
+ model = pkg
87
+ model = deserialize_model(model)
88
+ if 'best_state' in pkg:
89
+ model.load_state_dict(pkg['best_state'])
90
+ logger.debug(model)
91
+
92
+ model.eval()
93
+ model.to(args.device)
94
+ models.append(model)
95
+
96
+ # Load data
97
+ dataset = Validset(args.data_dir)
98
+ data_loader = distrib.loader(
99
+ dataset, batch_size=1, num_workers=args.num_workers)
100
+ sr = args.sample_rate
101
+ y_hat = torch.zeros((4))
102
+
103
+ pendings = []
104
+ with ProcessPoolExecutor(args.num_workers) as pool:
105
+ with torch.no_grad():
106
+ iterator = LogProgress(logger, data_loader, name="Eval estimates")
107
+ for i, data in enumerate(iterator):
108
+ # Get batch data
109
+ mixture, lengths, sources = [x.to(args.device) for x in data]
110
+ estimated_sources = list()
111
+ reorder_estimated_sources = list()
112
+
113
+ for model in models:
114
+ # Forward
115
+ with torch.no_grad():
116
+ raw_estimate = model(mixture)[-1]
117
+
118
+ estimate = pair_wise(sources, raw_estimate)
119
+ sisnr_loss, snr, estimate, reorder_estimate = cal_loss(
120
+ sources, estimate, lengths)
121
+ estimated_sources.insert(0, raw_estimate)
122
+ reorder_estimated_sources.insert(0, reorder_estimate)
123
+
124
+ # =================== DETECT NUM. NON-ACTIVE CHANNELS ============== #
125
+ selected_idx = 0
126
+ thresh = args.thresh
127
+ max_spk = 5
128
+ mix_spk = 2
129
+ ground = (max_spk - mix_spk)
130
+ while (selected_idx <= ground):
131
+ no_sils = 0
132
+ vals = torch.mean(
133
+ (estimated_sources[selected_idx]/torch.abs(estimated_sources[selected_idx]).max())**2, axis=2)
134
+ new_selected_idx = max_spk - len(vals[vals > thresh])
135
+ if new_selected_idx == selected_idx:
136
+ break
137
+ else:
138
+ selected_idx = new_selected_idx
139
+ if selected_idx < 0:
140
+ selected_idx = 0
141
+ elif selected_idx > ground:
142
+ selected_idx = ground
143
+
144
+ y_hat[ground - selected_idx] += 1
145
+ reorder_estimate = reorder_estimated_sources[selected_idx].cpu(
146
+ )
147
+ sources = sources.cpu()
148
+ mixture = mixture.cpu()
149
+
150
+ pendings.append(
151
+ pool.submit(_run_metrics, sources, reorder_estimate, mixture, None,
152
+ sr=sr))
153
+ total_cnt += sources.shape[0]
154
+
155
+ for pending in LogProgress(logger, pendings, updates, name="Eval metrics"):
156
+ sisnr_i, pesq_i, stoi_i = pending.result()
157
+ total_sisnr += sisnr_i
158
+ total_pesq += pesq_i
159
+ total_stoi += stoi_i
160
+
161
+ metrics = [total_sisnr, total_pesq, total_stoi]
162
+ sisnr, pesq, stoi = distrib.average(
163
+ [m/total_cnt for m in metrics], total_cnt)
164
+ logger.info(bold(f'Test set performance: SISNRi={sisnr:.2f} '
165
+ f'PESQ={pesq}, STOI={stoi}.'))
166
+ logger.info(f'Two spks prob: {y_hat[0]/(total_cnt)}')
167
+ logger.info(f'Three spks prob: {y_hat[1]/(total_cnt)}')
168
+ logger.info(f'Four spks prob: {y_hat[2]/(total_cnt)}')
169
+ logger.info(f'Five spks prob: {y_hat[3]/(total_cnt)}')
170
+ return sisnr, pesq, stoi
171
+
172
+
173
+ def main():
174
+ args = parser.parse_args()
175
+ logging.basicConfig(stream=sys.stderr, level=args.verbose)
176
+ logger.debug(args)
177
+ sisnr, pesq, stoi = evaluate_auto_select(args)
178
+ json.dump({'sisnr': sisnr,
179
+ 'pesq': pesq, 'stoi': stoi}, sys.stdout)
180
+ sys.stdout.write('\n')
181
+
182
+
183
+ if __name__ == '__main__':
184
+ main()
svoice/executor.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Author: Alexandre Defossez (adefossez)
8
+
9
+ """
10
+ Start multiple process locally for DDP.
11
+ """
12
+
13
+ import logging
14
+ import subprocess as sp
15
+ import sys
16
+
17
+ from hydra import utils
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class ChildrenManager:
23
+ def __init__(self):
24
+ self.children = []
25
+ self.failed = False
26
+
27
+ def add(self, child):
28
+ child.rank = len(self.children)
29
+ self.children.append(child)
30
+
31
+ def __enter__(self):
32
+ return self
33
+
34
+ def __exit__(self, exc_type, exc_value, traceback):
35
+ if exc_value is not None:
36
+ logger.error(
37
+ "An exception happened while starting workers %r", exc_value)
38
+ self.failed = True
39
+ try:
40
+ while self.children and not self.failed:
41
+ for child in list(self.children):
42
+ try:
43
+ exitcode = child.wait(0.1)
44
+ except sp.TimeoutExpired:
45
+ continue
46
+ else:
47
+ self.children.remove(child)
48
+ if exitcode:
49
+ logger.error(
50
+ f"Worker {child.rank} died, killing all workers")
51
+ self.failed = True
52
+ except KeyboardInterrupt:
53
+ logger.error(
54
+ "Received keyboard interrupt, trying to kill all workers.")
55
+ self.failed = True
56
+ for child in self.children:
57
+ child.terminate()
58
+ if not self.failed:
59
+ logger.info("All workers completed successfully")
60
+
61
+
62
+ def start_ddp_workers():
63
+ import torch as th
64
+
65
+ world_size = th.cuda.device_count()
66
+ if not world_size:
67
+ logger.error(
68
+ "DDP is only available on GPU. Make sure GPUs are properly configured with cuda.")
69
+ sys.exit(1)
70
+ logger.info(f"Starting {world_size} worker processes for DDP.")
71
+ with ChildrenManager() as manager:
72
+ for rank in range(world_size):
73
+ kwargs = {}
74
+ argv = list(sys.argv)
75
+ argv += [f"world_size={world_size}", f"rank={rank}"]
76
+ if rank > 0:
77
+ kwargs['stdin'] = sp.DEVNULL
78
+ kwargs['stdout'] = sp.DEVNULL
79
+ kwargs['stderr'] = sp.DEVNULL
80
+ log = utils.HydraConfig().cfg.hydra.job_logging.handlers.file.filename
81
+ log += f".{rank}"
82
+ argv.append("hydra.job_logging.handlers.file.filename=" + log)
83
+ manager.add(sp.Popen([sys.executable] + argv,
84
+ cwd=utils.get_original_cwd(), **kwargs))
85
+ sys.exit(int(manager.failed))
svoice/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
svoice/models/sisnr_loss.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The following piece of code was adapted from https://github.com/kaituoxu/Conv-TasNet
2
+ # released under the MIT License.
3
+ # Author: Kaituo XU
4
+ # Created on 2018/12
5
+
6
+ from itertools import permutations
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ EPS = 1e-8
12
+
13
+
14
+ def cal_loss(source, estimate_source, source_lengths):
15
+ """
16
+ Args:
17
+ source: [B, C, T], B is batch size
18
+ estimate_source: [B, C, T]
19
+ source_lengths: [B]
20
+ """
21
+ max_snr, perms, max_snr_idx, snr_set = cal_si_snr_with_pit(source,
22
+ estimate_source,
23
+ source_lengths)
24
+ B, C, T = estimate_source.shape
25
+ loss = 0 - torch.mean(max_snr)
26
+
27
+ reorder_estimate_source = reorder_source(
28
+ estimate_source, perms, max_snr_idx)
29
+ return loss, max_snr, estimate_source, reorder_estimate_source
30
+
31
+
32
+ def cal_si_snr_with_pit(source, estimate_source, source_lengths):
33
+ """Calculate SI-SNR with PIT training.
34
+ Args:
35
+ source: [B, C, T], B is batch size
36
+ estimate_source: [B, C, T]
37
+ source_lengths: [B], each item is between [0, T]
38
+ """
39
+
40
+ assert source.size() == estimate_source.size()
41
+ B, C, T = source.size()
42
+ # mask padding position along T
43
+ mask = get_mask(source, source_lengths)
44
+ estimate_source *= mask
45
+
46
+ # Step 1. Zero-mean norm
47
+ num_samples = source_lengths.view(-1, 1, 1).float() # [B, 1, 1]
48
+ mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
49
+ mean_estimate = torch.sum(estimate_source, dim=2,
50
+ keepdim=True) / num_samples
51
+ zero_mean_target = source - mean_target
52
+ zero_mean_estimate = estimate_source - mean_estimate
53
+ # mask padding position along T
54
+ zero_mean_target *= mask
55
+ zero_mean_estimate *= mask
56
+
57
+ # Step 2. SI-SNR with PIT
58
+ # reshape to use broadcast
59
+ s_target = torch.unsqueeze(zero_mean_target, dim=1) # [B, 1, C, T]
60
+ s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2) # [B, C, 1, T]
61
+ # s_target = <s', s>s / ||s||^2
62
+ pair_wise_dot = torch.sum(s_estimate * s_target,
63
+ dim=3, keepdim=True) # [B, C, C, 1]
64
+ s_target_energy = torch.sum(
65
+ s_target ** 2, dim=3, keepdim=True) + EPS # [B, 1, C, 1]
66
+ pair_wise_proj = pair_wise_dot * s_target / s_target_energy # [B, C, C, T]
67
+ # e_noise = s' - s_target
68
+ e_noise = s_estimate - pair_wise_proj # [B, C, C, T]
69
+ # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
70
+ pair_wise_si_snr = torch.sum(
71
+ pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS)
72
+ pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS) # [B, C, C]
73
+ pair_wise_si_snr = torch.transpose(pair_wise_si_snr, 1, 2)
74
+
75
+ # Get max_snr of each utterance
76
+ # permutations, [C!, C]
77
+ perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long)
78
+ # one-hot, [C!, C, C]
79
+ index = torch.unsqueeze(perms, 2)
80
+ perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1)
81
+ # [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
82
+ snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot])
83
+ max_snr_idx = torch.argmax(snr_set, dim=1) # [B]
84
+ # max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1)) # [B, 1]
85
+ max_snr, _ = torch.max(snr_set, dim=1, keepdim=True)
86
+ max_snr /= C
87
+ return max_snr, perms, max_snr_idx, snr_set / C
88
+
89
+
90
+ def reorder_source(source, perms, max_snr_idx):
91
+ """
92
+ Args:
93
+ source: [B, C, T]
94
+ perms: [C!, C], permutations
95
+ max_snr_idx: [B], each item is between [0, C!)
96
+ Returns:
97
+ reorder_source: [B, C, T]
98
+ """
99
+ B, C, *_ = source.size()
100
+ # [B, C], permutation whose SI-SNR is max of each utterance
101
+ # for each utterance, reorder estimate source according this permutation
102
+ max_snr_perm = torch.index_select(perms, dim=0, index=max_snr_idx)
103
+ # print('max_snr_perm', max_snr_perm)
104
+ # maybe use torch.gather()/index_select()/scatter() to impl this?
105
+ reorder_source = torch.zeros_like(source)
106
+ for b in range(B):
107
+ for c in range(C):
108
+ reorder_source[b, c] = source[b, max_snr_perm[b][c]]
109
+ return reorder_source
110
+
111
+
112
+ def get_mask(source, source_lengths):
113
+ """
114
+ Args:
115
+ source: [B, C, T]
116
+ source_lengths: [B]
117
+ Returns:
118
+ mask: [B, 1, T]
119
+ """
120
+ B, _, T = source.size()
121
+ mask = source.new_ones((B, 1, T))
122
+ for i in range(B):
123
+ mask[i, :, source_lengths[i]:] = 0
124
+ return mask
svoice/models/swave.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Authors: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf
8
+
9
+ import sys
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.autograd import Variable
15
+
16
+ from ..utils import overlap_and_add
17
+ from ..utils import capture_init
18
+
19
+
20
+ class MulCatBlock(nn.Module):
21
+
22
+ def __init__(self, input_size, hidden_size, dropout=0, bidirectional=False):
23
+ super(MulCatBlock, self).__init__()
24
+
25
+ self.input_size = input_size
26
+ self.hidden_size = hidden_size
27
+ self.num_direction = int(bidirectional) + 1
28
+
29
+ self.rnn = nn.LSTM(input_size, hidden_size, 1, dropout=dropout,
30
+ batch_first=True, bidirectional=bidirectional)
31
+ self.rnn_proj = nn.Linear(hidden_size * self.num_direction, input_size)
32
+
33
+ self.gate_rnn = nn.LSTM(input_size, hidden_size, num_layers=1,
34
+ batch_first=True, dropout=dropout, bidirectional=bidirectional)
35
+ self.gate_rnn_proj = nn.Linear(
36
+ hidden_size * self.num_direction, input_size)
37
+
38
+ self.block_projection = nn.Linear(input_size * 2, input_size)
39
+
40
+ def forward(self, input):
41
+ output = input
42
+ # run rnn module
43
+ rnn_output, _ = self.rnn(output)
44
+ rnn_output = self.rnn_proj(rnn_output.contiguous(
45
+ ).view(-1, rnn_output.shape[2])).view(output.shape).contiguous()
46
+ # run gate rnn module
47
+ gate_rnn_output, _ = self.gate_rnn(output)
48
+ gate_rnn_output = self.gate_rnn_proj(gate_rnn_output.contiguous(
49
+ ).view(-1, gate_rnn_output.shape[2])).view(output.shape).contiguous()
50
+ # apply gated rnn
51
+ gated_output = torch.mul(rnn_output, gate_rnn_output)
52
+ gated_output = torch.cat([gated_output, output], 2)
53
+ gated_output = self.block_projection(
54
+ gated_output.contiguous().view(-1, gated_output.shape[2])).view(output.shape)
55
+ return gated_output
56
+
57
+
58
+ class ByPass(nn.Module):
59
+ def __init__(self):
60
+ super(ByPass, self).__init__()
61
+
62
+ def forward(self, input):
63
+ return input
64
+
65
+
66
+ class DPMulCat(nn.Module):
67
+ def __init__(self, input_size, hidden_size, output_size, num_spk,
68
+ dropout=0, num_layers=1, bidirectional=True, input_normalize=False):
69
+ super(DPMulCat, self).__init__()
70
+
71
+ self.input_size = input_size
72
+ self.output_size = output_size
73
+ self.hidden_size = hidden_size
74
+ self.in_norm = input_normalize
75
+ self.num_layers = num_layers
76
+
77
+ self.rows_grnn = nn.ModuleList([])
78
+ self.cols_grnn = nn.ModuleList([])
79
+ self.rows_normalization = nn.ModuleList([])
80
+ self.cols_normalization = nn.ModuleList([])
81
+
82
+ # create the dual path pipeline
83
+ for i in range(num_layers):
84
+ self.rows_grnn.append(MulCatBlock(
85
+ input_size, hidden_size, dropout, bidirectional=bidirectional))
86
+ self.cols_grnn.append(MulCatBlock(
87
+ input_size, hidden_size, dropout, bidirectional=bidirectional))
88
+ if self.in_norm:
89
+ self.rows_normalization.append(
90
+ nn.GroupNorm(1, input_size, eps=1e-8))
91
+ self.cols_normalization.append(
92
+ nn.GroupNorm(1, input_size, eps=1e-8))
93
+ else:
94
+ # used to disable normalization
95
+ self.rows_normalization.append(ByPass())
96
+ self.cols_normalization.append(ByPass())
97
+
98
+ self.output = nn.Sequential(
99
+ nn.PReLU(), nn.Conv2d(input_size, output_size * num_spk, 1))
100
+
101
+ def forward(self, input):
102
+ batch_size, _, d1, d2 = input.shape
103
+ output = input
104
+ output_all = []
105
+ for i in range(self.num_layers):
106
+ row_input = output.permute(0, 3, 2, 1).contiguous().view(
107
+ batch_size * d2, d1, -1)
108
+ row_output = self.rows_grnn[i](row_input)
109
+ row_output = row_output.view(
110
+ batch_size, d2, d1, -1).permute(0, 3, 2, 1).contiguous()
111
+ row_output = self.rows_normalization[i](row_output)
112
+ # apply a skip connection
113
+ if self.training:
114
+ output = output + row_output
115
+ else:
116
+ output += row_output
117
+
118
+ col_input = output.permute(0, 2, 3, 1).contiguous().view(
119
+ batch_size * d1, d2, -1)
120
+ col_output = self.cols_grnn[i](col_input)
121
+ col_output = col_output.view(
122
+ batch_size, d1, d2, -1).permute(0, 3, 1, 2).contiguous()
123
+ col_output = self.cols_normalization[i](col_output).contiguous()
124
+ # apply a skip connection
125
+ if self.training:
126
+ output = output + col_output
127
+ else:
128
+ output += col_output
129
+
130
+ output_i = self.output(output)
131
+ if self.training or i == (self.num_layers - 1):
132
+ output_all.append(output_i)
133
+ return output_all
134
+
135
+
136
+ class Separator(nn.Module):
137
+ def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2,
138
+ layer=4, segment_size=100, input_normalize=False, bidirectional=True):
139
+ super(Separator, self).__init__()
140
+
141
+ self.input_dim = input_dim
142
+ self.feature_dim = feature_dim
143
+ self.hidden_dim = hidden_dim
144
+ self.output_dim = output_dim
145
+
146
+ self.layer = layer
147
+ self.segment_size = segment_size
148
+ self.num_spk = num_spk
149
+ self.input_normalize = input_normalize
150
+
151
+ self.rnn_model = DPMulCat(self.feature_dim, self.hidden_dim,
152
+ self.feature_dim, self.num_spk, num_layers=layer, bidirectional=bidirectional, input_normalize=input_normalize)
153
+
154
+ # ======================================= #
155
+ # The following code block was borrowed and modified from https://github.com/yluo42/TAC
156
+ # ================ BEGIN ================ #
157
+ def pad_segment(self, input, segment_size):
158
+ # input is the features: (B, N, T)
159
+ batch_size, dim, seq_len = input.shape
160
+ segment_stride = segment_size // 2
161
+ rest = segment_size - (segment_stride + seq_len %
162
+ segment_size) % segment_size
163
+ if rest > 0:
164
+ pad = Variable(torch.zeros(batch_size, dim, rest)
165
+ ).type(input.type())
166
+ input = torch.cat([input, pad], 2)
167
+
168
+ pad_aux = Variable(torch.zeros(
169
+ batch_size, dim, segment_stride)).type(input.type())
170
+ input = torch.cat([pad_aux, input, pad_aux], 2)
171
+ return input, rest
172
+
173
+ def create_chuncks(self, input, segment_size):
174
+ # split the feature into chunks of segment size
175
+ # input is the features: (B, N, T)
176
+
177
+ input, rest = self.pad_segment(input, segment_size)
178
+ batch_size, dim, seq_len = input.shape
179
+ segment_stride = segment_size // 2
180
+
181
+ segments1 = input[:, :, :-segment_stride].contiguous().view(batch_size,
182
+ dim, -1, segment_size)
183
+ segments2 = input[:, :, segment_stride:].contiguous().view(
184
+ batch_size, dim, -1, segment_size)
185
+ segments = torch.cat([segments1, segments2], 3).view(
186
+ batch_size, dim, -1, segment_size).transpose(2, 3)
187
+ return segments.contiguous(), rest
188
+
189
+ def merge_chuncks(self, input, rest):
190
+ # merge the splitted features into full utterance
191
+ # input is the features: (B, N, L, K)
192
+
193
+ batch_size, dim, segment_size, _ = input.shape
194
+ segment_stride = segment_size // 2
195
+ input = input.transpose(2, 3).contiguous().view(
196
+ batch_size, dim, -1, segment_size*2) # B, N, K, L
197
+
198
+ input1 = input[:, :, :, :segment_size].contiguous().view(
199
+ batch_size, dim, -1)[:, :, segment_stride:]
200
+ input2 = input[:, :, :, segment_size:].contiguous().view(
201
+ batch_size, dim, -1)[:, :, :-segment_stride]
202
+
203
+ output = input1 + input2
204
+ if rest > 0:
205
+ output = output[:, :, :-rest]
206
+ return output.contiguous() # B, N, T
207
+ # ================= END ================= #
208
+
209
+ def forward(self, input):
210
+ # create chunks
211
+ enc_segments, enc_rest = self.create_chuncks(
212
+ input, self.segment_size)
213
+ # separate
214
+ output_all = self.rnn_model(enc_segments)
215
+
216
+ # merge back audio files
217
+ output_all_wav = []
218
+ for ii in range(len(output_all)):
219
+ output_ii = self.merge_chuncks(
220
+ output_all[ii], enc_rest)
221
+ output_all_wav.append(output_ii)
222
+ return output_all_wav
223
+
224
+
225
+ class SWave(nn.Module):
226
+ @capture_init
227
+ def __init__(self, N, L, H, R, C, sr, segment, input_normalize):
228
+ super(SWave, self).__init__()
229
+ # hyper-parameter
230
+ self.N, self.L, self.H, self.R, self.C, self.sr, self.segment = N, L, H, R, C, sr, segment
231
+ self.input_normalize = input_normalize
232
+ self.context_len = 2 * self.sr / 1000
233
+ self.context = int(self.sr * self.context_len / 1000)
234
+ self.layer = self.R
235
+ self.filter_dim = self.context * 2 + 1
236
+ self.num_spk = self.C
237
+ # similar to dprnn paper, setting chancksize to sqrt(2*L)
238
+ self.segment_size = int(
239
+ np.sqrt(2 * self.sr * self.segment / (self.L/2)))
240
+
241
+ # model sub-networks
242
+ self.encoder = Encoder(L, N)
243
+ self.decoder = Decoder(L)
244
+ self.separator = Separator(self.filter_dim + self.N, self.N, self.H,
245
+ self.filter_dim, self.num_spk, self.layer, self.segment_size, self.input_normalize)
246
+ # init
247
+ for p in self.parameters():
248
+ if p.dim() > 1:
249
+ nn.init.xavier_normal_(p)
250
+
251
+ def forward(self, mixture):
252
+ mixture_w = self.encoder(mixture)
253
+ output_all = self.separator(mixture_w)
254
+
255
+ # fix time dimension, might change due to convolution operations
256
+ T_mix = mixture.size(-1)
257
+ # generate wav after each RNN block and optimize the loss
258
+ outputs = []
259
+ for ii in range(len(output_all)):
260
+ output_ii = output_all[ii].view(
261
+ mixture.shape[0], self.C, self.N, mixture_w.shape[2])
262
+ output_ii = self.decoder(output_ii)
263
+
264
+ T_est = output_ii.size(-1)
265
+ output_ii = F.pad(output_ii, (0, T_mix - T_est))
266
+ outputs.append(output_ii)
267
+ return torch.stack(outputs)
268
+
269
+
270
+ class Encoder(nn.Module):
271
+ def __init__(self, L, N):
272
+ super(Encoder, self).__init__()
273
+ self.L, self.N = L, N
274
+ # setting 50% overlap
275
+ self.conv = nn.Conv1d(
276
+ 1, N, kernel_size=L, stride=L // 2, bias=False)
277
+
278
+ def forward(self, mixture):
279
+ mixture = torch.unsqueeze(mixture, 1)
280
+ mixture_w = F.relu(self.conv(mixture))
281
+ return mixture_w
282
+
283
+
284
+ class Decoder(nn.Module):
285
+ def __init__(self, L):
286
+ super(Decoder, self).__init__()
287
+ self.L = L
288
+
289
+ def forward(self, est_source):
290
+ est_source = torch.transpose(est_source, 2, 3)
291
+ est_source = nn.AvgPool2d((1, self.L))(est_source)
292
+ est_source = overlap_and_add(est_source, self.L//2)
293
+
294
+ return est_source
svoice/separate.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import sys
5
+
6
+ import librosa
7
+ import torch
8
+ import tqdm
9
+
10
+ from .data.data import EvalDataLoader, EvalDataset
11
+ from . import distrib
12
+ from .utils import remove_pad
13
+
14
+ from .utils import bold, deserialize_model, LogProgress
15
+ logger = logging.getLogger(__name__)
16
+
17
+ def load_model():
18
+ global device
19
+ global model
20
+ global pkg
21
+ print("Loading svoice model if available...")
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ pkg = torch.load('checkpoint.th')
24
+ if 'model' in pkg:
25
+ model = pkg['model']
26
+ else:
27
+ model = pkg
28
+ model = deserialize_model(model)
29
+ logger.debug(model)
30
+ model.eval()
31
+ model.to(device)
32
+ print("svoice model loaded.")
33
+ print("Device: {}".format(device))
34
+
35
+ parser = argparse.ArgumentParser("Speech separation using MulCat blocks")
36
+ parser.add_argument("model_path", type=str, help="Model name")
37
+ parser.add_argument("out_dir", type=str, default="exp/result",
38
+ help="Directory putting enhanced wav files")
39
+ parser.add_argument("--mix_dir", type=str, default=None,
40
+ help="Directory including mix wav files")
41
+ parser.add_argument("--mix_json", type=str, default=None,
42
+ help="Json file including mix wav files")
43
+ parser.add_argument('--device', default="cuda")
44
+ parser.add_argument("--sample_rate", default=8000,
45
+ type=int, help="Sample rate")
46
+ parser.add_argument("--batch_size", default=1, type=int, help="Batch size")
47
+ parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
48
+ default=logging.INFO, help="More loggging")
49
+
50
+ def save_wavs(estimate_source, mix_sig, lengths, filenames, out_dir, sr=16000):
51
+ # Remove padding and flat
52
+ flat_estimate = remove_pad(estimate_source, lengths)
53
+ mix_sig = remove_pad(mix_sig, lengths)
54
+ # Write result
55
+ for i, filename in enumerate(filenames):
56
+ filename = os.path.join(
57
+ out_dir, os.path.basename(filename).strip(".wav"))
58
+ write(mix_sig[i], filename + ".wav", sr=sr)
59
+ C = flat_estimate[i].shape[0]
60
+ # future support for wave playing
61
+ for c in range(C):
62
+ write(flat_estimate[i][c], filename + f"_s{c + 1}.wav", sr=sr)
63
+
64
+
65
+ def write(inputs, filename, sr=8000):
66
+ librosa.output.write_wav(filename, inputs, sr, norm=True)
67
+
68
+ def separate_demo(mix_dir='mix/', batch_size=1, sample_rate=16000):
69
+ mix_dir, mix_json = mix_dir, None
70
+ out_dir = 'separated'
71
+ # Load data
72
+ eval_dataset = EvalDataset(
73
+ mix_dir,
74
+ mix_json,
75
+ batch_size=batch_size,
76
+ sample_rate=sample_rate,
77
+ )
78
+ eval_loader = distrib.loader(
79
+ eval_dataset, batch_size=1, klass=EvalDataLoader)
80
+
81
+ if distrib.rank == 0:
82
+ os.makedirs(out_dir, exist_ok=True)
83
+ distrib.barrier()
84
+
85
+ with torch.no_grad():
86
+ for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)):
87
+ # Get batch data
88
+ mixture, lengths, filenames = data
89
+ mixture = mixture.to(device)
90
+ lengths = lengths.to(device)
91
+ # Forward
92
+ estimate_sources = model(mixture)[-1]
93
+ # save wav files
94
+ save_wavs(estimate_sources, mixture, lengths,
95
+ filenames, out_dir, sr=sample_rate)
96
+
97
+ separated_files = [os.path.join(out_dir, f) for f in os.listdir(out_dir)]
98
+ separated_files = [os.path.abspath(f) for f in separated_files]
99
+ separated_files = [f for f in separated_files if not f.endswith('original.wav')]
100
+ return separated_files
101
+
102
+ def get_mix_paths(args):
103
+ mix_dir = None
104
+ mix_json = None
105
+ # fix mix dir
106
+ try:
107
+ if args.dset.mix_dir:
108
+ mix_dir = args.dset.mix_dir
109
+ except:
110
+ mix_dir = args.mix_dir
111
+
112
+ # fix mix json
113
+ try:
114
+ if args.dset.mix_json:
115
+ mix_json = args.dset.mix_json
116
+ except:
117
+ mix_json = args.mix_json
118
+ return mix_dir, mix_json
119
+
120
+
121
+ def separate(args, model=None, local_out_dir=None):
122
+ mix_dir, mix_json = get_mix_paths(args)
123
+ if not mix_json and not mix_dir:
124
+ logger.error("Must provide mix_dir or mix_json! "
125
+ "When providing mix_dir, mix_json is ignored.")
126
+ # Load model
127
+ if not model:
128
+ # model
129
+ pkg = torch.load(args.model_path)
130
+ if 'model' in pkg:
131
+ model = pkg['model']
132
+ else:
133
+ model = pkg
134
+ model = deserialize_model(model)
135
+ logger.debug(model)
136
+ model.eval()
137
+ model.to(args.device)
138
+ if local_out_dir:
139
+ out_dir = local_out_dir
140
+ else:
141
+ out_dir = args.out_dir
142
+
143
+ # Load data
144
+ eval_dataset = EvalDataset(
145
+ mix_dir,
146
+ mix_json,
147
+ batch_size=args.batch_size,
148
+ sample_rate=args.sample_rate,
149
+ )
150
+ eval_loader = distrib.loader(
151
+ eval_dataset, batch_size=1, klass=EvalDataLoader)
152
+
153
+ if distrib.rank == 0:
154
+ os.makedirs(out_dir, exist_ok=True)
155
+ distrib.barrier()
156
+
157
+ with torch.no_grad():
158
+ for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)):
159
+ # Get batch data
160
+ mixture, lengths, filenames = data
161
+ mixture = mixture.to(args.device)
162
+ lengths = lengths.to(args.device)
163
+ # Forward
164
+ estimate_sources = model(mixture)[-1]
165
+ # save wav files
166
+ save_wavs(estimate_sources, mixture, lengths,
167
+ filenames, out_dir, sr=args.sample_rate)
168
+
169
+
170
+ if __name__ == "__main__":
171
+ args = parser.parse_args()
172
+ logging.basicConfig(stream=sys.stderr, level=args.verbose)
173
+ logger.debug(args)
174
+ separate(args, local_out_dir=args.out_dir)
svoice/solver.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Author: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf
8
+
9
+ import json
10
+ import logging
11
+ from pathlib import Path
12
+ import os
13
+ import time
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
19
+
20
+ from . import distrib
21
+ from .separate import separate
22
+ from .evaluate import evaluate
23
+ from .models.sisnr_loss import cal_loss
24
+ from .models.swave import SWave
25
+ from .utils import bold, copy_state, pull_metric, serialize_model, swap_state, LogProgress
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class Solver(object):
32
+ def __init__(self, data, model, optimizer, args):
33
+ self.tr_loader = data['tr_loader']
34
+ self.cv_loader = data['cv_loader']
35
+ self.tt_loader = data['tt_loader']
36
+ self.model = model
37
+ self.dmodel = distrib.wrap(model)
38
+ self.optimizer = optimizer
39
+ if args.lr_sched == 'step':
40
+ self.sched = StepLR(
41
+ self.optimizer, step_size=args.step.step_size, gamma=args.step.gamma)
42
+ elif args.lr_sched == 'plateau':
43
+ self.sched = ReduceLROnPlateau(
44
+ self.optimizer, factor=args.plateau.factor, patience=args.plateau.patience)
45
+ else:
46
+ self.sched = None
47
+
48
+ # Training config
49
+ self.device = args.device
50
+ self.epochs = args.epochs
51
+ self.max_norm = args.max_norm
52
+
53
+ # Checkpoints
54
+ self.continue_from = args.continue_from
55
+ self.eval_every = args.eval_every
56
+ self.checkpoint = Path(
57
+ args.checkpoint_file) if args.checkpoint else None
58
+ if self.checkpoint:
59
+ logger.debug("Checkpoint will be saved to %s",
60
+ self.checkpoint.resolve())
61
+ self.history_file = args.history_file
62
+
63
+ self.best_state = None
64
+ self.restart = args.restart
65
+ # keep track of losses
66
+ self.history = []
67
+
68
+ # Where to save samples
69
+ self.samples_dir = args.samples_dir
70
+
71
+ # logging
72
+ self.num_prints = args.num_prints
73
+
74
+ # for seperation tests
75
+ self.args = args
76
+ self._reset()
77
+
78
+ def _serialize(self, path):
79
+ package = {}
80
+ package['model'] = serialize_model(self.model)
81
+ package['optimizer'] = self.optimizer.state_dict()
82
+ package['history'] = self.history
83
+ package['best_state'] = self.best_state
84
+ package['args'] = self.args
85
+ torch.save(package, path)
86
+
87
+ def _reset(self):
88
+ load_from = None
89
+ # Reset
90
+ if self.checkpoint and self.checkpoint.exists() and not self.restart:
91
+ load_from = self.checkpoint
92
+ elif self.continue_from:
93
+ load_from = self.continue_from
94
+
95
+ if load_from:
96
+ logger.info(f'Loading checkpoint model: {load_from}')
97
+ package = torch.load(load_from, 'cpu')
98
+ if load_from == self.continue_from and self.args.continue_best:
99
+ self.model.load_state_dict(package['best_state'])
100
+ else:
101
+ self.model.load_state_dict(package['model']['state'])
102
+
103
+ if 'optimizer' in package and not self.args.continue_best:
104
+ self.optimizer.load_state_dict(package['optimizer'])
105
+ self.history = package['history']
106
+ self.best_state = package['best_state']
107
+
108
+ def train(self):
109
+ # Optimizing the model
110
+ if self.history:
111
+ logger.info("Replaying metrics from previous run")
112
+ for epoch, metrics in enumerate(self.history):
113
+ info = " ".join(f"{k}={v:.5f}" for k, v in metrics.items())
114
+ logger.info(f"Epoch {epoch}: {info}")
115
+
116
+ for epoch in range(len(self.history), self.epochs):
117
+ # Train one epoch
118
+ self.model.train() # Turn on BatchNorm & Dropout
119
+ start = time.time()
120
+ logger.info('-' * 70)
121
+ logger.info("Training...")
122
+ train_loss = self._run_one_epoch(epoch)
123
+ logger.info(bold(f'Train Summary | End of Epoch {epoch + 1} | '
124
+ f'Time {time.time() - start:.2f}s | Train Loss {train_loss:.5f}'))
125
+
126
+ # Cross validation
127
+ logger.info('-' * 70)
128
+ logger.info('Cross validation...')
129
+ self.model.eval() # Turn off Batchnorm & Dropout
130
+ with torch.no_grad():
131
+ valid_loss = self._run_one_epoch(epoch, cross_valid=True)
132
+ logger.info(bold(f'Valid Summary | End of Epoch {epoch + 1} | '
133
+ f'Time {time.time() - start:.2f}s | Valid Loss {valid_loss:.5f}'))
134
+
135
+ # learning rate scheduling
136
+ if self.sched:
137
+ if self.args.lr_sched == 'plateau':
138
+ self.sched.step(valid_loss)
139
+ else:
140
+ self.sched.step()
141
+ logger.info(
142
+ f'Learning rate adjusted: {self.optimizer.state_dict()["param_groups"][0]["lr"]:.5f}')
143
+
144
+ best_loss = min(pull_metric(self.history, 'valid') + [valid_loss])
145
+ metrics = {'train': train_loss,
146
+ 'valid': valid_loss, 'best': best_loss}
147
+ # Save the best model
148
+ if valid_loss == best_loss or self.args.keep_last:
149
+ logger.info(bold('New best valid loss %.4f'), valid_loss)
150
+ self.best_state = copy_state(self.model.state_dict())
151
+
152
+ # evaluate and separate samples every 'eval_every' argument number of epochs
153
+ # also evaluate on last epoch
154
+ if (epoch + 1) % self.eval_every == 0 or epoch == self.epochs - 1:
155
+ # Evaluate on the testset
156
+ logger.info('-' * 70)
157
+ logger.info('Evaluating on the test set...')
158
+ # We switch to the best known model for testing
159
+ with swap_state(self.model, self.best_state):
160
+ sisnr, pesq, stoi = evaluate(
161
+ self.args, self.model, self.tt_loader, self.args.sample_rate)
162
+ metrics.update({'sisnr': sisnr, 'pesq': pesq, 'stoi': stoi})
163
+
164
+ # separate some samples
165
+ logger.info('Separate and save samples...')
166
+ separate(self.args, self.model, self.samples_dir)
167
+
168
+ self.history.append(metrics)
169
+ info = " | ".join(
170
+ f"{k.capitalize()} {v:.5f}" for k, v in metrics.items())
171
+ logger.info('-' * 70)
172
+ logger.info(bold(f"Overall Summary | Epoch {epoch + 1} | {info}"))
173
+
174
+ if distrib.rank == 0:
175
+ json.dump(self.history, open(self.history_file, "w"), indent=2)
176
+ # Save model each epoch
177
+ if self.checkpoint:
178
+ self._serialize(self.checkpoint)
179
+ logger.debug("Checkpoint saved to %s",
180
+ self.checkpoint.resolve())
181
+
182
+ def _run_one_epoch(self, epoch, cross_valid=False):
183
+ total_loss = 0
184
+ data_loader = self.tr_loader if not cross_valid else self.cv_loader
185
+
186
+ # get a different order for distributed training, otherwise this will get ignored
187
+ data_loader.epoch = epoch
188
+
189
+ label = ["Train", "Valid"][cross_valid]
190
+ name = label + f" | Epoch {epoch + 1}"
191
+ logprog = LogProgress(logger, data_loader,
192
+ updates=self.num_prints, name=name)
193
+ for i, data in enumerate(logprog):
194
+ mixture, lengths, sources = [x.to(self.device) for x in data]
195
+ estimate_source = self.dmodel(mixture)
196
+
197
+ # only eval last layer
198
+ if cross_valid:
199
+ estimate_source = estimate_source[-1:]
200
+
201
+ loss = 0
202
+ cnt = len(estimate_source)
203
+ # apply a loss function after each layer
204
+ with torch.autograd.set_detect_anomaly(True):
205
+ for c_idx, est_src in enumerate(estimate_source):
206
+ coeff = ((c_idx+1)*(1/cnt))
207
+ loss_i = 0
208
+ # SI-SNR loss
209
+ sisnr_loss, snr, est_src, reorder_est_src = cal_loss(
210
+ sources, estimate_source[c_idx], lengths)
211
+ loss += (coeff * sisnr_loss)
212
+ loss /= len(estimate_source)
213
+
214
+ if not cross_valid:
215
+ # optimize model in training mode
216
+ self.optimizer.zero_grad()
217
+ loss.backward()
218
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(),
219
+ self.max_norm)
220
+ self.optimizer.step()
221
+
222
+ total_loss += loss.item()
223
+ logprog.update(loss=format(total_loss / (i + 1), ".5f"))
224
+
225
+ # Just in case, clear some memory
226
+ del loss, estimate_source
227
+ return distrib.average([total_loss / (i + 1)], i + 1)[0]
svoice/utils.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Authors: Yossi Adi (adiyoss) and Alexandre Defossez (adefossez)
8
+
9
+ import functools
10
+ import logging
11
+ from contextlib import contextmanager
12
+ import inspect
13
+ import os
14
+ import time
15
+ import math
16
+ import torch
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def capture_init(init):
22
+ """
23
+ Decorate `__init__` with this, and you can then
24
+ recover the *args and **kwargs passed to it in `self._init_args_kwargs`
25
+ """
26
+ @functools.wraps(init)
27
+ def __init__(self, *args, **kwargs):
28
+ self._init_args_kwargs = (args, kwargs)
29
+ init(self, *args, **kwargs)
30
+
31
+ return __init__
32
+
33
+
34
+ def deserialize_model(package, strict=False):
35
+ klass = package['class']
36
+ if strict:
37
+ model = klass(*package['args'], **package['kwargs'])
38
+ else:
39
+ sig = inspect.signature(klass)
40
+ kw = package['kwargs']
41
+ for key in list(kw):
42
+ if key not in sig.parameters:
43
+ logger.warning("Dropping inexistant parameter %s", key)
44
+ del kw[key]
45
+ model = klass(*package['args'], **kw)
46
+ model.load_state_dict(package['state'])
47
+ return model
48
+
49
+
50
+ def copy_state(state):
51
+ return {k: v.cpu().clone() for k, v in state.items()}
52
+
53
+
54
+ def serialize_model(model):
55
+ args, kwargs = model._init_args_kwargs
56
+ state = copy_state(model.state_dict())
57
+ return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state}
58
+
59
+
60
+ @contextmanager
61
+ def swap_state(model, state):
62
+ old_state = copy_state(model.state_dict())
63
+ model.load_state_dict(state)
64
+ try:
65
+ yield
66
+ finally:
67
+ model.load_state_dict(old_state)
68
+
69
+
70
+ @contextmanager
71
+ def swap_cwd(cwd):
72
+ old_cwd = os.getcwd()
73
+ os.chdir(cwd)
74
+ try:
75
+ yield
76
+ finally:
77
+ os.chdir(old_cwd)
78
+
79
+
80
+ def pull_metric(history, name):
81
+ out = []
82
+ for metrics in history:
83
+ if name in metrics:
84
+ out.append(metrics[name])
85
+ return out
86
+
87
+
88
+ class LogProgress:
89
+ """
90
+ Sort of like tqdm but using log lines and not as real time.
91
+ """
92
+
93
+ def __init__(self, logger, iterable, updates=5, total=None,
94
+ name="LogProgress", level=logging.INFO):
95
+ self.iterable = iterable
96
+ self.total = total or len(iterable)
97
+ self.updates = updates
98
+ self.name = name
99
+ self.logger = logger
100
+ self.level = level
101
+
102
+ def update(self, **infos):
103
+ self._infos = infos
104
+
105
+ def __iter__(self):
106
+ self._iterator = iter(self.iterable)
107
+ self._index = -1
108
+ self._infos = {}
109
+ self._begin = time.time()
110
+ return self
111
+
112
+ def __next__(self):
113
+ self._index += 1
114
+ try:
115
+ value = next(self._iterator)
116
+ except StopIteration:
117
+ raise
118
+ else:
119
+ return value
120
+ finally:
121
+ log_every = max(1, self.total // self.updates)
122
+ # logging is delayed by 1 it, in order to have the metrics from update
123
+ if self._index >= 1 and self._index % log_every == 0:
124
+ self._log()
125
+
126
+ def _log(self):
127
+ self._speed = (1 + self._index) / (time.time() - self._begin)
128
+ infos = " | ".join(f"{k.capitalize()} {v}" for k,
129
+ v in self._infos.items())
130
+ if self._speed < 1e-4:
131
+ speed = "oo sec/it"
132
+ elif self._speed < 0.1:
133
+ speed = f"{1/self._speed:.1f} sec/it"
134
+ else:
135
+ speed = f"{self._speed:.1f} it/sec"
136
+ out = f"{self.name} | {self._index}/{self.total} | {speed}"
137
+ if infos:
138
+ out += " | " + infos
139
+ self.logger.log(self.level, out)
140
+
141
+
142
+ def colorize(text, color):
143
+ code = f"\033[{color}m"
144
+ restore = f"\033[0m"
145
+ return "".join([code, text, restore])
146
+
147
+
148
+ def bold(text):
149
+ return colorize(text, "1")
150
+
151
+
152
+ def calculate_grad_norm(model):
153
+ total_norm = 0.0
154
+ is_first = True
155
+ for p in model.parameters():
156
+ param_norm = p.data.grad.flatten()
157
+ if is_first:
158
+ total_norm = param_norm
159
+ is_first = False
160
+ else:
161
+ total_norm = torch.cat((total_norm.unsqueeze(
162
+ 1), p.data.grad.flatten().unsqueeze(1)), dim=0).squeeze(1)
163
+ return total_norm.norm(2) ** (1. / 2)
164
+
165
+
166
+ def calculate_weight_norm(model):
167
+ total_norm = 0.0
168
+ is_first = True
169
+ for p in model.parameters():
170
+ param_norm = p.data.flatten()
171
+ if is_first:
172
+ total_norm = param_norm
173
+ is_first = False
174
+ else:
175
+ total_norm = torch.cat((total_norm.unsqueeze(
176
+ 1), p.data.flatten().unsqueeze(1)), dim=0).squeeze(1)
177
+ return total_norm.norm(2) ** (1. / 2)
178
+
179
+
180
+ def remove_pad(inputs, inputs_lengths):
181
+ """
182
+ Args:
183
+ inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size
184
+ inputs_lengths: torch.Tensor, [B]
185
+ Returns:
186
+ results: a list containing B items, each item is [C, T], T varies
187
+ """
188
+ results = []
189
+ dim = inputs.dim()
190
+ if dim == 3:
191
+ C = inputs.size(1)
192
+ for input, length in zip(inputs, inputs_lengths):
193
+ if dim == 3: # [B, C, T]
194
+ results.append(input[:, :length].view(C, -1).cpu().numpy())
195
+ elif dim == 2: # [B, T]
196
+ results.append(input[:length].view(-1).cpu().numpy())
197
+ return results
198
+
199
+
200
+ def overlap_and_add(signal, frame_step):
201
+ """Reconstructs a signal from a framed representation.
202
+
203
+ Adds potentially overlapping frames of a signal with shape
204
+ `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
205
+ The resulting tensor has shape `[..., output_size]` where
206
+
207
+ output_size = (frames - 1) * frame_step + frame_length
208
+
209
+ Args:
210
+ signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
211
+ frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.
212
+
213
+ Returns:
214
+ A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
215
+ output_size = (frames - 1) * frame_step + frame_length
216
+
217
+ Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
218
+ """
219
+ outer_dimensions = signal.size()[:-2]
220
+ frames, frame_length = signal.size()[-2:]
221
+
222
+ # gcd=Greatest Common Divisor
223
+ subframe_length = math.gcd(frame_length, frame_step)
224
+ subframe_step = frame_step // subframe_length
225
+ subframes_per_frame = frame_length // subframe_length
226
+ output_size = frame_step * (frames - 1) + frame_length
227
+ output_subframes = output_size // subframe_length
228
+
229
+ subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
230
+
231
+ frame = torch.arange(0, output_subframes).unfold(
232
+ 0, subframes_per_frame, subframe_step)
233
+ frame = frame.clone().detach().long().to(signal.device)
234
+ # frame = signal.new_tensor(frame).clone().long() # signal may in GPU or CPU
235
+ frame = frame.contiguous().view(-1)
236
+
237
+ result = signal.new_zeros(
238
+ *outer_dimensions, output_subframes, subframe_length)
239
+ result.index_add_(-2, frame, subframe_signal)
240
+ result = result.view(*outer_dimensions, -1)
241
+ return result