Spaces:
Build error
Build error
ahmedghani
commited on
Commit
·
8235b4f
1
Parent(s):
f8fec52
initial commit
Browse files- CODE_OF_CONDUCT.md +78 -0
- CONTRIBUTING.md +25 -0
- LICENSE +437 -0
- README.md +123 -13
- app.py +100 -0
- packages.txt +2 -0
- requirements.txt +18 -0
- svoice/__init__.py +5 -0
- svoice/data/__init__.py +5 -0
- svoice/data/audio.py +89 -0
- svoice/data/data.py +207 -0
- svoice/data/preprocess.py +74 -0
- svoice/distrib.py +95 -0
- svoice/evaluate.py +212 -0
- svoice/evaluate_auto_select.py +184 -0
- svoice/executor.py +85 -0
- svoice/models/__init__.py +5 -0
- svoice/models/sisnr_loss.py +124 -0
- svoice/models/swave.py +294 -0
- svoice/separate.py +174 -0
- svoice/solver.py +227 -0
- svoice/utils.py +241 -0
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Code of Conduct
|
3 |
+
|
4 |
+
## Our Pledge
|
5 |
+
|
6 |
+
In the interest of fostering an open and welcoming environment, we as
|
7 |
+
contributors and maintainers pledge to make participation in our project and
|
8 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
9 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
10 |
+
level of experience, education, socio-economic status, nationality, personal
|
11 |
+
appearance, race, religion, or sexual identity and orientation.
|
12 |
+
|
13 |
+
## Our Standards
|
14 |
+
|
15 |
+
Examples of behavior that contributes to creating a positive environment
|
16 |
+
include:
|
17 |
+
|
18 |
+
* Using welcoming and inclusive language
|
19 |
+
* Being respectful of differing viewpoints and experiences
|
20 |
+
* Gracefully accepting constructive criticism
|
21 |
+
* Focusing on what is best for the community
|
22 |
+
* Showing empathy towards other community members
|
23 |
+
|
24 |
+
Examples of unacceptable behavior by participants include:
|
25 |
+
|
26 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
27 |
+
advances
|
28 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
29 |
+
* Public or private harassment
|
30 |
+
* Publishing others' private information, such as a physical or electronic
|
31 |
+
address, without explicit permission
|
32 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
33 |
+
professional setting
|
34 |
+
|
35 |
+
## Our Responsibilities
|
36 |
+
|
37 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
38 |
+
behavior and are expected to take appropriate and fair corrective action in
|
39 |
+
response to any instances of unacceptable behavior.
|
40 |
+
|
41 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
42 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
43 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
44 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
45 |
+
threatening, offensive, or harmful.
|
46 |
+
|
47 |
+
## Scope
|
48 |
+
|
49 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
50 |
+
an individual is representing the project or its community in public spaces.
|
51 |
+
Examples of representing a project or community include using an official
|
52 |
+
project e-mail address, posting via an official social media account, or acting
|
53 |
+
as an appointed representative at an online or offline event. Representation of
|
54 |
+
a project may be further defined and clarified by project maintainers.
|
55 |
+
|
56 |
+
## Enforcement
|
57 |
+
|
58 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
59 |
+
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
60 |
+
complaints will be reviewed and investigated and will result in a response that
|
61 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
62 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
63 |
+
Further details of specific enforcement policies may be posted separately.
|
64 |
+
|
65 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
66 |
+
faith may face temporary or permanent repercussions as determined by other
|
67 |
+
members of the project's leadership.
|
68 |
+
|
69 |
+
## Attribution
|
70 |
+
|
71 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
72 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
73 |
+
|
74 |
+
[homepage]: https://www.contributor-covenant.org
|
75 |
+
|
76 |
+
For answers to common questions about this code of conduct, see
|
77 |
+
https://www.contributor-covenant.org/faq
|
78 |
+
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to Denoiser
|
2 |
+
|
3 |
+
## Pull Requests
|
4 |
+
|
5 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
6 |
+
to do this once to work on any of Facebook's open source projects.
|
7 |
+
|
8 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
9 |
+
|
10 |
+
Demucs is the implementation of a research paper.
|
11 |
+
Therefore, we do not plan on accepting many pull requests for new features.
|
12 |
+
We certainly welcome them for bug fixes.
|
13 |
+
|
14 |
+
|
15 |
+
## Issues
|
16 |
+
|
17 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
18 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
19 |
+
Please first check existing issues as well as the README for existing solutions.
|
20 |
+
|
21 |
+
|
22 |
+
## License
|
23 |
+
By contributing to this repository, you agree that your contributions will be licensed
|
24 |
+
under the LICENSE file in the root directory of this source tree.
|
25 |
+
|
LICENSE
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial-ShareAlike 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More_considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
58 |
+
Public License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
63 |
+
("Public License"). To the extent this Public License may be
|
64 |
+
interpreted as a contract, You are granted the Licensed Rights in
|
65 |
+
consideration of Your acceptance of these terms and conditions, and the
|
66 |
+
Licensor grants You such rights in consideration of benefits the
|
67 |
+
Licensor receives from making the Licensed Material available under
|
68 |
+
these terms and conditions.
|
69 |
+
|
70 |
+
|
71 |
+
Section 1 -- Definitions.
|
72 |
+
|
73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
74 |
+
Rights that is derived from or based upon the Licensed Material
|
75 |
+
and in which the Licensed Material is translated, altered,
|
76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
77 |
+
permission under the Copyright and Similar Rights held by the
|
78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
79 |
+
Material is a musical work, performance, or sound recording,
|
80 |
+
Adapted Material is always produced where the Licensed Material is
|
81 |
+
synched in timed relation with a moving image.
|
82 |
+
|
83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
85 |
+
accordance with the terms and conditions of this Public License.
|
86 |
+
|
87 |
+
c. BY-NC-SA Compatible License means a license listed at
|
88 |
+
creativecommons.org/compatiblelicenses, approved by Creative
|
89 |
+
Commons as essentially the equivalent of this Public License.
|
90 |
+
|
91 |
+
d. Copyright and Similar Rights means copyright and/or similar rights
|
92 |
+
closely related to copyright including, without limitation,
|
93 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
94 |
+
Rights, without regard to how the rights are labeled or
|
95 |
+
categorized. For purposes of this Public License, the rights
|
96 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
97 |
+
Rights.
|
98 |
+
|
99 |
+
e. Effective Technological Measures means those measures that, in the
|
100 |
+
absence of proper authority, may not be circumvented under laws
|
101 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
102 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
103 |
+
agreements.
|
104 |
+
|
105 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
106 |
+
any other exception or limitation to Copyright and Similar Rights
|
107 |
+
that applies to Your use of the Licensed Material.
|
108 |
+
|
109 |
+
g. License Elements means the license attributes listed in the name
|
110 |
+
of a Creative Commons Public License. The License Elements of this
|
111 |
+
Public License are Attribution, NonCommercial, and ShareAlike.
|
112 |
+
|
113 |
+
h. Licensed Material means the artistic or literary work, database,
|
114 |
+
or other material to which the Licensor applied this Public
|
115 |
+
License.
|
116 |
+
|
117 |
+
i. Licensed Rights means the rights granted to You subject to the
|
118 |
+
terms and conditions of this Public License, which are limited to
|
119 |
+
all Copyright and Similar Rights that apply to Your use of the
|
120 |
+
Licensed Material and that the Licensor has authority to license.
|
121 |
+
|
122 |
+
j. Licensor means the individual(s) or entity(ies) granting rights
|
123 |
+
under this Public License.
|
124 |
+
|
125 |
+
k. NonCommercial means not primarily intended for or directed towards
|
126 |
+
commercial advantage or monetary compensation. For purposes of
|
127 |
+
this Public License, the exchange of the Licensed Material for
|
128 |
+
other material subject to Copyright and Similar Rights by digital
|
129 |
+
file-sharing or similar means is NonCommercial provided there is
|
130 |
+
no payment of monetary compensation in connection with the
|
131 |
+
exchange.
|
132 |
+
|
133 |
+
l. Share means to provide material to the public by any means or
|
134 |
+
process that requires permission under the Licensed Rights, such
|
135 |
+
as reproduction, public display, public performance, distribution,
|
136 |
+
dissemination, communication, or importation, and to make material
|
137 |
+
available to the public including in ways that members of the
|
138 |
+
public may access the material from a place and at a time
|
139 |
+
individually chosen by them.
|
140 |
+
|
141 |
+
m. Sui Generis Database Rights means rights other than copyright
|
142 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
143 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
144 |
+
as amended and/or succeeded, as well as other essentially
|
145 |
+
equivalent rights anywhere in the world.
|
146 |
+
|
147 |
+
n. You means the individual or entity exercising the Licensed Rights
|
148 |
+
under this Public License. Your has a corresponding meaning.
|
149 |
+
|
150 |
+
|
151 |
+
Section 2 -- Scope.
|
152 |
+
|
153 |
+
a. License grant.
|
154 |
+
|
155 |
+
1. Subject to the terms and conditions of this Public License,
|
156 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
157 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
158 |
+
exercise the Licensed Rights in the Licensed Material to:
|
159 |
+
|
160 |
+
a. reproduce and Share the Licensed Material, in whole or
|
161 |
+
in part, for NonCommercial purposes only; and
|
162 |
+
|
163 |
+
b. produce, reproduce, and Share Adapted Material for
|
164 |
+
NonCommercial purposes only.
|
165 |
+
|
166 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
167 |
+
Exceptions and Limitations apply to Your use, this Public
|
168 |
+
License does not apply, and You do not need to comply with
|
169 |
+
its terms and conditions.
|
170 |
+
|
171 |
+
3. Term. The term of this Public License is specified in Section
|
172 |
+
6(a).
|
173 |
+
|
174 |
+
4. Media and formats; technical modifications allowed. The
|
175 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
176 |
+
all media and formats whether now known or hereafter created,
|
177 |
+
and to make technical modifications necessary to do so. The
|
178 |
+
Licensor waives and/or agrees not to assert any right or
|
179 |
+
authority to forbid You from making technical modifications
|
180 |
+
necessary to exercise the Licensed Rights, including
|
181 |
+
technical modifications necessary to circumvent Effective
|
182 |
+
Technological Measures. For purposes of this Public License,
|
183 |
+
simply making modifications authorized by this Section 2(a)
|
184 |
+
(4) never produces Adapted Material.
|
185 |
+
|
186 |
+
5. Downstream recipients.
|
187 |
+
|
188 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
189 |
+
recipient of the Licensed Material automatically
|
190 |
+
receives an offer from the Licensor to exercise the
|
191 |
+
Licensed Rights under the terms and conditions of this
|
192 |
+
Public License.
|
193 |
+
|
194 |
+
b. Additional offer from the Licensor -- Adapted Material.
|
195 |
+
Every recipient of Adapted Material from You
|
196 |
+
automatically receives an offer from the Licensor to
|
197 |
+
exercise the Licensed Rights in the Adapted Material
|
198 |
+
under the conditions of the Adapter's License You apply.
|
199 |
+
|
200 |
+
c. No downstream restrictions. You may not offer or impose
|
201 |
+
any additional or different terms or conditions on, or
|
202 |
+
apply any Effective Technological Measures to, the
|
203 |
+
Licensed Material if doing so restricts exercise of the
|
204 |
+
Licensed Rights by any recipient of the Licensed
|
205 |
+
Material.
|
206 |
+
|
207 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
208 |
+
may be construed as permission to assert or imply that You
|
209 |
+
are, or that Your use of the Licensed Material is, connected
|
210 |
+
with, or sponsored, endorsed, or granted official status by,
|
211 |
+
the Licensor or others designated to receive attribution as
|
212 |
+
provided in Section 3(a)(1)(A)(i).
|
213 |
+
|
214 |
+
b. Other rights.
|
215 |
+
|
216 |
+
1. Moral rights, such as the right of integrity, are not
|
217 |
+
licensed under this Public License, nor are publicity,
|
218 |
+
privacy, and/or other similar personality rights; however, to
|
219 |
+
the extent possible, the Licensor waives and/or agrees not to
|
220 |
+
assert any such rights held by the Licensor to the limited
|
221 |
+
extent necessary to allow You to exercise the Licensed
|
222 |
+
Rights, but not otherwise.
|
223 |
+
|
224 |
+
2. Patent and trademark rights are not licensed under this
|
225 |
+
Public License.
|
226 |
+
|
227 |
+
3. To the extent possible, the Licensor waives any right to
|
228 |
+
collect royalties from You for the exercise of the Licensed
|
229 |
+
Rights, whether directly or through a collecting society
|
230 |
+
under any voluntary or waivable statutory or compulsory
|
231 |
+
licensing scheme. In all other cases the Licensor expressly
|
232 |
+
reserves any right to collect such royalties, including when
|
233 |
+
the Licensed Material is used other than for NonCommercial
|
234 |
+
purposes.
|
235 |
+
|
236 |
+
|
237 |
+
Section 3 -- License Conditions.
|
238 |
+
|
239 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
240 |
+
following conditions.
|
241 |
+
|
242 |
+
a. Attribution.
|
243 |
+
|
244 |
+
1. If You Share the Licensed Material (including in modified
|
245 |
+
form), You must:
|
246 |
+
|
247 |
+
a. retain the following if it is supplied by the Licensor
|
248 |
+
with the Licensed Material:
|
249 |
+
|
250 |
+
i. identification of the creator(s) of the Licensed
|
251 |
+
Material and any others designated to receive
|
252 |
+
attribution, in any reasonable manner requested by
|
253 |
+
the Licensor (including by pseudonym if
|
254 |
+
designated);
|
255 |
+
|
256 |
+
ii. a copyright notice;
|
257 |
+
|
258 |
+
iii. a notice that refers to this Public License;
|
259 |
+
|
260 |
+
iv. a notice that refers to the disclaimer of
|
261 |
+
warranties;
|
262 |
+
|
263 |
+
v. a URI or hyperlink to the Licensed Material to the
|
264 |
+
extent reasonably practicable;
|
265 |
+
|
266 |
+
b. indicate if You modified the Licensed Material and
|
267 |
+
retain an indication of any previous modifications; and
|
268 |
+
|
269 |
+
c. indicate the Licensed Material is licensed under this
|
270 |
+
Public License, and include the text of, or the URI or
|
271 |
+
hyperlink to, this Public License.
|
272 |
+
|
273 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
274 |
+
reasonable manner based on the medium, means, and context in
|
275 |
+
which You Share the Licensed Material. For example, it may be
|
276 |
+
reasonable to satisfy the conditions by providing a URI or
|
277 |
+
hyperlink to a resource that includes the required
|
278 |
+
information.
|
279 |
+
3. If requested by the Licensor, You must remove any of the
|
280 |
+
information required by Section 3(a)(1)(A) to the extent
|
281 |
+
reasonably practicable.
|
282 |
+
|
283 |
+
b. ShareAlike.
|
284 |
+
|
285 |
+
In addition to the conditions in Section 3(a), if You Share
|
286 |
+
Adapted Material You produce, the following conditions also apply.
|
287 |
+
|
288 |
+
1. The Adapter's License You apply must be a Creative Commons
|
289 |
+
license with the same License Elements, this version or
|
290 |
+
later, or a BY-NC-SA Compatible License.
|
291 |
+
|
292 |
+
2. You must include the text of, or the URI or hyperlink to, the
|
293 |
+
Adapter's License You apply. You may satisfy this condition
|
294 |
+
in any reasonable manner based on the medium, means, and
|
295 |
+
context in which You Share Adapted Material.
|
296 |
+
|
297 |
+
3. You may not offer or impose any additional or different terms
|
298 |
+
or conditions on, or apply any Effective Technological
|
299 |
+
Measures to, Adapted Material that restrict exercise of the
|
300 |
+
rights granted under the Adapter's License You apply.
|
301 |
+
|
302 |
+
|
303 |
+
Section 4 -- Sui Generis Database Rights.
|
304 |
+
|
305 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
306 |
+
apply to Your use of the Licensed Material:
|
307 |
+
|
308 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
309 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
310 |
+
portion of the contents of the database for NonCommercial purposes
|
311 |
+
only;
|
312 |
+
|
313 |
+
b. if You include all or a substantial portion of the database
|
314 |
+
contents in a database in which You have Sui Generis Database
|
315 |
+
Rights, then the database in which You have Sui Generis Database
|
316 |
+
Rights (but not its individual contents) is Adapted Material,
|
317 |
+
including for purposes of Section 3(b); and
|
318 |
+
|
319 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
320 |
+
all or a substantial portion of the contents of the database.
|
321 |
+
|
322 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
323 |
+
replace Your obligations under this Public License where the Licensed
|
324 |
+
Rights include other Copyright and Similar Rights.
|
325 |
+
|
326 |
+
|
327 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
328 |
+
|
329 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
330 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
331 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
332 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
333 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
334 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
335 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
336 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
337 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
338 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
339 |
+
|
340 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
341 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
342 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
343 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
344 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
345 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
346 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
347 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
348 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
349 |
+
|
350 |
+
c. The disclaimer of warranties and limitation of liability provided
|
351 |
+
above shall be interpreted in a manner that, to the extent
|
352 |
+
possible, most closely approximates an absolute disclaimer and
|
353 |
+
waiver of all liability.
|
354 |
+
|
355 |
+
|
356 |
+
Section 6 -- Term and Termination.
|
357 |
+
|
358 |
+
a. This Public License applies for the term of the Copyright and
|
359 |
+
Similar Rights licensed here. However, if You fail to comply with
|
360 |
+
this Public License, then Your rights under this Public License
|
361 |
+
terminate automatically.
|
362 |
+
|
363 |
+
b. Where Your right to use the Licensed Material has terminated under
|
364 |
+
Section 6(a), it reinstates:
|
365 |
+
|
366 |
+
1. automatically as of the date the violation is cured, provided
|
367 |
+
it is cured within 30 days of Your discovery of the
|
368 |
+
violation; or
|
369 |
+
|
370 |
+
2. upon express reinstatement by the Licensor.
|
371 |
+
|
372 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
373 |
+
right the Licensor may have to seek remedies for Your violations
|
374 |
+
of this Public License.
|
375 |
+
|
376 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
377 |
+
Licensed Material under separate terms or conditions or stop
|
378 |
+
distributing the Licensed Material at any time; however, doing so
|
379 |
+
will not terminate this Public License.
|
380 |
+
|
381 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
382 |
+
License.
|
383 |
+
|
384 |
+
|
385 |
+
Section 7 -- Other Terms and Conditions.
|
386 |
+
|
387 |
+
a. The Licensor shall not be bound by any additional or different
|
388 |
+
terms or conditions communicated by You unless expressly agreed.
|
389 |
+
|
390 |
+
b. Any arrangements, understandings, or agreements regarding the
|
391 |
+
Licensed Material not stated herein are separate from and
|
392 |
+
independent of the terms and conditions of this Public License.
|
393 |
+
|
394 |
+
|
395 |
+
Section 8 -- Interpretation.
|
396 |
+
|
397 |
+
a. For the avoidance of doubt, this Public License does not, and
|
398 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
399 |
+
conditions on any use of the Licensed Material that could lawfully
|
400 |
+
be made without permission under this Public License.
|
401 |
+
|
402 |
+
b. To the extent possible, if any provision of this Public License is
|
403 |
+
deemed unenforceable, it shall be automatically reformed to the
|
404 |
+
minimum extent necessary to make it enforceable. If the provision
|
405 |
+
cannot be reformed, it shall be severed from this Public License
|
406 |
+
without affecting the enforceability of the remaining terms and
|
407 |
+
conditions.
|
408 |
+
|
409 |
+
c. No term or condition of this Public License will be waived and no
|
410 |
+
failure to comply consented to unless expressly agreed to by the
|
411 |
+
Licensor.
|
412 |
+
|
413 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
414 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
415 |
+
that apply to the Licensor or You, including from the legal
|
416 |
+
processes of any jurisdiction or authority.
|
417 |
+
|
418 |
+
=======================================================================
|
419 |
+
|
420 |
+
Creative Commons is not a party to its public
|
421 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
422 |
+
its public licenses to material it publishes and in those instances
|
423 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
424 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
425 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
426 |
+
material is shared under a Creative Commons public license or as
|
427 |
+
otherwise permitted by the Creative Commons policies published at
|
428 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
429 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
430 |
+
of Creative Commons without its prior written consent including,
|
431 |
+
without limitation, in connection with any unauthorized modifications
|
432 |
+
to any of its public licenses or any other arrangements,
|
433 |
+
understandings, or agreements concerning use of licensed material. For
|
434 |
+
the avoidance of doubt, this paragraph does not form part of the
|
435 |
+
public licenses.
|
436 |
+
|
437 |
+
Creative Commons may be contacted at creativecommons.org.
|
README.md
CHANGED
@@ -1,13 +1,123 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Speaker Voice Separation using Neural Nets Gradio Demo
|
2 |
+
|
3 |
+
## Installation
|
4 |
+
|
5 |
+
```bash
|
6 |
+
git clone https://github.com/Muhammad-Ahmad-Ghani/svoice_demo.git
|
7 |
+
cd svoice_demo
|
8 |
+
conda create -n svoice python=3.7 -y
|
9 |
+
conda activate svoice
|
10 |
+
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch -y
|
11 |
+
pip install -r requirements.txt
|
12 |
+
```
|
13 |
+
|
14 |
+
| Pretrained-Model | Dataset | Epochs | Train Loss | Valid Loss |
|
15 |
+
|:------------:|:------------:|:------------:|:------------:|:------------:
|
16 |
+
| [checkpoint.th](https://drive.google.com/drive/folders/1WzhvH1oIB9LqoTyItA6jViTRai5aURzJ?usp=sharing) | Librimix-7 (16k-mix_clean) | 31 | 0.04 | 0.64 |
|
17 |
+
|
18 |
+
This is an intermediate checkpoint just for demo purpose.
|
19 |
+
|
20 |
+
create directory ```outputs/exp_``` and save checkpoint there
|
21 |
+
```
|
22 |
+
svoice_demo
|
23 |
+
├── outputs
|
24 |
+
│ └── exp_
|
25 |
+
│ └── checkpoint.th
|
26 |
+
...
|
27 |
+
```
|
28 |
+
|
29 |
+
## Running End To End project
|
30 |
+
#### Terminal 1
|
31 |
+
```bash
|
32 |
+
conda activate svoice
|
33 |
+
python demo.py
|
34 |
+
```
|
35 |
+
|
36 |
+
## Training
|
37 |
+
Create dataset ```mix_clean``` with sample rate ```16K``` using [librimix](https://github.com/shakeddovrat/librimix) repo.
|
38 |
+
|
39 |
+
Dataset Structure
|
40 |
+
```
|
41 |
+
svoice_demo
|
42 |
+
├── Libri7Mix_Dataset
|
43 |
+
│ └── wav16k
|
44 |
+
│ └── min
|
45 |
+
│ │ └── dev
|
46 |
+
│ │ └── ...
|
47 |
+
│ │ └── test
|
48 |
+
│ │ └── ...
|
49 |
+
│ │ └── train-360
|
50 |
+
│ │ └── ...
|
51 |
+
...
|
52 |
+
```
|
53 |
+
|
54 |
+
#### Create ```metadata``` files
|
55 |
+
For Librimix7 dataset
|
56 |
+
```
|
57 |
+
bash create_metadata_librimix7.sh
|
58 |
+
```
|
59 |
+
|
60 |
+
For Librimix10 dataset
|
61 |
+
```
|
62 |
+
bash create_metadata_librimix10.sh
|
63 |
+
```
|
64 |
+
|
65 |
+
Change ```conf/config.yaml``` according to your settings. Set ```C: 10``` value at line 66 for number of speakers.
|
66 |
+
|
67 |
+
```
|
68 |
+
python train.py
|
69 |
+
```
|
70 |
+
This will automaticlly read all the configurations from the `conf/config.yaml` file.
|
71 |
+
To know more about the training you may refer to original [svoice](https://github.com/facebookresearch/svoice) repo.
|
72 |
+
|
73 |
+
#### Distributed Training
|
74 |
+
|
75 |
+
```
|
76 |
+
python train.py ddp=1
|
77 |
+
```
|
78 |
+
|
79 |
+
### Evaluating
|
80 |
+
|
81 |
+
```
|
82 |
+
python -m svoice.evaluate <path to the model> <path to folder containing mix.json and all target separated channels json files s<ID>.json>
|
83 |
+
```
|
84 |
+
|
85 |
+
### Citation
|
86 |
+
|
87 |
+
The svoice code is borrowed from original [svoice](https://github.com/facebookresearch/svoice) repository. All rights of code are reserved by [META Research](https://github.com/facebookresearch).
|
88 |
+
|
89 |
+
```
|
90 |
+
@inproceedings{nachmani2020voice,
|
91 |
+
title={Voice Separation with an Unknown Number of Multiple Speakers},
|
92 |
+
author={Nachmani, Eliya and Adi, Yossi and Wolf, Lior},
|
93 |
+
booktitle={Proceedings of the 37th international conference on Machine learning},
|
94 |
+
year={2020}
|
95 |
+
}
|
96 |
+
```
|
97 |
+
```
|
98 |
+
@misc{cosentino2020librimix,
|
99 |
+
title={LibriMix: An Open-Source Dataset for Generalizable Speech Separation},
|
100 |
+
author={Joris Cosentino and Manuel Pariente and Samuele Cornell and Antoine Deleforge and Emmanuel Vincent},
|
101 |
+
year={2020},
|
102 |
+
eprint={2005.11262},
|
103 |
+
archivePrefix={arXiv},
|
104 |
+
primaryClass={eess.AS}
|
105 |
+
}
|
106 |
+
```
|
107 |
+
## License
|
108 |
+
This repository is released under the CC-BY-NC-SA 4.0. license as found in the [LICENSE](LICENSE) file.
|
109 |
+
|
110 |
+
The file: `svoice/models/sisnr_loss.py` and `svoice/data/preprocess.py` were adapted from the [kaituoxu/Conv-TasNet][convtas] repository. It is an unofficial implementation of the [Conv-TasNet: Surpassing Ideal Time-Frequency Magnitude Masking for Speech Separation][convtas-paper] paper, released under the MIT License.
|
111 |
+
Additionally, several input manipulation functions were borrowed and modified from the [yluo42/TAC][tac] repository, released under the CC BY-NC-SA 3.0 License.
|
112 |
+
|
113 |
+
[icml]: https://arxiv.org/abs/2003.01531.pdf
|
114 |
+
[icassp]: https://arxiv.org/pdf/2011.02329.pdf
|
115 |
+
[web]: https://enk100.github.io/speaker_separation/
|
116 |
+
[pytorch]: https://pytorch.org/
|
117 |
+
[hydra]: https://github.com/facebookresearch/hydra
|
118 |
+
[hydra-web]: https://hydra.cc/
|
119 |
+
[convtas]: https://github.com/kaituoxu/Conv-TasNet
|
120 |
+
[convtas-paper]: https://arxiv.org/pdf/1809.07454.pdf
|
121 |
+
[tac]: https://github.com/yluo42/TAC
|
122 |
+
[nprirgen]: https://github.com/ty274/rir-generator
|
123 |
+
[rir]:https://asa.scitation.org/doi/10.1121/1.382599
|
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from svoice.separate import *
|
2 |
+
import scipy.io as sio
|
3 |
+
from scipy.io.wavfile import write
|
4 |
+
import gradio as gr
|
5 |
+
import os
|
6 |
+
from transformers import AutoProcessor, pipeline
|
7 |
+
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
8 |
+
from glob import glob
|
9 |
+
load_model()
|
10 |
+
|
11 |
+
BASE_PATH = os.path.dirname(os.path.abspath(__file__))
|
12 |
+
os.makedirs('input', exist_ok=True)
|
13 |
+
os.makedirs('separated', exist_ok=True)
|
14 |
+
os.makedirs('whisper_checkpoint', exist_ok=True)
|
15 |
+
|
16 |
+
print("Loading ASR model...")
|
17 |
+
processor = AutoProcessor.from_pretrained("openai/whisper-small")
|
18 |
+
if not os.path.exists("whisper_checkpoint"):
|
19 |
+
model = ORTModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small", from_transformers=True)
|
20 |
+
speech_recognition_pipeline = pipeline(
|
21 |
+
"automatic-speech-recognition",
|
22 |
+
model=model,
|
23 |
+
feature_extractor=processor.feature_extractor,
|
24 |
+
tokenizer=processor.tokenizer,
|
25 |
+
)
|
26 |
+
model.save_pretrained("whisper_checkpoint")
|
27 |
+
else:
|
28 |
+
model = ORTModelForSpeechSeq2Seq.from_pretrained("whisper_checkpoint", from_transformers=False)
|
29 |
+
speech_recognition_pipeline = pipeline(
|
30 |
+
"automatic-speech-recognition",
|
31 |
+
model=model,
|
32 |
+
feature_extractor=processor.feature_extractor,
|
33 |
+
tokenizer=processor.tokenizer,
|
34 |
+
)
|
35 |
+
print("Whisper ASR model loaded.")
|
36 |
+
|
37 |
+
def separator(audio, rec_audio):
|
38 |
+
outputs= {}
|
39 |
+
|
40 |
+
if audio:
|
41 |
+
write('input/original.wav', audio[0], audio[1])
|
42 |
+
elif rec_audio:
|
43 |
+
write('input/original.wav', rec_audio[0], rec_audio[1])
|
44 |
+
|
45 |
+
separate_demo(mix_dir="./input")
|
46 |
+
separated_files = glob(os.path.join('separated', "*.wav"))
|
47 |
+
separated_files = [f for f in separated_files if "original.wav" not in f]
|
48 |
+
outputs['transcripts'] = []
|
49 |
+
for file in sorted(separated_files):
|
50 |
+
separated_audio = sio.wavfile.read(file)
|
51 |
+
outputs['transcripts'].append(speech_recognition_pipeline(separated_audio[1])['text'])
|
52 |
+
return sorted(separated_files) + outputs['transcripts']
|
53 |
+
|
54 |
+
def set_example_audio(example: list) -> dict:
|
55 |
+
return gr.Audio.update(value=example[0])
|
56 |
+
|
57 |
+
demo = gr.Blocks()
|
58 |
+
with demo:
|
59 |
+
gr.Markdown('''
|
60 |
+
<center>
|
61 |
+
<h1>Multiple Voice Separation with Transcription DEMO</h1>
|
62 |
+
<div style="display:flex;align-items:center;justify-content:center;"><iframe src="https://streamable.com/e/0x8osl?autoplay=1&nocontrols=1" frameborder="0" allow="autoplay"></iframe></div>
|
63 |
+
<p>
|
64 |
+
This is a demo for the multiple voice separation algorithm. The algorithm is trained on the LibriMix7 dataset and can be used to separate multiple voices from a single audio file.
|
65 |
+
</p>
|
66 |
+
</center>
|
67 |
+
''')
|
68 |
+
|
69 |
+
with gr.Row():
|
70 |
+
input_audio = gr.Audio(label="Input audio", type="numpy")
|
71 |
+
rec_audio = gr.Audio(label="Record Using Microphone", type="numpy", source="microphone")
|
72 |
+
|
73 |
+
with gr.Row():
|
74 |
+
output_audio1 = gr.Audio(label='Speaker 1', interactive=False)
|
75 |
+
output_text1 = gr.Text(label='Speaker 1', interactive=False)
|
76 |
+
output_audio2 = gr.Audio(label='Speaker 2', interactive=False)
|
77 |
+
output_text2 = gr.Text(label='Speaker 2', interactive=False)
|
78 |
+
|
79 |
+
with gr.Row():
|
80 |
+
output_audio3 = gr.Audio(label='Speaker 3', interactive=False)
|
81 |
+
output_text3 = gr.Text(label='Speaker 3', interactive=False)
|
82 |
+
output_audio4 = gr.Audio(label='Speaker 4', interactive=False)
|
83 |
+
output_text4 = gr.Text(label='Speaker 4', interactive=False)
|
84 |
+
|
85 |
+
with gr.Row():
|
86 |
+
output_audio5 = gr.Audio(label='Speaker 5', interactive=False)
|
87 |
+
output_text5 = gr.Text(label='Speaker 5', interactive=False)
|
88 |
+
output_audio6 = gr.Audio(label='Speaker 6', interactive=False)
|
89 |
+
output_text6 = gr.Text(label='Speaker 6', interactive=False)
|
90 |
+
|
91 |
+
with gr.Row():
|
92 |
+
output_audio7 = gr.Audio(label='Speaker 7', interactive=False)
|
93 |
+
output_text7 = gr.Text(label='Speaker 7', interactive=False)
|
94 |
+
|
95 |
+
outputs_audio = [output_audio1, output_audio2, output_audio3, output_audio4, output_audio5, output_audio6, output_audio7]
|
96 |
+
outputs_text = [output_text1, output_text2, output_text3, output_text4, output_text5, output_text6, output_text7]
|
97 |
+
button = gr.Button("Separate")
|
98 |
+
button.click(separator, inputs=[input_audio, rec_audio], outputs=outputs_audio + outputs_text)
|
99 |
+
|
100 |
+
demo.launch()
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
ffmpeg
|
2 |
+
libsndfile1-dev
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pesq==0.0.2
|
2 |
+
tqdm
|
3 |
+
hydra_core==1.0.3
|
4 |
+
hydra_colorlog==1.0.0
|
5 |
+
pystoi==0.3.3
|
6 |
+
librosa==0.7.1
|
7 |
+
numba==0.48
|
8 |
+
numpy
|
9 |
+
flask
|
10 |
+
flask-cors
|
11 |
+
uvicorn[standard]
|
12 |
+
asgiref
|
13 |
+
gradio
|
14 |
+
transformers==4.24.0
|
15 |
+
torch
|
16 |
+
torchvision
|
17 |
+
torchaudio
|
18 |
+
optimum[onnxruntime]==1.5.0
|
svoice/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
svoice/data/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
svoice/data/audio.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Author: Alexandre Défossez @adefossez, 2020
|
7 |
+
|
8 |
+
import json
|
9 |
+
from pathlib import Path
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import tqdm
|
13 |
+
import sys
|
14 |
+
|
15 |
+
import torchaudio
|
16 |
+
torchaudio.set_audio_backend("sox_io")
|
17 |
+
import soundfile as sf
|
18 |
+
import torch as th
|
19 |
+
from torch.nn import functional as F
|
20 |
+
|
21 |
+
|
22 |
+
# If used, this should be saved somewhere as it takes quite a bit
|
23 |
+
# of time to generate
|
24 |
+
def find_audio_files(path, exts=[".wav"], progress=True):
|
25 |
+
audio_files = []
|
26 |
+
for root, folders, files in os.walk(path, followlinks=True):
|
27 |
+
for file in files:
|
28 |
+
file = Path(root) / file
|
29 |
+
if file.suffix.lower() in exts:
|
30 |
+
audio_files.append(str(os.path.abspath(file)))
|
31 |
+
meta = []
|
32 |
+
if progress:
|
33 |
+
audio_files = tqdm.tqdm(audio_files, ncols=80)
|
34 |
+
for file in audio_files:
|
35 |
+
siginfo, _ = torchaudio.info(file)
|
36 |
+
length = siginfo.length // siginfo.channels
|
37 |
+
meta.append((file, length))
|
38 |
+
meta.sort()
|
39 |
+
return meta
|
40 |
+
|
41 |
+
|
42 |
+
class Audioset:
|
43 |
+
def __init__(self, files, length=None, stride=None, pad=True, augment=None):
|
44 |
+
"""
|
45 |
+
files should be a list [(file, length)]
|
46 |
+
"""
|
47 |
+
self.files = files
|
48 |
+
self.num_examples = []
|
49 |
+
self.length = length
|
50 |
+
self.stride = stride or length
|
51 |
+
self.augment = augment
|
52 |
+
for file, file_length in self.files:
|
53 |
+
if length is None:
|
54 |
+
examples = 1
|
55 |
+
elif file_length < length:
|
56 |
+
examples = 1 if pad else 0
|
57 |
+
elif pad:
|
58 |
+
examples = int(
|
59 |
+
math.ceil((file_length - self.length) / self.stride) + 1)
|
60 |
+
else:
|
61 |
+
examples = (file_length - self.length) // self.stride + 1
|
62 |
+
self.num_examples.append(examples)
|
63 |
+
|
64 |
+
def __len__(self):
|
65 |
+
return sum(self.num_examples)
|
66 |
+
|
67 |
+
def __getitem__(self, index):
|
68 |
+
for (file, _), examples in zip(self.files, self.num_examples):
|
69 |
+
if index >= examples:
|
70 |
+
index -= examples
|
71 |
+
continue
|
72 |
+
num_frames = 0
|
73 |
+
offset = 0
|
74 |
+
if self.length is not None:
|
75 |
+
offset = self.stride * index
|
76 |
+
num_frames = self.length
|
77 |
+
# out = th.Tensor(sf.read(str(file), start=offset, frames=num_frames)[0]).unsqueeze(0)
|
78 |
+
out = torchaudio.load(str(file), frame_offset=offset,
|
79 |
+
num_frames=num_frames)[0]
|
80 |
+
if self.augment:
|
81 |
+
out = self.augment(out.squeeze(0).numpy()).unsqueeze(0)
|
82 |
+
if num_frames:
|
83 |
+
out = F.pad(out, (0, num_frames - out.shape[-1]))
|
84 |
+
return out[0]
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
json.dump(find_audio_files(sys.argv[1]), sys.stdout, indent=4)
|
89 |
+
print()
|
svoice/data/data.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Authors: Yossi Adi (adiyoss) and Alexandre Défossez (adefossez)
|
7 |
+
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
import math
|
11 |
+
from pathlib import Path
|
12 |
+
import os
|
13 |
+
import re
|
14 |
+
|
15 |
+
import librosa
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.utils.data as data
|
19 |
+
|
20 |
+
from .preprocess import preprocess_one_dir
|
21 |
+
from .audio import Audioset
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
def sort(infos): return sorted(
|
27 |
+
infos, key=lambda info: int(info[1]), reverse=True)
|
28 |
+
|
29 |
+
|
30 |
+
class Trainset:
|
31 |
+
def __init__(self, json_dir, sample_rate=16000, segment=4.0, stride=1.0, pad=True):
|
32 |
+
mix_json = os.path.join(json_dir, 'mix.json')
|
33 |
+
s_jsons = list()
|
34 |
+
s_infos = list()
|
35 |
+
sets_re = re.compile(r's[0-9]+.json')
|
36 |
+
print(os.listdir(json_dir))
|
37 |
+
for s in os.listdir(json_dir):
|
38 |
+
if sets_re.search(s):
|
39 |
+
s_jsons.append(os.path.join(json_dir, s))
|
40 |
+
|
41 |
+
with open(mix_json, 'r') as f:
|
42 |
+
mix_infos = json.load(f)
|
43 |
+
for s_json in s_jsons:
|
44 |
+
with open(s_json, 'r') as f:
|
45 |
+
s_infos.append(json.load(f))
|
46 |
+
|
47 |
+
length = int(sample_rate * segment)
|
48 |
+
stride = int(sample_rate * stride)
|
49 |
+
|
50 |
+
kw = {'length': length, 'stride': stride, 'pad': pad}
|
51 |
+
self.mix_set = Audioset(sort(mix_infos), **kw)
|
52 |
+
|
53 |
+
self.sets = list()
|
54 |
+
for s_info in s_infos:
|
55 |
+
self.sets.append(Audioset(sort(s_info), **kw))
|
56 |
+
|
57 |
+
# verify all sets has the same size
|
58 |
+
for s in self.sets:
|
59 |
+
assert len(s) == len(self.mix_set)
|
60 |
+
|
61 |
+
def __getitem__(self, index):
|
62 |
+
mix_sig = self.mix_set[index]
|
63 |
+
tgt_sig = [self.sets[i][index] for i in range(len(self.sets))]
|
64 |
+
return self.mix_set[index], torch.LongTensor([mix_sig.shape[0]]), torch.stack(tgt_sig)
|
65 |
+
|
66 |
+
def __len__(self):
|
67 |
+
return len(self.mix_set)
|
68 |
+
|
69 |
+
|
70 |
+
class Validset:
|
71 |
+
"""
|
72 |
+
load entire wav.
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, json_dir):
|
76 |
+
mix_json = os.path.join(json_dir, 'mix.json')
|
77 |
+
s_jsons = list()
|
78 |
+
s_infos = list()
|
79 |
+
sets_re = re.compile(r's[0-9]+.json')
|
80 |
+
for s in os.listdir(json_dir):
|
81 |
+
if sets_re.search(s):
|
82 |
+
s_jsons.append(os.path.join(json_dir, s))
|
83 |
+
with open(mix_json, 'r') as f:
|
84 |
+
mix_infos = json.load(f)
|
85 |
+
for s_json in s_jsons:
|
86 |
+
with open(s_json, 'r') as f:
|
87 |
+
s_infos.append(json.load(f))
|
88 |
+
self.mix_set = Audioset(sort(mix_infos))
|
89 |
+
self.sets = list()
|
90 |
+
for s_info in s_infos:
|
91 |
+
self.sets.append(Audioset(sort(s_info)))
|
92 |
+
for s in self.sets:
|
93 |
+
assert len(s) == len(self.mix_set)
|
94 |
+
|
95 |
+
def __getitem__(self, index):
|
96 |
+
mix_sig = self.mix_set[index]
|
97 |
+
tgt_sig = [self.sets[i][index] for i in range(len(self.sets))]
|
98 |
+
return self.mix_set[index], torch.LongTensor([mix_sig.shape[0]]), torch.stack(tgt_sig)
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self.mix_set)
|
102 |
+
|
103 |
+
|
104 |
+
# The following piece of code was adapted from https://github.com/kaituoxu/Conv-TasNet
|
105 |
+
# released under the MIT License.
|
106 |
+
# Author: Kaituo XU
|
107 |
+
# Created on 2018/12
|
108 |
+
class EvalDataset(data.Dataset):
|
109 |
+
|
110 |
+
def __init__(self, mix_dir, mix_json, batch_size, sample_rate=8000):
|
111 |
+
"""
|
112 |
+
Args:
|
113 |
+
mix_dir: directory including mixture wav files
|
114 |
+
mix_json: json file including mixture wav files
|
115 |
+
"""
|
116 |
+
super(EvalDataset, self).__init__()
|
117 |
+
assert mix_dir != None or mix_json != None
|
118 |
+
if mix_dir is not None:
|
119 |
+
# Generate mix.json given mix_dir
|
120 |
+
preprocess_one_dir(mix_dir, mix_dir, 'mix',
|
121 |
+
sample_rate=sample_rate)
|
122 |
+
mix_json = os.path.join(mix_dir, 'mix.json')
|
123 |
+
with open(mix_json, 'r') as f:
|
124 |
+
mix_infos = json.load(f)
|
125 |
+
# sort it by #samples (impl bucket)
|
126 |
+
def sort(infos): return sorted(
|
127 |
+
infos, key=lambda info: int(info[1]), reverse=True)
|
128 |
+
sorted_mix_infos = sort(mix_infos)
|
129 |
+
# generate minibach infomations
|
130 |
+
minibatch = []
|
131 |
+
start = 0
|
132 |
+
while True:
|
133 |
+
end = min(len(sorted_mix_infos), start + batch_size)
|
134 |
+
minibatch.append([sorted_mix_infos[start:end],
|
135 |
+
sample_rate])
|
136 |
+
if end == len(sorted_mix_infos):
|
137 |
+
break
|
138 |
+
start = end
|
139 |
+
self.minibatch = minibatch
|
140 |
+
|
141 |
+
def __getitem__(self, index):
|
142 |
+
return self.minibatch[index]
|
143 |
+
|
144 |
+
def __len__(self):
|
145 |
+
return len(self.minibatch)
|
146 |
+
|
147 |
+
|
148 |
+
class EvalDataLoader(data.DataLoader):
|
149 |
+
"""
|
150 |
+
NOTE: just use batchsize=1 here, so drop_last=True makes no sense here.
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(self, *args, **kwargs):
|
154 |
+
super(EvalDataLoader, self).__init__(*args, **kwargs)
|
155 |
+
self.collate_fn = _collate_fn_eval
|
156 |
+
|
157 |
+
|
158 |
+
def _collate_fn_eval(batch):
|
159 |
+
"""
|
160 |
+
Args:
|
161 |
+
batch: list, len(batch) = 1. See AudioDataset.__getitem__()
|
162 |
+
Returns:
|
163 |
+
mixtures_pad: B x T, torch.Tensor
|
164 |
+
ilens : B, torch.Tentor
|
165 |
+
filenames: a list contain B strings
|
166 |
+
"""
|
167 |
+
# batch should be located in list
|
168 |
+
assert len(batch) == 1
|
169 |
+
mixtures, filenames = load_mixtures(batch[0])
|
170 |
+
|
171 |
+
# get batch of lengths of input sequences
|
172 |
+
ilens = np.array([mix.shape[0] for mix in mixtures])
|
173 |
+
|
174 |
+
# perform padding and convert to tensor
|
175 |
+
pad_value = 0
|
176 |
+
mixtures_pad = pad_list([torch.from_numpy(mix).float()
|
177 |
+
for mix in mixtures], pad_value)
|
178 |
+
ilens = torch.from_numpy(ilens)
|
179 |
+
return mixtures_pad, ilens, filenames
|
180 |
+
|
181 |
+
|
182 |
+
def load_mixtures(batch):
|
183 |
+
"""
|
184 |
+
Returns:
|
185 |
+
mixtures: a list containing B items, each item is T np.ndarray
|
186 |
+
filenames: a list containing B strings
|
187 |
+
T varies from item to item.
|
188 |
+
"""
|
189 |
+
mixtures, filenames = [], []
|
190 |
+
mix_infos, sample_rate = batch
|
191 |
+
# for each utterance
|
192 |
+
for mix_info in mix_infos:
|
193 |
+
mix_path = mix_info[0]
|
194 |
+
# read wav file
|
195 |
+
mix, _ = librosa.load(mix_path, sr=sample_rate)
|
196 |
+
mixtures.append(mix)
|
197 |
+
filenames.append(mix_path)
|
198 |
+
return mixtures, filenames
|
199 |
+
|
200 |
+
|
201 |
+
def pad_list(xs, pad_value):
|
202 |
+
n_batch = len(xs)
|
203 |
+
max_len = max(x.size(0) for x in xs)
|
204 |
+
pad = xs[0].new(n_batch, max_len, * xs[0].size()[1:]).fill_(pad_value)
|
205 |
+
for i in range(n_batch):
|
206 |
+
pad[i, :xs[i].size(0)] = xs[i]
|
207 |
+
return pad
|
svoice/data/preprocess.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The following piece of code was adapted from https://github.com/kaituoxu/Conv-TasNet
|
2 |
+
# released under the MIT License.
|
3 |
+
# Author: Kaituo XU
|
4 |
+
# Created on 2018/12
|
5 |
+
|
6 |
+
# Revised by: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
|
12 |
+
import librosa
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
def preprocess_one_dir(in_dir, out_dir, out_filename, sample_rate=8000):
|
17 |
+
file_infos = []
|
18 |
+
in_dir = os.path.abspath(in_dir)
|
19 |
+
wav_list = os.listdir(in_dir)
|
20 |
+
for wav_file in tqdm(wav_list):
|
21 |
+
if not wav_file.endswith('.wav'):
|
22 |
+
continue
|
23 |
+
wav_path = os.path.join(in_dir, wav_file)
|
24 |
+
samples, _ = librosa.load(wav_path, sr=sample_rate)
|
25 |
+
file_infos.append((wav_path, len(samples)))
|
26 |
+
if not os.path.exists(out_dir):
|
27 |
+
os.makedirs(out_dir)
|
28 |
+
with open(os.path.join(out_dir, out_filename + '.json'), 'w') as f:
|
29 |
+
json.dump(file_infos, f, indent=4)
|
30 |
+
|
31 |
+
|
32 |
+
def preprocess(args):
|
33 |
+
for data_type in ['tr', 'cv', 'tt']:
|
34 |
+
for signal in ['noisy', 'clean']:
|
35 |
+
preprocess_one_dir(os.path.join(args.in_dir, data_type, signal),
|
36 |
+
os.path.join(args.out_dir, data_type),
|
37 |
+
signal,
|
38 |
+
sample_rate=args.sample_rate)
|
39 |
+
|
40 |
+
|
41 |
+
def preprocess_alldirs(args):
|
42 |
+
for d in os.listdir(args.in_dir):
|
43 |
+
local_dir = os.path.join(args.in_dir, d)
|
44 |
+
if os.path.isdir(local_dir):
|
45 |
+
preprocess_one_dir(os.path.join(args.in_dir, local_dir),
|
46 |
+
os.path.join(args.out_dir),
|
47 |
+
d,
|
48 |
+
sample_rate=args.sample_rate)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
parser = argparse.ArgumentParser("WSJ0 data preprocessing")
|
53 |
+
parser.add_argument('--in_dir', type=str, default=None,
|
54 |
+
help='Directory path of wsj0 including tr, cv and tt')
|
55 |
+
parser.add_argument('--out_dir', type=str, default=None,
|
56 |
+
help='Directory path to put output files')
|
57 |
+
parser.add_argument('--sample_rate', type=int, default=16000,
|
58 |
+
help='Sample rate of audio file')
|
59 |
+
parser.add_argument("--one_dir", action="store_true",
|
60 |
+
help="Generate json files from specific directory")
|
61 |
+
parser.add_argument("--all_dirs", action="store_true",
|
62 |
+
help="Generate json files from all dirs in specific directory")
|
63 |
+
parser.add_argument('--json_name', type=str, default=None,
|
64 |
+
help='The name of the json to be generated. '
|
65 |
+
'To be used only with one-dir option.')
|
66 |
+
args = parser.parse_args()
|
67 |
+
print(args)
|
68 |
+
if args.all_dirs:
|
69 |
+
preprocess_alldirs(args)
|
70 |
+
elif args.one_dir:
|
71 |
+
preprocess_one_dir(args.in_dir, args.out_dir,
|
72 |
+
args.json_name, sample_rate=args.sample_rate)
|
73 |
+
else:
|
74 |
+
preprocess(args)
|
svoice/distrib.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.utils.data.distributed import DistributedSampler
|
13 |
+
from torch.utils.data import DataLoader, Subset
|
14 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
rank = 0
|
18 |
+
world_size = 1
|
19 |
+
|
20 |
+
|
21 |
+
def init(args):
|
22 |
+
"""init.
|
23 |
+
Initialize DDP using the given rendezvous file.
|
24 |
+
"""
|
25 |
+
global rank, world_size
|
26 |
+
if args.ddp:
|
27 |
+
assert args.rank is not None and args.world_size is not None
|
28 |
+
rank = args.rank
|
29 |
+
world_size = args.world_size
|
30 |
+
if world_size == 1:
|
31 |
+
return
|
32 |
+
torch.cuda.set_device(rank)
|
33 |
+
torch.distributed.init_process_group(
|
34 |
+
backend=args.ddp_backend,
|
35 |
+
init_method='file://' + os.path.abspath(args.rendezvous_file),
|
36 |
+
world_size=world_size,
|
37 |
+
rank=rank)
|
38 |
+
logger.debug("Distributed rendezvous went well, rank %d/%d", rank, world_size)
|
39 |
+
|
40 |
+
|
41 |
+
def average(metrics, count=1.):
|
42 |
+
"""average.
|
43 |
+
Average all the relevant metrices across processes
|
44 |
+
`metrics`should be a 1D float32 fector. Returns the average of `metrics`
|
45 |
+
over all hosts. You can use `count` to control the weight of each worker.
|
46 |
+
"""
|
47 |
+
if world_size == 1:
|
48 |
+
return metrics
|
49 |
+
tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
|
50 |
+
tensor *= count
|
51 |
+
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
|
52 |
+
return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
|
53 |
+
|
54 |
+
|
55 |
+
def wrap(model):
|
56 |
+
"""wrap.
|
57 |
+
Wrap a model with DDP if distributed training is enabled.
|
58 |
+
"""
|
59 |
+
if world_size == 1:
|
60 |
+
return model
|
61 |
+
else:
|
62 |
+
return DistributedDataParallel(
|
63 |
+
model,
|
64 |
+
device_ids=[torch.cuda.current_device()],
|
65 |
+
output_device=torch.cuda.current_device())
|
66 |
+
|
67 |
+
|
68 |
+
def barrier():
|
69 |
+
if world_size > 1:
|
70 |
+
torch.distributed.barrier()
|
71 |
+
|
72 |
+
|
73 |
+
def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
|
74 |
+
"""loader.
|
75 |
+
Create a dataloader properly in case of distributed training.
|
76 |
+
If a gradient is going to be computed you must set `shuffle=True`.
|
77 |
+
:param dataset: the dataset to be parallelized
|
78 |
+
:param args: relevant args for the loader
|
79 |
+
:param shuffle: shuffle examples
|
80 |
+
:param klass: loader class
|
81 |
+
:param kwargs: relevant args
|
82 |
+
"""
|
83 |
+
|
84 |
+
if world_size == 1:
|
85 |
+
return klass(dataset, *args, shuffle=shuffle, **kwargs)
|
86 |
+
|
87 |
+
if shuffle:
|
88 |
+
# train means we will compute backward, we use DistributedSampler
|
89 |
+
sampler = DistributedSampler(dataset)
|
90 |
+
# We ignore shuffle, DistributedSampler already shuffles
|
91 |
+
return klass(dataset, *args, **kwargs, sampler=sampler)
|
92 |
+
else:
|
93 |
+
# We make a manual shard, as DistributedSampler otherwise replicate some examples
|
94 |
+
dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
|
95 |
+
return klass(dataset, *args, shuffle=shuffle)
|
svoice/evaluate.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Authors: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf and Alexandre Defossez (adefossez)
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
from concurrent.futures import ProcessPoolExecutor
|
11 |
+
import json
|
12 |
+
import logging
|
13 |
+
import sys
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
from pesq import pesq
|
17 |
+
from pystoi import stoi
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from .models.sisnr_loss import cal_loss
|
21 |
+
from .data.data import Validset
|
22 |
+
from . import distrib
|
23 |
+
from .utils import bold, deserialize_model, LogProgress
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
parser = argparse.ArgumentParser(
|
29 |
+
'Evaluate separation performance using MulCat blocks')
|
30 |
+
parser.add_argument('model_path',
|
31 |
+
help='Path to model file created by training')
|
32 |
+
parser.add_argument('data_dir',
|
33 |
+
help='directory including mix.json, s1.json, s2.json, ... files')
|
34 |
+
parser.add_argument('--device', default="cuda")
|
35 |
+
parser.add_argument('--sdr', type=int, default=0)
|
36 |
+
parser.add_argument('--sample_rate', default=16000,
|
37 |
+
type=int, help='Sample rate')
|
38 |
+
parser.add_argument('--num_workers', type=int, default=5)
|
39 |
+
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
|
40 |
+
default=logging.INFO, help="More loggging")
|
41 |
+
|
42 |
+
|
43 |
+
def evaluate(args, model=None, data_loader=None, sr=None):
|
44 |
+
total_sisnr = 0
|
45 |
+
total_pesq = 0
|
46 |
+
total_stoi = 0
|
47 |
+
total_cnt = 0
|
48 |
+
updates = 5
|
49 |
+
|
50 |
+
# Load model
|
51 |
+
if not model:
|
52 |
+
pkg = torch.load(args.model_path, map_location=args.device)
|
53 |
+
if 'model' in pkg:
|
54 |
+
model = pkg['model']
|
55 |
+
else:
|
56 |
+
model = pkg
|
57 |
+
model = deserialize_model(model)
|
58 |
+
if 'best_state' in pkg:
|
59 |
+
model.load_state_dict(pkg['best_state'])
|
60 |
+
logger.debug(model)
|
61 |
+
model.eval()
|
62 |
+
model.to(args.device)
|
63 |
+
# Load data
|
64 |
+
if not data_loader:
|
65 |
+
dataset = Validset(args.data_dir)
|
66 |
+
data_loader = distrib.loader(
|
67 |
+
dataset, batch_size=1, num_workers=args.num_workers)
|
68 |
+
sr = args.sample_rate
|
69 |
+
pendings = []
|
70 |
+
with ProcessPoolExecutor(args.num_workers) as pool:
|
71 |
+
with torch.no_grad():
|
72 |
+
iterator = LogProgress(logger, data_loader, name="Eval estimates")
|
73 |
+
for i, data in enumerate(iterator):
|
74 |
+
# Get batch data
|
75 |
+
mixture, lengths, sources = [x.to(args.device) for x in data]
|
76 |
+
# Forward
|
77 |
+
with torch.no_grad():
|
78 |
+
mixture /= mixture.max()
|
79 |
+
estimate = model(mixture)[-1]
|
80 |
+
sisnr_loss, snr, estimate, reorder_estimate = cal_loss(
|
81 |
+
sources, estimate, lengths)
|
82 |
+
reorder_estimate = reorder_estimate.cpu()
|
83 |
+
sources = sources.cpu()
|
84 |
+
mixture = mixture.cpu()
|
85 |
+
|
86 |
+
pendings.append(
|
87 |
+
pool.submit(_run_metrics, sources, reorder_estimate, mixture, None,
|
88 |
+
sr=sr))
|
89 |
+
total_cnt += sources.shape[0]
|
90 |
+
|
91 |
+
for pending in LogProgress(logger, pendings, updates, name="Eval metrics"):
|
92 |
+
sisnr_i, pesq_i, stoi_i = pending.result()
|
93 |
+
total_sisnr += sisnr_i
|
94 |
+
total_pesq += pesq_i
|
95 |
+
total_stoi += stoi_i
|
96 |
+
|
97 |
+
metrics = [total_sisnr, total_pesq, total_stoi]
|
98 |
+
sisnr, pesq, stoi = distrib.average(
|
99 |
+
[m/total_cnt for m in metrics], total_cnt)
|
100 |
+
logger.info(
|
101 |
+
bold(f'Test set performance: SISNRi={sisnr:.2f} PESQ={pesq}, STOI={stoi}.'))
|
102 |
+
return sisnr, pesq, stoi
|
103 |
+
|
104 |
+
|
105 |
+
def _run_metrics(clean, estimate, mix, model, sr, pesq=False):
|
106 |
+
if model is not None:
|
107 |
+
torch.set_num_threads(1)
|
108 |
+
# parallel evaluation here
|
109 |
+
with torch.no_grad():
|
110 |
+
estimate = model(estimate)[-1]
|
111 |
+
estimate = estimate.numpy()
|
112 |
+
clean = clean.numpy()
|
113 |
+
mix = mix.numpy()
|
114 |
+
sisnr = cal_SISNRi(clean, estimate, mix)
|
115 |
+
if pesq:
|
116 |
+
pesq_i = cal_PESQ(clean, estimate, sr=sr)
|
117 |
+
stoi_i = cal_STOI(clean, estimate, sr=sr)
|
118 |
+
else:
|
119 |
+
pesq_i = 0
|
120 |
+
stoi_i = 0
|
121 |
+
return sisnr.mean(), pesq_i, stoi_i
|
122 |
+
|
123 |
+
|
124 |
+
def cal_SISNR(ref_sig, out_sig, eps=1e-8):
|
125 |
+
"""Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
|
126 |
+
Args:
|
127 |
+
ref_sig: numpy.ndarray, [B, T]
|
128 |
+
out_sig: numpy.ndarray, [B, T]
|
129 |
+
Returns:
|
130 |
+
SISNR
|
131 |
+
"""
|
132 |
+
assert len(ref_sig) == len(out_sig)
|
133 |
+
B, T = ref_sig.shape
|
134 |
+
ref_sig = ref_sig - np.mean(ref_sig, axis=1).reshape(B, 1)
|
135 |
+
out_sig = out_sig - np.mean(out_sig, axis=1).reshape(B, 1)
|
136 |
+
ref_energy = (np.sum(ref_sig ** 2, axis=1) + eps).reshape(B, 1)
|
137 |
+
proj = (np.sum(ref_sig * out_sig, axis=1).reshape(B, 1)) * \
|
138 |
+
ref_sig / ref_energy
|
139 |
+
noise = out_sig - proj
|
140 |
+
ratio = np.sum(proj ** 2, axis=1) / (np.sum(noise ** 2, axis=1) + eps)
|
141 |
+
sisnr = 10 * np.log(ratio + eps) / np.log(10.0)
|
142 |
+
return sisnr.mean()
|
143 |
+
|
144 |
+
|
145 |
+
def cal_PESQ(ref_sig, out_sig, sr):
|
146 |
+
"""Calculate PESQ.
|
147 |
+
Args:
|
148 |
+
ref_sig: numpy.ndarray, [B, C, T]
|
149 |
+
out_sig: numpy.ndarray, [B, C, T]
|
150 |
+
Returns
|
151 |
+
PESQ
|
152 |
+
"""
|
153 |
+
B, C, T = ref_sig.shape
|
154 |
+
ref_sig = ref_sig.reshape(B*C, T)
|
155 |
+
out_sig = out_sig.reshape(B*C, T)
|
156 |
+
pesq_val = 0
|
157 |
+
for i in range(len(ref_sig)):
|
158 |
+
pesq_val += pesq(sr, ref_sig[i], out_sig[i], 'nb')
|
159 |
+
return pesq_val / (B*C)
|
160 |
+
|
161 |
+
|
162 |
+
def cal_STOI(ref_sig, out_sig, sr):
|
163 |
+
"""Calculate STOI.
|
164 |
+
Args:
|
165 |
+
ref_sig: numpy.ndarray, [B, C, T]
|
166 |
+
out_sig: numpy.ndarray, [B, C, T]
|
167 |
+
Returns:
|
168 |
+
STOI
|
169 |
+
"""
|
170 |
+
B, C, T = ref_sig.shape
|
171 |
+
ref_sig = ref_sig.reshape(B*C, T)
|
172 |
+
out_sig = out_sig.reshape(B*C, T)
|
173 |
+
try:
|
174 |
+
stoi_val = 0
|
175 |
+
for i in range(len(ref_sig)):
|
176 |
+
stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False)
|
177 |
+
return stoi_val / (B*C)
|
178 |
+
except:
|
179 |
+
return 0
|
180 |
+
|
181 |
+
|
182 |
+
def cal_SISNRi(src_ref, src_est, mix):
|
183 |
+
"""Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi)
|
184 |
+
Args:
|
185 |
+
src_ref: numpy.ndarray, [B, C, T]
|
186 |
+
src_est: numpy.ndarray, [B, C, T], reordered by best PIT permutation
|
187 |
+
mix: numpy.ndarray, [T]
|
188 |
+
Returns:
|
189 |
+
average_SISNRi
|
190 |
+
"""
|
191 |
+
avg_SISNRi = 0.0
|
192 |
+
B, C, T = src_ref.shape
|
193 |
+
for c in range(C):
|
194 |
+
sisnr = cal_SISNR(src_ref[:, c], src_est[:, c])
|
195 |
+
sisnrb = cal_SISNR(src_ref[:, c], mix)
|
196 |
+
avg_SISNRi += (sisnr - sisnrb)
|
197 |
+
avg_SISNRi /= C
|
198 |
+
return avg_SISNRi
|
199 |
+
|
200 |
+
|
201 |
+
def main():
|
202 |
+
args = parser.parse_args()
|
203 |
+
logging.basicConfig(stream=sys.stderr, level=args.verbose)
|
204 |
+
logger.debug(args)
|
205 |
+
sisnr, pesq, stoi = evaluate(args)
|
206 |
+
json.dump({'sisnr': sisnr,
|
207 |
+
'pesq': pesq, 'stoi': stoi}, sys.stdout)
|
208 |
+
sys.stdout.write('\n')
|
209 |
+
|
210 |
+
|
211 |
+
if __name__ == '__main__':
|
212 |
+
main()
|
svoice/evaluate_auto_select.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Authors: Yossi Adi (adiyoss)
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
from concurrent.futures import ProcessPoolExecutor
|
11 |
+
import json
|
12 |
+
import logging
|
13 |
+
import sys
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
from pesq import pesq
|
17 |
+
from pystoi import stoi
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from .models.sisnr_loss import cal_loss
|
21 |
+
from .data.data import Validset
|
22 |
+
from . import distrib
|
23 |
+
from .utils import bold, deserialize_model, LogProgress
|
24 |
+
from .evaluate import _run_metrics
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
parser = argparse.ArgumentParser(
|
30 |
+
'Evaluate model automatic selection performance')
|
31 |
+
parser.add_argument('model_path_2spk',
|
32 |
+
help='Path to 2spk model file created by training')
|
33 |
+
parser.add_argument('model_path_3spk',
|
34 |
+
help='Path to 3spk model file created by training')
|
35 |
+
parser.add_argument('model_path_4spk',
|
36 |
+
help='Path to 4spk model file created by training')
|
37 |
+
parser.add_argument('model_path_5spk',
|
38 |
+
help='Path to 5spk model file created by training')
|
39 |
+
parser.add_argument(
|
40 |
+
'data_dir', help='directory including mix.json, s1.json and s2.json files')
|
41 |
+
parser.add_argument('--device', default="cuda")
|
42 |
+
parser.add_argument('--sample_rate', default=8000,
|
43 |
+
type=int, help='Sample rate')
|
44 |
+
parser.add_argument('--thresh', default=0.001,
|
45 |
+
type=float, help='Threshold for model auto selection')
|
46 |
+
parser.add_argument('--num_workers', type=int, default=5)
|
47 |
+
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
|
48 |
+
default=logging.INFO, help="More loggging")
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
# test pariwise matching
|
53 |
+
def pair_wise(padded_source, estimate_source):
|
54 |
+
pair_wise = torch.sum(padded_source.unsqueeze(
|
55 |
+
1)*estimate_source.unsqueeze(2), dim=3)
|
56 |
+
if estimate_source.shape[1] != padded_source.shape[1]:
|
57 |
+
idxs = pair_wise.argmax(dim=1)
|
58 |
+
new_src = torch.FloatTensor(padded_source.shape)
|
59 |
+
for b, idx in enumerate(idxs):
|
60 |
+
new_src[b:, :, ] = estimate_source[b][idx]
|
61 |
+
padded_source_pad = padded_source
|
62 |
+
estimate_source_pad = new_src.cuda()
|
63 |
+
else:
|
64 |
+
padded_source_pad = padded_source
|
65 |
+
estimate_source_pad = estimate_source
|
66 |
+
return estimate_source_pad
|
67 |
+
|
68 |
+
|
69 |
+
def evaluate_auto_select(args):
|
70 |
+
total_sisnr = 0
|
71 |
+
total_pesq = 0
|
72 |
+
total_stoi = 0
|
73 |
+
total_cnt = 0
|
74 |
+
updates = 5
|
75 |
+
|
76 |
+
models = list()
|
77 |
+
paths = [args.model_path_2spk, args.model_path_3spk,
|
78 |
+
args.model_path_4spk, args.model_path_5spk]
|
79 |
+
|
80 |
+
for path in paths:
|
81 |
+
# Load model
|
82 |
+
pkg = torch.load(path)
|
83 |
+
if 'model' in pkg:
|
84 |
+
model = pkg['model']
|
85 |
+
else:
|
86 |
+
model = pkg
|
87 |
+
model = deserialize_model(model)
|
88 |
+
if 'best_state' in pkg:
|
89 |
+
model.load_state_dict(pkg['best_state'])
|
90 |
+
logger.debug(model)
|
91 |
+
|
92 |
+
model.eval()
|
93 |
+
model.to(args.device)
|
94 |
+
models.append(model)
|
95 |
+
|
96 |
+
# Load data
|
97 |
+
dataset = Validset(args.data_dir)
|
98 |
+
data_loader = distrib.loader(
|
99 |
+
dataset, batch_size=1, num_workers=args.num_workers)
|
100 |
+
sr = args.sample_rate
|
101 |
+
y_hat = torch.zeros((4))
|
102 |
+
|
103 |
+
pendings = []
|
104 |
+
with ProcessPoolExecutor(args.num_workers) as pool:
|
105 |
+
with torch.no_grad():
|
106 |
+
iterator = LogProgress(logger, data_loader, name="Eval estimates")
|
107 |
+
for i, data in enumerate(iterator):
|
108 |
+
# Get batch data
|
109 |
+
mixture, lengths, sources = [x.to(args.device) for x in data]
|
110 |
+
estimated_sources = list()
|
111 |
+
reorder_estimated_sources = list()
|
112 |
+
|
113 |
+
for model in models:
|
114 |
+
# Forward
|
115 |
+
with torch.no_grad():
|
116 |
+
raw_estimate = model(mixture)[-1]
|
117 |
+
|
118 |
+
estimate = pair_wise(sources, raw_estimate)
|
119 |
+
sisnr_loss, snr, estimate, reorder_estimate = cal_loss(
|
120 |
+
sources, estimate, lengths)
|
121 |
+
estimated_sources.insert(0, raw_estimate)
|
122 |
+
reorder_estimated_sources.insert(0, reorder_estimate)
|
123 |
+
|
124 |
+
# =================== DETECT NUM. NON-ACTIVE CHANNELS ============== #
|
125 |
+
selected_idx = 0
|
126 |
+
thresh = args.thresh
|
127 |
+
max_spk = 5
|
128 |
+
mix_spk = 2
|
129 |
+
ground = (max_spk - mix_spk)
|
130 |
+
while (selected_idx <= ground):
|
131 |
+
no_sils = 0
|
132 |
+
vals = torch.mean(
|
133 |
+
(estimated_sources[selected_idx]/torch.abs(estimated_sources[selected_idx]).max())**2, axis=2)
|
134 |
+
new_selected_idx = max_spk - len(vals[vals > thresh])
|
135 |
+
if new_selected_idx == selected_idx:
|
136 |
+
break
|
137 |
+
else:
|
138 |
+
selected_idx = new_selected_idx
|
139 |
+
if selected_idx < 0:
|
140 |
+
selected_idx = 0
|
141 |
+
elif selected_idx > ground:
|
142 |
+
selected_idx = ground
|
143 |
+
|
144 |
+
y_hat[ground - selected_idx] += 1
|
145 |
+
reorder_estimate = reorder_estimated_sources[selected_idx].cpu(
|
146 |
+
)
|
147 |
+
sources = sources.cpu()
|
148 |
+
mixture = mixture.cpu()
|
149 |
+
|
150 |
+
pendings.append(
|
151 |
+
pool.submit(_run_metrics, sources, reorder_estimate, mixture, None,
|
152 |
+
sr=sr))
|
153 |
+
total_cnt += sources.shape[0]
|
154 |
+
|
155 |
+
for pending in LogProgress(logger, pendings, updates, name="Eval metrics"):
|
156 |
+
sisnr_i, pesq_i, stoi_i = pending.result()
|
157 |
+
total_sisnr += sisnr_i
|
158 |
+
total_pesq += pesq_i
|
159 |
+
total_stoi += stoi_i
|
160 |
+
|
161 |
+
metrics = [total_sisnr, total_pesq, total_stoi]
|
162 |
+
sisnr, pesq, stoi = distrib.average(
|
163 |
+
[m/total_cnt for m in metrics], total_cnt)
|
164 |
+
logger.info(bold(f'Test set performance: SISNRi={sisnr:.2f} '
|
165 |
+
f'PESQ={pesq}, STOI={stoi}.'))
|
166 |
+
logger.info(f'Two spks prob: {y_hat[0]/(total_cnt)}')
|
167 |
+
logger.info(f'Three spks prob: {y_hat[1]/(total_cnt)}')
|
168 |
+
logger.info(f'Four spks prob: {y_hat[2]/(total_cnt)}')
|
169 |
+
logger.info(f'Five spks prob: {y_hat[3]/(total_cnt)}')
|
170 |
+
return sisnr, pesq, stoi
|
171 |
+
|
172 |
+
|
173 |
+
def main():
|
174 |
+
args = parser.parse_args()
|
175 |
+
logging.basicConfig(stream=sys.stderr, level=args.verbose)
|
176 |
+
logger.debug(args)
|
177 |
+
sisnr, pesq, stoi = evaluate_auto_select(args)
|
178 |
+
json.dump({'sisnr': sisnr,
|
179 |
+
'pesq': pesq, 'stoi': stoi}, sys.stdout)
|
180 |
+
sys.stdout.write('\n')
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == '__main__':
|
184 |
+
main()
|
svoice/executor.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Author: Alexandre Defossez (adefossez)
|
8 |
+
|
9 |
+
"""
|
10 |
+
Start multiple process locally for DDP.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import logging
|
14 |
+
import subprocess as sp
|
15 |
+
import sys
|
16 |
+
|
17 |
+
from hydra import utils
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class ChildrenManager:
|
23 |
+
def __init__(self):
|
24 |
+
self.children = []
|
25 |
+
self.failed = False
|
26 |
+
|
27 |
+
def add(self, child):
|
28 |
+
child.rank = len(self.children)
|
29 |
+
self.children.append(child)
|
30 |
+
|
31 |
+
def __enter__(self):
|
32 |
+
return self
|
33 |
+
|
34 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
35 |
+
if exc_value is not None:
|
36 |
+
logger.error(
|
37 |
+
"An exception happened while starting workers %r", exc_value)
|
38 |
+
self.failed = True
|
39 |
+
try:
|
40 |
+
while self.children and not self.failed:
|
41 |
+
for child in list(self.children):
|
42 |
+
try:
|
43 |
+
exitcode = child.wait(0.1)
|
44 |
+
except sp.TimeoutExpired:
|
45 |
+
continue
|
46 |
+
else:
|
47 |
+
self.children.remove(child)
|
48 |
+
if exitcode:
|
49 |
+
logger.error(
|
50 |
+
f"Worker {child.rank} died, killing all workers")
|
51 |
+
self.failed = True
|
52 |
+
except KeyboardInterrupt:
|
53 |
+
logger.error(
|
54 |
+
"Received keyboard interrupt, trying to kill all workers.")
|
55 |
+
self.failed = True
|
56 |
+
for child in self.children:
|
57 |
+
child.terminate()
|
58 |
+
if not self.failed:
|
59 |
+
logger.info("All workers completed successfully")
|
60 |
+
|
61 |
+
|
62 |
+
def start_ddp_workers():
|
63 |
+
import torch as th
|
64 |
+
|
65 |
+
world_size = th.cuda.device_count()
|
66 |
+
if not world_size:
|
67 |
+
logger.error(
|
68 |
+
"DDP is only available on GPU. Make sure GPUs are properly configured with cuda.")
|
69 |
+
sys.exit(1)
|
70 |
+
logger.info(f"Starting {world_size} worker processes for DDP.")
|
71 |
+
with ChildrenManager() as manager:
|
72 |
+
for rank in range(world_size):
|
73 |
+
kwargs = {}
|
74 |
+
argv = list(sys.argv)
|
75 |
+
argv += [f"world_size={world_size}", f"rank={rank}"]
|
76 |
+
if rank > 0:
|
77 |
+
kwargs['stdin'] = sp.DEVNULL
|
78 |
+
kwargs['stdout'] = sp.DEVNULL
|
79 |
+
kwargs['stderr'] = sp.DEVNULL
|
80 |
+
log = utils.HydraConfig().cfg.hydra.job_logging.handlers.file.filename
|
81 |
+
log += f".{rank}"
|
82 |
+
argv.append("hydra.job_logging.handlers.file.filename=" + log)
|
83 |
+
manager.add(sp.Popen([sys.executable] + argv,
|
84 |
+
cwd=utils.get_original_cwd(), **kwargs))
|
85 |
+
sys.exit(int(manager.failed))
|
svoice/models/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
svoice/models/sisnr_loss.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The following piece of code was adapted from https://github.com/kaituoxu/Conv-TasNet
|
2 |
+
# released under the MIT License.
|
3 |
+
# Author: Kaituo XU
|
4 |
+
# Created on 2018/12
|
5 |
+
|
6 |
+
from itertools import permutations
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
EPS = 1e-8
|
12 |
+
|
13 |
+
|
14 |
+
def cal_loss(source, estimate_source, source_lengths):
|
15 |
+
"""
|
16 |
+
Args:
|
17 |
+
source: [B, C, T], B is batch size
|
18 |
+
estimate_source: [B, C, T]
|
19 |
+
source_lengths: [B]
|
20 |
+
"""
|
21 |
+
max_snr, perms, max_snr_idx, snr_set = cal_si_snr_with_pit(source,
|
22 |
+
estimate_source,
|
23 |
+
source_lengths)
|
24 |
+
B, C, T = estimate_source.shape
|
25 |
+
loss = 0 - torch.mean(max_snr)
|
26 |
+
|
27 |
+
reorder_estimate_source = reorder_source(
|
28 |
+
estimate_source, perms, max_snr_idx)
|
29 |
+
return loss, max_snr, estimate_source, reorder_estimate_source
|
30 |
+
|
31 |
+
|
32 |
+
def cal_si_snr_with_pit(source, estimate_source, source_lengths):
|
33 |
+
"""Calculate SI-SNR with PIT training.
|
34 |
+
Args:
|
35 |
+
source: [B, C, T], B is batch size
|
36 |
+
estimate_source: [B, C, T]
|
37 |
+
source_lengths: [B], each item is between [0, T]
|
38 |
+
"""
|
39 |
+
|
40 |
+
assert source.size() == estimate_source.size()
|
41 |
+
B, C, T = source.size()
|
42 |
+
# mask padding position along T
|
43 |
+
mask = get_mask(source, source_lengths)
|
44 |
+
estimate_source *= mask
|
45 |
+
|
46 |
+
# Step 1. Zero-mean norm
|
47 |
+
num_samples = source_lengths.view(-1, 1, 1).float() # [B, 1, 1]
|
48 |
+
mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
|
49 |
+
mean_estimate = torch.sum(estimate_source, dim=2,
|
50 |
+
keepdim=True) / num_samples
|
51 |
+
zero_mean_target = source - mean_target
|
52 |
+
zero_mean_estimate = estimate_source - mean_estimate
|
53 |
+
# mask padding position along T
|
54 |
+
zero_mean_target *= mask
|
55 |
+
zero_mean_estimate *= mask
|
56 |
+
|
57 |
+
# Step 2. SI-SNR with PIT
|
58 |
+
# reshape to use broadcast
|
59 |
+
s_target = torch.unsqueeze(zero_mean_target, dim=1) # [B, 1, C, T]
|
60 |
+
s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2) # [B, C, 1, T]
|
61 |
+
# s_target = <s', s>s / ||s||^2
|
62 |
+
pair_wise_dot = torch.sum(s_estimate * s_target,
|
63 |
+
dim=3, keepdim=True) # [B, C, C, 1]
|
64 |
+
s_target_energy = torch.sum(
|
65 |
+
s_target ** 2, dim=3, keepdim=True) + EPS # [B, 1, C, 1]
|
66 |
+
pair_wise_proj = pair_wise_dot * s_target / s_target_energy # [B, C, C, T]
|
67 |
+
# e_noise = s' - s_target
|
68 |
+
e_noise = s_estimate - pair_wise_proj # [B, C, C, T]
|
69 |
+
# SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
|
70 |
+
pair_wise_si_snr = torch.sum(
|
71 |
+
pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS)
|
72 |
+
pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS) # [B, C, C]
|
73 |
+
pair_wise_si_snr = torch.transpose(pair_wise_si_snr, 1, 2)
|
74 |
+
|
75 |
+
# Get max_snr of each utterance
|
76 |
+
# permutations, [C!, C]
|
77 |
+
perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long)
|
78 |
+
# one-hot, [C!, C, C]
|
79 |
+
index = torch.unsqueeze(perms, 2)
|
80 |
+
perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1)
|
81 |
+
# [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
|
82 |
+
snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot])
|
83 |
+
max_snr_idx = torch.argmax(snr_set, dim=1) # [B]
|
84 |
+
# max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1)) # [B, 1]
|
85 |
+
max_snr, _ = torch.max(snr_set, dim=1, keepdim=True)
|
86 |
+
max_snr /= C
|
87 |
+
return max_snr, perms, max_snr_idx, snr_set / C
|
88 |
+
|
89 |
+
|
90 |
+
def reorder_source(source, perms, max_snr_idx):
|
91 |
+
"""
|
92 |
+
Args:
|
93 |
+
source: [B, C, T]
|
94 |
+
perms: [C!, C], permutations
|
95 |
+
max_snr_idx: [B], each item is between [0, C!)
|
96 |
+
Returns:
|
97 |
+
reorder_source: [B, C, T]
|
98 |
+
"""
|
99 |
+
B, C, *_ = source.size()
|
100 |
+
# [B, C], permutation whose SI-SNR is max of each utterance
|
101 |
+
# for each utterance, reorder estimate source according this permutation
|
102 |
+
max_snr_perm = torch.index_select(perms, dim=0, index=max_snr_idx)
|
103 |
+
# print('max_snr_perm', max_snr_perm)
|
104 |
+
# maybe use torch.gather()/index_select()/scatter() to impl this?
|
105 |
+
reorder_source = torch.zeros_like(source)
|
106 |
+
for b in range(B):
|
107 |
+
for c in range(C):
|
108 |
+
reorder_source[b, c] = source[b, max_snr_perm[b][c]]
|
109 |
+
return reorder_source
|
110 |
+
|
111 |
+
|
112 |
+
def get_mask(source, source_lengths):
|
113 |
+
"""
|
114 |
+
Args:
|
115 |
+
source: [B, C, T]
|
116 |
+
source_lengths: [B]
|
117 |
+
Returns:
|
118 |
+
mask: [B, 1, T]
|
119 |
+
"""
|
120 |
+
B, _, T = source.size()
|
121 |
+
mask = source.new_ones((B, 1, T))
|
122 |
+
for i in range(B):
|
123 |
+
mask[i, :, source_lengths[i]:] = 0
|
124 |
+
return mask
|
svoice/models/swave.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Authors: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf
|
8 |
+
|
9 |
+
import sys
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch.autograd import Variable
|
15 |
+
|
16 |
+
from ..utils import overlap_and_add
|
17 |
+
from ..utils import capture_init
|
18 |
+
|
19 |
+
|
20 |
+
class MulCatBlock(nn.Module):
|
21 |
+
|
22 |
+
def __init__(self, input_size, hidden_size, dropout=0, bidirectional=False):
|
23 |
+
super(MulCatBlock, self).__init__()
|
24 |
+
|
25 |
+
self.input_size = input_size
|
26 |
+
self.hidden_size = hidden_size
|
27 |
+
self.num_direction = int(bidirectional) + 1
|
28 |
+
|
29 |
+
self.rnn = nn.LSTM(input_size, hidden_size, 1, dropout=dropout,
|
30 |
+
batch_first=True, bidirectional=bidirectional)
|
31 |
+
self.rnn_proj = nn.Linear(hidden_size * self.num_direction, input_size)
|
32 |
+
|
33 |
+
self.gate_rnn = nn.LSTM(input_size, hidden_size, num_layers=1,
|
34 |
+
batch_first=True, dropout=dropout, bidirectional=bidirectional)
|
35 |
+
self.gate_rnn_proj = nn.Linear(
|
36 |
+
hidden_size * self.num_direction, input_size)
|
37 |
+
|
38 |
+
self.block_projection = nn.Linear(input_size * 2, input_size)
|
39 |
+
|
40 |
+
def forward(self, input):
|
41 |
+
output = input
|
42 |
+
# run rnn module
|
43 |
+
rnn_output, _ = self.rnn(output)
|
44 |
+
rnn_output = self.rnn_proj(rnn_output.contiguous(
|
45 |
+
).view(-1, rnn_output.shape[2])).view(output.shape).contiguous()
|
46 |
+
# run gate rnn module
|
47 |
+
gate_rnn_output, _ = self.gate_rnn(output)
|
48 |
+
gate_rnn_output = self.gate_rnn_proj(gate_rnn_output.contiguous(
|
49 |
+
).view(-1, gate_rnn_output.shape[2])).view(output.shape).contiguous()
|
50 |
+
# apply gated rnn
|
51 |
+
gated_output = torch.mul(rnn_output, gate_rnn_output)
|
52 |
+
gated_output = torch.cat([gated_output, output], 2)
|
53 |
+
gated_output = self.block_projection(
|
54 |
+
gated_output.contiguous().view(-1, gated_output.shape[2])).view(output.shape)
|
55 |
+
return gated_output
|
56 |
+
|
57 |
+
|
58 |
+
class ByPass(nn.Module):
|
59 |
+
def __init__(self):
|
60 |
+
super(ByPass, self).__init__()
|
61 |
+
|
62 |
+
def forward(self, input):
|
63 |
+
return input
|
64 |
+
|
65 |
+
|
66 |
+
class DPMulCat(nn.Module):
|
67 |
+
def __init__(self, input_size, hidden_size, output_size, num_spk,
|
68 |
+
dropout=0, num_layers=1, bidirectional=True, input_normalize=False):
|
69 |
+
super(DPMulCat, self).__init__()
|
70 |
+
|
71 |
+
self.input_size = input_size
|
72 |
+
self.output_size = output_size
|
73 |
+
self.hidden_size = hidden_size
|
74 |
+
self.in_norm = input_normalize
|
75 |
+
self.num_layers = num_layers
|
76 |
+
|
77 |
+
self.rows_grnn = nn.ModuleList([])
|
78 |
+
self.cols_grnn = nn.ModuleList([])
|
79 |
+
self.rows_normalization = nn.ModuleList([])
|
80 |
+
self.cols_normalization = nn.ModuleList([])
|
81 |
+
|
82 |
+
# create the dual path pipeline
|
83 |
+
for i in range(num_layers):
|
84 |
+
self.rows_grnn.append(MulCatBlock(
|
85 |
+
input_size, hidden_size, dropout, bidirectional=bidirectional))
|
86 |
+
self.cols_grnn.append(MulCatBlock(
|
87 |
+
input_size, hidden_size, dropout, bidirectional=bidirectional))
|
88 |
+
if self.in_norm:
|
89 |
+
self.rows_normalization.append(
|
90 |
+
nn.GroupNorm(1, input_size, eps=1e-8))
|
91 |
+
self.cols_normalization.append(
|
92 |
+
nn.GroupNorm(1, input_size, eps=1e-8))
|
93 |
+
else:
|
94 |
+
# used to disable normalization
|
95 |
+
self.rows_normalization.append(ByPass())
|
96 |
+
self.cols_normalization.append(ByPass())
|
97 |
+
|
98 |
+
self.output = nn.Sequential(
|
99 |
+
nn.PReLU(), nn.Conv2d(input_size, output_size * num_spk, 1))
|
100 |
+
|
101 |
+
def forward(self, input):
|
102 |
+
batch_size, _, d1, d2 = input.shape
|
103 |
+
output = input
|
104 |
+
output_all = []
|
105 |
+
for i in range(self.num_layers):
|
106 |
+
row_input = output.permute(0, 3, 2, 1).contiguous().view(
|
107 |
+
batch_size * d2, d1, -1)
|
108 |
+
row_output = self.rows_grnn[i](row_input)
|
109 |
+
row_output = row_output.view(
|
110 |
+
batch_size, d2, d1, -1).permute(0, 3, 2, 1).contiguous()
|
111 |
+
row_output = self.rows_normalization[i](row_output)
|
112 |
+
# apply a skip connection
|
113 |
+
if self.training:
|
114 |
+
output = output + row_output
|
115 |
+
else:
|
116 |
+
output += row_output
|
117 |
+
|
118 |
+
col_input = output.permute(0, 2, 3, 1).contiguous().view(
|
119 |
+
batch_size * d1, d2, -1)
|
120 |
+
col_output = self.cols_grnn[i](col_input)
|
121 |
+
col_output = col_output.view(
|
122 |
+
batch_size, d1, d2, -1).permute(0, 3, 1, 2).contiguous()
|
123 |
+
col_output = self.cols_normalization[i](col_output).contiguous()
|
124 |
+
# apply a skip connection
|
125 |
+
if self.training:
|
126 |
+
output = output + col_output
|
127 |
+
else:
|
128 |
+
output += col_output
|
129 |
+
|
130 |
+
output_i = self.output(output)
|
131 |
+
if self.training or i == (self.num_layers - 1):
|
132 |
+
output_all.append(output_i)
|
133 |
+
return output_all
|
134 |
+
|
135 |
+
|
136 |
+
class Separator(nn.Module):
|
137 |
+
def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2,
|
138 |
+
layer=4, segment_size=100, input_normalize=False, bidirectional=True):
|
139 |
+
super(Separator, self).__init__()
|
140 |
+
|
141 |
+
self.input_dim = input_dim
|
142 |
+
self.feature_dim = feature_dim
|
143 |
+
self.hidden_dim = hidden_dim
|
144 |
+
self.output_dim = output_dim
|
145 |
+
|
146 |
+
self.layer = layer
|
147 |
+
self.segment_size = segment_size
|
148 |
+
self.num_spk = num_spk
|
149 |
+
self.input_normalize = input_normalize
|
150 |
+
|
151 |
+
self.rnn_model = DPMulCat(self.feature_dim, self.hidden_dim,
|
152 |
+
self.feature_dim, self.num_spk, num_layers=layer, bidirectional=bidirectional, input_normalize=input_normalize)
|
153 |
+
|
154 |
+
# ======================================= #
|
155 |
+
# The following code block was borrowed and modified from https://github.com/yluo42/TAC
|
156 |
+
# ================ BEGIN ================ #
|
157 |
+
def pad_segment(self, input, segment_size):
|
158 |
+
# input is the features: (B, N, T)
|
159 |
+
batch_size, dim, seq_len = input.shape
|
160 |
+
segment_stride = segment_size // 2
|
161 |
+
rest = segment_size - (segment_stride + seq_len %
|
162 |
+
segment_size) % segment_size
|
163 |
+
if rest > 0:
|
164 |
+
pad = Variable(torch.zeros(batch_size, dim, rest)
|
165 |
+
).type(input.type())
|
166 |
+
input = torch.cat([input, pad], 2)
|
167 |
+
|
168 |
+
pad_aux = Variable(torch.zeros(
|
169 |
+
batch_size, dim, segment_stride)).type(input.type())
|
170 |
+
input = torch.cat([pad_aux, input, pad_aux], 2)
|
171 |
+
return input, rest
|
172 |
+
|
173 |
+
def create_chuncks(self, input, segment_size):
|
174 |
+
# split the feature into chunks of segment size
|
175 |
+
# input is the features: (B, N, T)
|
176 |
+
|
177 |
+
input, rest = self.pad_segment(input, segment_size)
|
178 |
+
batch_size, dim, seq_len = input.shape
|
179 |
+
segment_stride = segment_size // 2
|
180 |
+
|
181 |
+
segments1 = input[:, :, :-segment_stride].contiguous().view(batch_size,
|
182 |
+
dim, -1, segment_size)
|
183 |
+
segments2 = input[:, :, segment_stride:].contiguous().view(
|
184 |
+
batch_size, dim, -1, segment_size)
|
185 |
+
segments = torch.cat([segments1, segments2], 3).view(
|
186 |
+
batch_size, dim, -1, segment_size).transpose(2, 3)
|
187 |
+
return segments.contiguous(), rest
|
188 |
+
|
189 |
+
def merge_chuncks(self, input, rest):
|
190 |
+
# merge the splitted features into full utterance
|
191 |
+
# input is the features: (B, N, L, K)
|
192 |
+
|
193 |
+
batch_size, dim, segment_size, _ = input.shape
|
194 |
+
segment_stride = segment_size // 2
|
195 |
+
input = input.transpose(2, 3).contiguous().view(
|
196 |
+
batch_size, dim, -1, segment_size*2) # B, N, K, L
|
197 |
+
|
198 |
+
input1 = input[:, :, :, :segment_size].contiguous().view(
|
199 |
+
batch_size, dim, -1)[:, :, segment_stride:]
|
200 |
+
input2 = input[:, :, :, segment_size:].contiguous().view(
|
201 |
+
batch_size, dim, -1)[:, :, :-segment_stride]
|
202 |
+
|
203 |
+
output = input1 + input2
|
204 |
+
if rest > 0:
|
205 |
+
output = output[:, :, :-rest]
|
206 |
+
return output.contiguous() # B, N, T
|
207 |
+
# ================= END ================= #
|
208 |
+
|
209 |
+
def forward(self, input):
|
210 |
+
# create chunks
|
211 |
+
enc_segments, enc_rest = self.create_chuncks(
|
212 |
+
input, self.segment_size)
|
213 |
+
# separate
|
214 |
+
output_all = self.rnn_model(enc_segments)
|
215 |
+
|
216 |
+
# merge back audio files
|
217 |
+
output_all_wav = []
|
218 |
+
for ii in range(len(output_all)):
|
219 |
+
output_ii = self.merge_chuncks(
|
220 |
+
output_all[ii], enc_rest)
|
221 |
+
output_all_wav.append(output_ii)
|
222 |
+
return output_all_wav
|
223 |
+
|
224 |
+
|
225 |
+
class SWave(nn.Module):
|
226 |
+
@capture_init
|
227 |
+
def __init__(self, N, L, H, R, C, sr, segment, input_normalize):
|
228 |
+
super(SWave, self).__init__()
|
229 |
+
# hyper-parameter
|
230 |
+
self.N, self.L, self.H, self.R, self.C, self.sr, self.segment = N, L, H, R, C, sr, segment
|
231 |
+
self.input_normalize = input_normalize
|
232 |
+
self.context_len = 2 * self.sr / 1000
|
233 |
+
self.context = int(self.sr * self.context_len / 1000)
|
234 |
+
self.layer = self.R
|
235 |
+
self.filter_dim = self.context * 2 + 1
|
236 |
+
self.num_spk = self.C
|
237 |
+
# similar to dprnn paper, setting chancksize to sqrt(2*L)
|
238 |
+
self.segment_size = int(
|
239 |
+
np.sqrt(2 * self.sr * self.segment / (self.L/2)))
|
240 |
+
|
241 |
+
# model sub-networks
|
242 |
+
self.encoder = Encoder(L, N)
|
243 |
+
self.decoder = Decoder(L)
|
244 |
+
self.separator = Separator(self.filter_dim + self.N, self.N, self.H,
|
245 |
+
self.filter_dim, self.num_spk, self.layer, self.segment_size, self.input_normalize)
|
246 |
+
# init
|
247 |
+
for p in self.parameters():
|
248 |
+
if p.dim() > 1:
|
249 |
+
nn.init.xavier_normal_(p)
|
250 |
+
|
251 |
+
def forward(self, mixture):
|
252 |
+
mixture_w = self.encoder(mixture)
|
253 |
+
output_all = self.separator(mixture_w)
|
254 |
+
|
255 |
+
# fix time dimension, might change due to convolution operations
|
256 |
+
T_mix = mixture.size(-1)
|
257 |
+
# generate wav after each RNN block and optimize the loss
|
258 |
+
outputs = []
|
259 |
+
for ii in range(len(output_all)):
|
260 |
+
output_ii = output_all[ii].view(
|
261 |
+
mixture.shape[0], self.C, self.N, mixture_w.shape[2])
|
262 |
+
output_ii = self.decoder(output_ii)
|
263 |
+
|
264 |
+
T_est = output_ii.size(-1)
|
265 |
+
output_ii = F.pad(output_ii, (0, T_mix - T_est))
|
266 |
+
outputs.append(output_ii)
|
267 |
+
return torch.stack(outputs)
|
268 |
+
|
269 |
+
|
270 |
+
class Encoder(nn.Module):
|
271 |
+
def __init__(self, L, N):
|
272 |
+
super(Encoder, self).__init__()
|
273 |
+
self.L, self.N = L, N
|
274 |
+
# setting 50% overlap
|
275 |
+
self.conv = nn.Conv1d(
|
276 |
+
1, N, kernel_size=L, stride=L // 2, bias=False)
|
277 |
+
|
278 |
+
def forward(self, mixture):
|
279 |
+
mixture = torch.unsqueeze(mixture, 1)
|
280 |
+
mixture_w = F.relu(self.conv(mixture))
|
281 |
+
return mixture_w
|
282 |
+
|
283 |
+
|
284 |
+
class Decoder(nn.Module):
|
285 |
+
def __init__(self, L):
|
286 |
+
super(Decoder, self).__init__()
|
287 |
+
self.L = L
|
288 |
+
|
289 |
+
def forward(self, est_source):
|
290 |
+
est_source = torch.transpose(est_source, 2, 3)
|
291 |
+
est_source = nn.AvgPool2d((1, self.L))(est_source)
|
292 |
+
est_source = overlap_and_add(est_source, self.L//2)
|
293 |
+
|
294 |
+
return est_source
|
svoice/separate.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import torch
|
8 |
+
import tqdm
|
9 |
+
|
10 |
+
from .data.data import EvalDataLoader, EvalDataset
|
11 |
+
from . import distrib
|
12 |
+
from .utils import remove_pad
|
13 |
+
|
14 |
+
from .utils import bold, deserialize_model, LogProgress
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
def load_model():
|
18 |
+
global device
|
19 |
+
global model
|
20 |
+
global pkg
|
21 |
+
print("Loading svoice model if available...")
|
22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
pkg = torch.load('checkpoint.th')
|
24 |
+
if 'model' in pkg:
|
25 |
+
model = pkg['model']
|
26 |
+
else:
|
27 |
+
model = pkg
|
28 |
+
model = deserialize_model(model)
|
29 |
+
logger.debug(model)
|
30 |
+
model.eval()
|
31 |
+
model.to(device)
|
32 |
+
print("svoice model loaded.")
|
33 |
+
print("Device: {}".format(device))
|
34 |
+
|
35 |
+
parser = argparse.ArgumentParser("Speech separation using MulCat blocks")
|
36 |
+
parser.add_argument("model_path", type=str, help="Model name")
|
37 |
+
parser.add_argument("out_dir", type=str, default="exp/result",
|
38 |
+
help="Directory putting enhanced wav files")
|
39 |
+
parser.add_argument("--mix_dir", type=str, default=None,
|
40 |
+
help="Directory including mix wav files")
|
41 |
+
parser.add_argument("--mix_json", type=str, default=None,
|
42 |
+
help="Json file including mix wav files")
|
43 |
+
parser.add_argument('--device', default="cuda")
|
44 |
+
parser.add_argument("--sample_rate", default=8000,
|
45 |
+
type=int, help="Sample rate")
|
46 |
+
parser.add_argument("--batch_size", default=1, type=int, help="Batch size")
|
47 |
+
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
|
48 |
+
default=logging.INFO, help="More loggging")
|
49 |
+
|
50 |
+
def save_wavs(estimate_source, mix_sig, lengths, filenames, out_dir, sr=16000):
|
51 |
+
# Remove padding and flat
|
52 |
+
flat_estimate = remove_pad(estimate_source, lengths)
|
53 |
+
mix_sig = remove_pad(mix_sig, lengths)
|
54 |
+
# Write result
|
55 |
+
for i, filename in enumerate(filenames):
|
56 |
+
filename = os.path.join(
|
57 |
+
out_dir, os.path.basename(filename).strip(".wav"))
|
58 |
+
write(mix_sig[i], filename + ".wav", sr=sr)
|
59 |
+
C = flat_estimate[i].shape[0]
|
60 |
+
# future support for wave playing
|
61 |
+
for c in range(C):
|
62 |
+
write(flat_estimate[i][c], filename + f"_s{c + 1}.wav", sr=sr)
|
63 |
+
|
64 |
+
|
65 |
+
def write(inputs, filename, sr=8000):
|
66 |
+
librosa.output.write_wav(filename, inputs, sr, norm=True)
|
67 |
+
|
68 |
+
def separate_demo(mix_dir='mix/', batch_size=1, sample_rate=16000):
|
69 |
+
mix_dir, mix_json = mix_dir, None
|
70 |
+
out_dir = 'separated'
|
71 |
+
# Load data
|
72 |
+
eval_dataset = EvalDataset(
|
73 |
+
mix_dir,
|
74 |
+
mix_json,
|
75 |
+
batch_size=batch_size,
|
76 |
+
sample_rate=sample_rate,
|
77 |
+
)
|
78 |
+
eval_loader = distrib.loader(
|
79 |
+
eval_dataset, batch_size=1, klass=EvalDataLoader)
|
80 |
+
|
81 |
+
if distrib.rank == 0:
|
82 |
+
os.makedirs(out_dir, exist_ok=True)
|
83 |
+
distrib.barrier()
|
84 |
+
|
85 |
+
with torch.no_grad():
|
86 |
+
for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)):
|
87 |
+
# Get batch data
|
88 |
+
mixture, lengths, filenames = data
|
89 |
+
mixture = mixture.to(device)
|
90 |
+
lengths = lengths.to(device)
|
91 |
+
# Forward
|
92 |
+
estimate_sources = model(mixture)[-1]
|
93 |
+
# save wav files
|
94 |
+
save_wavs(estimate_sources, mixture, lengths,
|
95 |
+
filenames, out_dir, sr=sample_rate)
|
96 |
+
|
97 |
+
separated_files = [os.path.join(out_dir, f) for f in os.listdir(out_dir)]
|
98 |
+
separated_files = [os.path.abspath(f) for f in separated_files]
|
99 |
+
separated_files = [f for f in separated_files if not f.endswith('original.wav')]
|
100 |
+
return separated_files
|
101 |
+
|
102 |
+
def get_mix_paths(args):
|
103 |
+
mix_dir = None
|
104 |
+
mix_json = None
|
105 |
+
# fix mix dir
|
106 |
+
try:
|
107 |
+
if args.dset.mix_dir:
|
108 |
+
mix_dir = args.dset.mix_dir
|
109 |
+
except:
|
110 |
+
mix_dir = args.mix_dir
|
111 |
+
|
112 |
+
# fix mix json
|
113 |
+
try:
|
114 |
+
if args.dset.mix_json:
|
115 |
+
mix_json = args.dset.mix_json
|
116 |
+
except:
|
117 |
+
mix_json = args.mix_json
|
118 |
+
return mix_dir, mix_json
|
119 |
+
|
120 |
+
|
121 |
+
def separate(args, model=None, local_out_dir=None):
|
122 |
+
mix_dir, mix_json = get_mix_paths(args)
|
123 |
+
if not mix_json and not mix_dir:
|
124 |
+
logger.error("Must provide mix_dir or mix_json! "
|
125 |
+
"When providing mix_dir, mix_json is ignored.")
|
126 |
+
# Load model
|
127 |
+
if not model:
|
128 |
+
# model
|
129 |
+
pkg = torch.load(args.model_path)
|
130 |
+
if 'model' in pkg:
|
131 |
+
model = pkg['model']
|
132 |
+
else:
|
133 |
+
model = pkg
|
134 |
+
model = deserialize_model(model)
|
135 |
+
logger.debug(model)
|
136 |
+
model.eval()
|
137 |
+
model.to(args.device)
|
138 |
+
if local_out_dir:
|
139 |
+
out_dir = local_out_dir
|
140 |
+
else:
|
141 |
+
out_dir = args.out_dir
|
142 |
+
|
143 |
+
# Load data
|
144 |
+
eval_dataset = EvalDataset(
|
145 |
+
mix_dir,
|
146 |
+
mix_json,
|
147 |
+
batch_size=args.batch_size,
|
148 |
+
sample_rate=args.sample_rate,
|
149 |
+
)
|
150 |
+
eval_loader = distrib.loader(
|
151 |
+
eval_dataset, batch_size=1, klass=EvalDataLoader)
|
152 |
+
|
153 |
+
if distrib.rank == 0:
|
154 |
+
os.makedirs(out_dir, exist_ok=True)
|
155 |
+
distrib.barrier()
|
156 |
+
|
157 |
+
with torch.no_grad():
|
158 |
+
for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)):
|
159 |
+
# Get batch data
|
160 |
+
mixture, lengths, filenames = data
|
161 |
+
mixture = mixture.to(args.device)
|
162 |
+
lengths = lengths.to(args.device)
|
163 |
+
# Forward
|
164 |
+
estimate_sources = model(mixture)[-1]
|
165 |
+
# save wav files
|
166 |
+
save_wavs(estimate_sources, mixture, lengths,
|
167 |
+
filenames, out_dir, sr=args.sample_rate)
|
168 |
+
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
args = parser.parse_args()
|
172 |
+
logging.basicConfig(stream=sys.stderr, level=args.verbose)
|
173 |
+
logger.debug(args)
|
174 |
+
separate(args, local_out_dir=args.out_dir)
|
svoice/solver.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Author: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf
|
8 |
+
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
from pathlib import Path
|
12 |
+
import os
|
13 |
+
import time
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
|
19 |
+
|
20 |
+
from . import distrib
|
21 |
+
from .separate import separate
|
22 |
+
from .evaluate import evaluate
|
23 |
+
from .models.sisnr_loss import cal_loss
|
24 |
+
from .models.swave import SWave
|
25 |
+
from .utils import bold, copy_state, pull_metric, serialize_model, swap_state, LogProgress
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
class Solver(object):
|
32 |
+
def __init__(self, data, model, optimizer, args):
|
33 |
+
self.tr_loader = data['tr_loader']
|
34 |
+
self.cv_loader = data['cv_loader']
|
35 |
+
self.tt_loader = data['tt_loader']
|
36 |
+
self.model = model
|
37 |
+
self.dmodel = distrib.wrap(model)
|
38 |
+
self.optimizer = optimizer
|
39 |
+
if args.lr_sched == 'step':
|
40 |
+
self.sched = StepLR(
|
41 |
+
self.optimizer, step_size=args.step.step_size, gamma=args.step.gamma)
|
42 |
+
elif args.lr_sched == 'plateau':
|
43 |
+
self.sched = ReduceLROnPlateau(
|
44 |
+
self.optimizer, factor=args.plateau.factor, patience=args.plateau.patience)
|
45 |
+
else:
|
46 |
+
self.sched = None
|
47 |
+
|
48 |
+
# Training config
|
49 |
+
self.device = args.device
|
50 |
+
self.epochs = args.epochs
|
51 |
+
self.max_norm = args.max_norm
|
52 |
+
|
53 |
+
# Checkpoints
|
54 |
+
self.continue_from = args.continue_from
|
55 |
+
self.eval_every = args.eval_every
|
56 |
+
self.checkpoint = Path(
|
57 |
+
args.checkpoint_file) if args.checkpoint else None
|
58 |
+
if self.checkpoint:
|
59 |
+
logger.debug("Checkpoint will be saved to %s",
|
60 |
+
self.checkpoint.resolve())
|
61 |
+
self.history_file = args.history_file
|
62 |
+
|
63 |
+
self.best_state = None
|
64 |
+
self.restart = args.restart
|
65 |
+
# keep track of losses
|
66 |
+
self.history = []
|
67 |
+
|
68 |
+
# Where to save samples
|
69 |
+
self.samples_dir = args.samples_dir
|
70 |
+
|
71 |
+
# logging
|
72 |
+
self.num_prints = args.num_prints
|
73 |
+
|
74 |
+
# for seperation tests
|
75 |
+
self.args = args
|
76 |
+
self._reset()
|
77 |
+
|
78 |
+
def _serialize(self, path):
|
79 |
+
package = {}
|
80 |
+
package['model'] = serialize_model(self.model)
|
81 |
+
package['optimizer'] = self.optimizer.state_dict()
|
82 |
+
package['history'] = self.history
|
83 |
+
package['best_state'] = self.best_state
|
84 |
+
package['args'] = self.args
|
85 |
+
torch.save(package, path)
|
86 |
+
|
87 |
+
def _reset(self):
|
88 |
+
load_from = None
|
89 |
+
# Reset
|
90 |
+
if self.checkpoint and self.checkpoint.exists() and not self.restart:
|
91 |
+
load_from = self.checkpoint
|
92 |
+
elif self.continue_from:
|
93 |
+
load_from = self.continue_from
|
94 |
+
|
95 |
+
if load_from:
|
96 |
+
logger.info(f'Loading checkpoint model: {load_from}')
|
97 |
+
package = torch.load(load_from, 'cpu')
|
98 |
+
if load_from == self.continue_from and self.args.continue_best:
|
99 |
+
self.model.load_state_dict(package['best_state'])
|
100 |
+
else:
|
101 |
+
self.model.load_state_dict(package['model']['state'])
|
102 |
+
|
103 |
+
if 'optimizer' in package and not self.args.continue_best:
|
104 |
+
self.optimizer.load_state_dict(package['optimizer'])
|
105 |
+
self.history = package['history']
|
106 |
+
self.best_state = package['best_state']
|
107 |
+
|
108 |
+
def train(self):
|
109 |
+
# Optimizing the model
|
110 |
+
if self.history:
|
111 |
+
logger.info("Replaying metrics from previous run")
|
112 |
+
for epoch, metrics in enumerate(self.history):
|
113 |
+
info = " ".join(f"{k}={v:.5f}" for k, v in metrics.items())
|
114 |
+
logger.info(f"Epoch {epoch}: {info}")
|
115 |
+
|
116 |
+
for epoch in range(len(self.history), self.epochs):
|
117 |
+
# Train one epoch
|
118 |
+
self.model.train() # Turn on BatchNorm & Dropout
|
119 |
+
start = time.time()
|
120 |
+
logger.info('-' * 70)
|
121 |
+
logger.info("Training...")
|
122 |
+
train_loss = self._run_one_epoch(epoch)
|
123 |
+
logger.info(bold(f'Train Summary | End of Epoch {epoch + 1} | '
|
124 |
+
f'Time {time.time() - start:.2f}s | Train Loss {train_loss:.5f}'))
|
125 |
+
|
126 |
+
# Cross validation
|
127 |
+
logger.info('-' * 70)
|
128 |
+
logger.info('Cross validation...')
|
129 |
+
self.model.eval() # Turn off Batchnorm & Dropout
|
130 |
+
with torch.no_grad():
|
131 |
+
valid_loss = self._run_one_epoch(epoch, cross_valid=True)
|
132 |
+
logger.info(bold(f'Valid Summary | End of Epoch {epoch + 1} | '
|
133 |
+
f'Time {time.time() - start:.2f}s | Valid Loss {valid_loss:.5f}'))
|
134 |
+
|
135 |
+
# learning rate scheduling
|
136 |
+
if self.sched:
|
137 |
+
if self.args.lr_sched == 'plateau':
|
138 |
+
self.sched.step(valid_loss)
|
139 |
+
else:
|
140 |
+
self.sched.step()
|
141 |
+
logger.info(
|
142 |
+
f'Learning rate adjusted: {self.optimizer.state_dict()["param_groups"][0]["lr"]:.5f}')
|
143 |
+
|
144 |
+
best_loss = min(pull_metric(self.history, 'valid') + [valid_loss])
|
145 |
+
metrics = {'train': train_loss,
|
146 |
+
'valid': valid_loss, 'best': best_loss}
|
147 |
+
# Save the best model
|
148 |
+
if valid_loss == best_loss or self.args.keep_last:
|
149 |
+
logger.info(bold('New best valid loss %.4f'), valid_loss)
|
150 |
+
self.best_state = copy_state(self.model.state_dict())
|
151 |
+
|
152 |
+
# evaluate and separate samples every 'eval_every' argument number of epochs
|
153 |
+
# also evaluate on last epoch
|
154 |
+
if (epoch + 1) % self.eval_every == 0 or epoch == self.epochs - 1:
|
155 |
+
# Evaluate on the testset
|
156 |
+
logger.info('-' * 70)
|
157 |
+
logger.info('Evaluating on the test set...')
|
158 |
+
# We switch to the best known model for testing
|
159 |
+
with swap_state(self.model, self.best_state):
|
160 |
+
sisnr, pesq, stoi = evaluate(
|
161 |
+
self.args, self.model, self.tt_loader, self.args.sample_rate)
|
162 |
+
metrics.update({'sisnr': sisnr, 'pesq': pesq, 'stoi': stoi})
|
163 |
+
|
164 |
+
# separate some samples
|
165 |
+
logger.info('Separate and save samples...')
|
166 |
+
separate(self.args, self.model, self.samples_dir)
|
167 |
+
|
168 |
+
self.history.append(metrics)
|
169 |
+
info = " | ".join(
|
170 |
+
f"{k.capitalize()} {v:.5f}" for k, v in metrics.items())
|
171 |
+
logger.info('-' * 70)
|
172 |
+
logger.info(bold(f"Overall Summary | Epoch {epoch + 1} | {info}"))
|
173 |
+
|
174 |
+
if distrib.rank == 0:
|
175 |
+
json.dump(self.history, open(self.history_file, "w"), indent=2)
|
176 |
+
# Save model each epoch
|
177 |
+
if self.checkpoint:
|
178 |
+
self._serialize(self.checkpoint)
|
179 |
+
logger.debug("Checkpoint saved to %s",
|
180 |
+
self.checkpoint.resolve())
|
181 |
+
|
182 |
+
def _run_one_epoch(self, epoch, cross_valid=False):
|
183 |
+
total_loss = 0
|
184 |
+
data_loader = self.tr_loader if not cross_valid else self.cv_loader
|
185 |
+
|
186 |
+
# get a different order for distributed training, otherwise this will get ignored
|
187 |
+
data_loader.epoch = epoch
|
188 |
+
|
189 |
+
label = ["Train", "Valid"][cross_valid]
|
190 |
+
name = label + f" | Epoch {epoch + 1}"
|
191 |
+
logprog = LogProgress(logger, data_loader,
|
192 |
+
updates=self.num_prints, name=name)
|
193 |
+
for i, data in enumerate(logprog):
|
194 |
+
mixture, lengths, sources = [x.to(self.device) for x in data]
|
195 |
+
estimate_source = self.dmodel(mixture)
|
196 |
+
|
197 |
+
# only eval last layer
|
198 |
+
if cross_valid:
|
199 |
+
estimate_source = estimate_source[-1:]
|
200 |
+
|
201 |
+
loss = 0
|
202 |
+
cnt = len(estimate_source)
|
203 |
+
# apply a loss function after each layer
|
204 |
+
with torch.autograd.set_detect_anomaly(True):
|
205 |
+
for c_idx, est_src in enumerate(estimate_source):
|
206 |
+
coeff = ((c_idx+1)*(1/cnt))
|
207 |
+
loss_i = 0
|
208 |
+
# SI-SNR loss
|
209 |
+
sisnr_loss, snr, est_src, reorder_est_src = cal_loss(
|
210 |
+
sources, estimate_source[c_idx], lengths)
|
211 |
+
loss += (coeff * sisnr_loss)
|
212 |
+
loss /= len(estimate_source)
|
213 |
+
|
214 |
+
if not cross_valid:
|
215 |
+
# optimize model in training mode
|
216 |
+
self.optimizer.zero_grad()
|
217 |
+
loss.backward()
|
218 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(),
|
219 |
+
self.max_norm)
|
220 |
+
self.optimizer.step()
|
221 |
+
|
222 |
+
total_loss += loss.item()
|
223 |
+
logprog.update(loss=format(total_loss / (i + 1), ".5f"))
|
224 |
+
|
225 |
+
# Just in case, clear some memory
|
226 |
+
del loss, estimate_source
|
227 |
+
return distrib.average([total_loss / (i + 1)], i + 1)[0]
|
svoice/utils.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Authors: Yossi Adi (adiyoss) and Alexandre Defossez (adefossez)
|
8 |
+
|
9 |
+
import functools
|
10 |
+
import logging
|
11 |
+
from contextlib import contextmanager
|
12 |
+
import inspect
|
13 |
+
import os
|
14 |
+
import time
|
15 |
+
import math
|
16 |
+
import torch
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
def capture_init(init):
|
22 |
+
"""
|
23 |
+
Decorate `__init__` with this, and you can then
|
24 |
+
recover the *args and **kwargs passed to it in `self._init_args_kwargs`
|
25 |
+
"""
|
26 |
+
@functools.wraps(init)
|
27 |
+
def __init__(self, *args, **kwargs):
|
28 |
+
self._init_args_kwargs = (args, kwargs)
|
29 |
+
init(self, *args, **kwargs)
|
30 |
+
|
31 |
+
return __init__
|
32 |
+
|
33 |
+
|
34 |
+
def deserialize_model(package, strict=False):
|
35 |
+
klass = package['class']
|
36 |
+
if strict:
|
37 |
+
model = klass(*package['args'], **package['kwargs'])
|
38 |
+
else:
|
39 |
+
sig = inspect.signature(klass)
|
40 |
+
kw = package['kwargs']
|
41 |
+
for key in list(kw):
|
42 |
+
if key not in sig.parameters:
|
43 |
+
logger.warning("Dropping inexistant parameter %s", key)
|
44 |
+
del kw[key]
|
45 |
+
model = klass(*package['args'], **kw)
|
46 |
+
model.load_state_dict(package['state'])
|
47 |
+
return model
|
48 |
+
|
49 |
+
|
50 |
+
def copy_state(state):
|
51 |
+
return {k: v.cpu().clone() for k, v in state.items()}
|
52 |
+
|
53 |
+
|
54 |
+
def serialize_model(model):
|
55 |
+
args, kwargs = model._init_args_kwargs
|
56 |
+
state = copy_state(model.state_dict())
|
57 |
+
return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state}
|
58 |
+
|
59 |
+
|
60 |
+
@contextmanager
|
61 |
+
def swap_state(model, state):
|
62 |
+
old_state = copy_state(model.state_dict())
|
63 |
+
model.load_state_dict(state)
|
64 |
+
try:
|
65 |
+
yield
|
66 |
+
finally:
|
67 |
+
model.load_state_dict(old_state)
|
68 |
+
|
69 |
+
|
70 |
+
@contextmanager
|
71 |
+
def swap_cwd(cwd):
|
72 |
+
old_cwd = os.getcwd()
|
73 |
+
os.chdir(cwd)
|
74 |
+
try:
|
75 |
+
yield
|
76 |
+
finally:
|
77 |
+
os.chdir(old_cwd)
|
78 |
+
|
79 |
+
|
80 |
+
def pull_metric(history, name):
|
81 |
+
out = []
|
82 |
+
for metrics in history:
|
83 |
+
if name in metrics:
|
84 |
+
out.append(metrics[name])
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
class LogProgress:
|
89 |
+
"""
|
90 |
+
Sort of like tqdm but using log lines and not as real time.
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(self, logger, iterable, updates=5, total=None,
|
94 |
+
name="LogProgress", level=logging.INFO):
|
95 |
+
self.iterable = iterable
|
96 |
+
self.total = total or len(iterable)
|
97 |
+
self.updates = updates
|
98 |
+
self.name = name
|
99 |
+
self.logger = logger
|
100 |
+
self.level = level
|
101 |
+
|
102 |
+
def update(self, **infos):
|
103 |
+
self._infos = infos
|
104 |
+
|
105 |
+
def __iter__(self):
|
106 |
+
self._iterator = iter(self.iterable)
|
107 |
+
self._index = -1
|
108 |
+
self._infos = {}
|
109 |
+
self._begin = time.time()
|
110 |
+
return self
|
111 |
+
|
112 |
+
def __next__(self):
|
113 |
+
self._index += 1
|
114 |
+
try:
|
115 |
+
value = next(self._iterator)
|
116 |
+
except StopIteration:
|
117 |
+
raise
|
118 |
+
else:
|
119 |
+
return value
|
120 |
+
finally:
|
121 |
+
log_every = max(1, self.total // self.updates)
|
122 |
+
# logging is delayed by 1 it, in order to have the metrics from update
|
123 |
+
if self._index >= 1 and self._index % log_every == 0:
|
124 |
+
self._log()
|
125 |
+
|
126 |
+
def _log(self):
|
127 |
+
self._speed = (1 + self._index) / (time.time() - self._begin)
|
128 |
+
infos = " | ".join(f"{k.capitalize()} {v}" for k,
|
129 |
+
v in self._infos.items())
|
130 |
+
if self._speed < 1e-4:
|
131 |
+
speed = "oo sec/it"
|
132 |
+
elif self._speed < 0.1:
|
133 |
+
speed = f"{1/self._speed:.1f} sec/it"
|
134 |
+
else:
|
135 |
+
speed = f"{self._speed:.1f} it/sec"
|
136 |
+
out = f"{self.name} | {self._index}/{self.total} | {speed}"
|
137 |
+
if infos:
|
138 |
+
out += " | " + infos
|
139 |
+
self.logger.log(self.level, out)
|
140 |
+
|
141 |
+
|
142 |
+
def colorize(text, color):
|
143 |
+
code = f"\033[{color}m"
|
144 |
+
restore = f"\033[0m"
|
145 |
+
return "".join([code, text, restore])
|
146 |
+
|
147 |
+
|
148 |
+
def bold(text):
|
149 |
+
return colorize(text, "1")
|
150 |
+
|
151 |
+
|
152 |
+
def calculate_grad_norm(model):
|
153 |
+
total_norm = 0.0
|
154 |
+
is_first = True
|
155 |
+
for p in model.parameters():
|
156 |
+
param_norm = p.data.grad.flatten()
|
157 |
+
if is_first:
|
158 |
+
total_norm = param_norm
|
159 |
+
is_first = False
|
160 |
+
else:
|
161 |
+
total_norm = torch.cat((total_norm.unsqueeze(
|
162 |
+
1), p.data.grad.flatten().unsqueeze(1)), dim=0).squeeze(1)
|
163 |
+
return total_norm.norm(2) ** (1. / 2)
|
164 |
+
|
165 |
+
|
166 |
+
def calculate_weight_norm(model):
|
167 |
+
total_norm = 0.0
|
168 |
+
is_first = True
|
169 |
+
for p in model.parameters():
|
170 |
+
param_norm = p.data.flatten()
|
171 |
+
if is_first:
|
172 |
+
total_norm = param_norm
|
173 |
+
is_first = False
|
174 |
+
else:
|
175 |
+
total_norm = torch.cat((total_norm.unsqueeze(
|
176 |
+
1), p.data.flatten().unsqueeze(1)), dim=0).squeeze(1)
|
177 |
+
return total_norm.norm(2) ** (1. / 2)
|
178 |
+
|
179 |
+
|
180 |
+
def remove_pad(inputs, inputs_lengths):
|
181 |
+
"""
|
182 |
+
Args:
|
183 |
+
inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size
|
184 |
+
inputs_lengths: torch.Tensor, [B]
|
185 |
+
Returns:
|
186 |
+
results: a list containing B items, each item is [C, T], T varies
|
187 |
+
"""
|
188 |
+
results = []
|
189 |
+
dim = inputs.dim()
|
190 |
+
if dim == 3:
|
191 |
+
C = inputs.size(1)
|
192 |
+
for input, length in zip(inputs, inputs_lengths):
|
193 |
+
if dim == 3: # [B, C, T]
|
194 |
+
results.append(input[:, :length].view(C, -1).cpu().numpy())
|
195 |
+
elif dim == 2: # [B, T]
|
196 |
+
results.append(input[:length].view(-1).cpu().numpy())
|
197 |
+
return results
|
198 |
+
|
199 |
+
|
200 |
+
def overlap_and_add(signal, frame_step):
|
201 |
+
"""Reconstructs a signal from a framed representation.
|
202 |
+
|
203 |
+
Adds potentially overlapping frames of a signal with shape
|
204 |
+
`[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
|
205 |
+
The resulting tensor has shape `[..., output_size]` where
|
206 |
+
|
207 |
+
output_size = (frames - 1) * frame_step + frame_length
|
208 |
+
|
209 |
+
Args:
|
210 |
+
signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
|
211 |
+
frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
|
215 |
+
output_size = (frames - 1) * frame_step + frame_length
|
216 |
+
|
217 |
+
Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
|
218 |
+
"""
|
219 |
+
outer_dimensions = signal.size()[:-2]
|
220 |
+
frames, frame_length = signal.size()[-2:]
|
221 |
+
|
222 |
+
# gcd=Greatest Common Divisor
|
223 |
+
subframe_length = math.gcd(frame_length, frame_step)
|
224 |
+
subframe_step = frame_step // subframe_length
|
225 |
+
subframes_per_frame = frame_length // subframe_length
|
226 |
+
output_size = frame_step * (frames - 1) + frame_length
|
227 |
+
output_subframes = output_size // subframe_length
|
228 |
+
|
229 |
+
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
|
230 |
+
|
231 |
+
frame = torch.arange(0, output_subframes).unfold(
|
232 |
+
0, subframes_per_frame, subframe_step)
|
233 |
+
frame = frame.clone().detach().long().to(signal.device)
|
234 |
+
# frame = signal.new_tensor(frame).clone().long() # signal may in GPU or CPU
|
235 |
+
frame = frame.contiguous().view(-1)
|
236 |
+
|
237 |
+
result = signal.new_zeros(
|
238 |
+
*outer_dimensions, output_subframes, subframe_length)
|
239 |
+
result.index_add_(-2, frame, subframe_signal)
|
240 |
+
result = result.view(*outer_dimensions, -1)
|
241 |
+
return result
|