Spaces:
Runtime error
Runtime error
Add files with LFS
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- LICENSE +107 -0
- README.md +186 -13
- app.py +617 -0
- app_flux.py +305 -0
- app_p2p.py +567 -0
- densepose/__init__.py +22 -0
- densepose/config.py +277 -0
- densepose/converters/__init__.py +17 -0
- densepose/converters/base.py +95 -0
- densepose/converters/builtin.py +33 -0
- densepose/converters/chart_output_hflip.py +73 -0
- densepose/converters/chart_output_to_chart_result.py +190 -0
- densepose/converters/hflip.py +36 -0
- densepose/converters/segm_to_mask.py +152 -0
- densepose/converters/to_chart_result.py +72 -0
- densepose/converters/to_mask.py +51 -0
- densepose/data/__init__.py +27 -0
- densepose/data/build.py +738 -0
- densepose/data/combined_loader.py +46 -0
- densepose/data/dataset_mapper.py +170 -0
- densepose/data/datasets/__init__.py +7 -0
- densepose/data/datasets/builtin.py +18 -0
- densepose/data/datasets/chimpnsee.py +31 -0
- densepose/data/datasets/coco.py +434 -0
- densepose/data/datasets/dataset_type.py +13 -0
- densepose/data/datasets/lvis.py +259 -0
- densepose/data/image_list_dataset.py +74 -0
- densepose/data/inference_based_loader.py +174 -0
- densepose/data/meshes/__init__.py +7 -0
- densepose/data/meshes/builtin.py +103 -0
- densepose/data/meshes/catalog.py +73 -0
- densepose/data/samplers/__init__.py +10 -0
- densepose/data/samplers/densepose_base.py +205 -0
- densepose/data/samplers/densepose_confidence_based.py +110 -0
- densepose/data/samplers/densepose_cse_base.py +141 -0
- densepose/data/samplers/densepose_cse_confidence_based.py +121 -0
- densepose/data/samplers/densepose_cse_uniform.py +14 -0
- densepose/data/samplers/densepose_uniform.py +43 -0
- densepose/data/samplers/mask_from_densepose.py +30 -0
- densepose/data/samplers/prediction_to_gt.py +100 -0
- densepose/data/transform/__init__.py +5 -0
- densepose/data/transform/image.py +41 -0
- densepose/data/utils.py +40 -0
- densepose/data/video/__init__.py +19 -0
- densepose/data/video/frame_selector.py +89 -0
- densepose/data/video/video_keyframe_dataset.py +304 -0
- densepose/engine/__init__.py +5 -0
- densepose/engine/trainer.py +260 -0
- densepose/evaluation/__init__.py +5 -0
- densepose/evaluation/d2_evaluator_adapter.py +52 -0
LICENSE
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
2 |
+
|
3 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an "as-is" basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
|
4 |
+
|
5 |
+
Using Creative Commons Public Licenses
|
6 |
+
|
7 |
+
Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
|
8 |
+
|
9 |
+
Considerations for licensors: Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. More considerations for licensors : wiki.creativecommons.org/Considerations_for_licensors
|
10 |
+
|
11 |
+
Considerations for the public: By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor's permission is not necessary for any reasonβfor example, because of any applicable exception or limitation to copyrightβthen that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. More considerations for the public : wiki.creativecommons.org/Considerations_for_licensees
|
12 |
+
|
13 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
14 |
+
|
15 |
+
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
|
16 |
+
|
17 |
+
Section 1 β Definitions.
|
18 |
+
|
19 |
+
a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
|
20 |
+
b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
|
21 |
+
c. BY-NC-SA Compatible License means a license listed at creativecommons.org/compatiblelicenses, approved by Creative Commons as essentially the equivalent of this Public License.
|
22 |
+
d. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
|
23 |
+
e. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
|
24 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
|
25 |
+
g. License Elements means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
|
26 |
+
h. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
|
27 |
+
i. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
|
28 |
+
j. Licensor means the individual(s) or entity(ies) granting rights under this Public License.
|
29 |
+
k. NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
|
30 |
+
l. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
|
31 |
+
m. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
|
32 |
+
n. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
|
33 |
+
Section 2 β Scope.
|
34 |
+
|
35 |
+
a. License grant.
|
36 |
+
1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
|
37 |
+
A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
|
38 |
+
B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
|
39 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
|
40 |
+
3. Term. The term of this Public License is specified in Section 6(a).
|
41 |
+
4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
|
42 |
+
5. Downstream recipients.
|
43 |
+
A. Offer from the Licensor β Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
|
44 |
+
B. Additional offer from the Licensor β Adapted Material. Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter's License You apply.
|
45 |
+
C. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
|
46 |
+
6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
|
47 |
+
b. Other rights.
|
48 |
+
1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
|
49 |
+
2. Patent and trademark rights are not licensed under this Public License.
|
50 |
+
3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
|
51 |
+
Section 3 β License Conditions.
|
52 |
+
|
53 |
+
Your exercise of the Licensed Rights is expressly made subject to the following conditions.
|
54 |
+
|
55 |
+
a. Attribution.
|
56 |
+
1. If You Share the Licensed Material (including in modified form), You must:
|
57 |
+
A. retain the following if it is supplied by the Licensor with the Licensed Material:
|
58 |
+
i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
|
59 |
+
ii. a copyright notice;
|
60 |
+
iii. a notice that refers to this Public License;
|
61 |
+
iv. a notice that refers to the disclaimer of warranties;
|
62 |
+
v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
|
63 |
+
|
64 |
+
B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
|
65 |
+
C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
|
66 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
|
67 |
+
3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
|
68 |
+
b. ShareAlike.In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
|
69 |
+
1. The Adapter's License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
|
70 |
+
2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
|
71 |
+
3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
|
72 |
+
Section 4 β Sui Generis Database Rights.
|
73 |
+
|
74 |
+
Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
|
75 |
+
|
76 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
|
77 |
+
b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
|
78 |
+
c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
|
79 |
+
For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
|
80 |
+
Section 5 β Disclaimer of Warranties and Limitation of Liability.
|
81 |
+
|
82 |
+
a. Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.
|
83 |
+
b. To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.
|
84 |
+
c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
|
85 |
+
Section 6 β Term and Termination.
|
86 |
+
|
87 |
+
a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
|
88 |
+
b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
|
89 |
+
1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
|
90 |
+
2. upon express reinstatement by the Licensor.
|
91 |
+
For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
|
92 |
+
|
93 |
+
c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
|
94 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
|
95 |
+
Section 7 β Other Terms and Conditions.
|
96 |
+
|
97 |
+
a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
|
98 |
+
b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
|
99 |
+
Section 8 β Interpretation.
|
100 |
+
|
101 |
+
a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
|
102 |
+
b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
|
103 |
+
c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
|
104 |
+
d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
|
105 |
+
Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the "Licensor." The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark "Creative Commons" or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
|
106 |
+
|
107 |
+
Creative Commons may be contacted at creativecommons.org.
|
README.md
CHANGED
@@ -1,13 +1,186 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# π CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models
|
2 |
+
|
3 |
+
<div style="display: flex; justify-content: center; align-items: center;">
|
4 |
+
<a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
|
5 |
+
<img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
|
6 |
+
</a>
|
7 |
+
<a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
|
8 |
+
<img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
|
9 |
+
</a>
|
10 |
+
<a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
|
11 |
+
<img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
|
12 |
+
</a>
|
13 |
+
<a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
|
14 |
+
<img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
15 |
+
</a>
|
16 |
+
<a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
|
17 |
+
<img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
18 |
+
</a>
|
19 |
+
<a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
|
20 |
+
<img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
|
21 |
+
</a>
|
22 |
+
<a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
|
23 |
+
<img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
|
24 |
+
</a>
|
25 |
+
</div>
|
26 |
+
|
27 |
+
|
28 |
+
**CatVTON** is a simple and efficient virtual try-on diffusion model with ***1) Lightweight Network (899.06M parameters totally)***, ***2) Parameter-Efficient Training (49.57M parameters trainable)*** and ***3) Simplified Inference (< 8G VRAM for 1024X768 resolution)***.
|
29 |
+
<div align="center">
|
30 |
+
<img src="resource/img/teaser.jpg" width="100%" height="100%"/>
|
31 |
+
</div>
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
## Updates
|
36 |
+
- **`2024/12/20`**: π Code for gradio app of [**CatVTON-FLUX**] has been released! It is not a stable version, but it is a good start!
|
37 |
+
- **`2024/12/19`**: [**CatVTON-FLUX**](https://huggingface.co/spaces/zhengchong/CatVTON) has been released! It is a extremely lightweight LoRA (only 37.4M checkpints) for [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev), the lora weights are available in **[huggingface repo](https://huggingface.co/zhengchong/CatVTON/tree/main/flux-lora)**, code will be released soon!
|
38 |
+
- **`2024/11/26`**: Our **unified vision-based model for image and video try-on** will be released soon, bringing a brand-new virtual try-on experience! While our demo page will be temporarily taken offline, [**the demo on HuggingFace Space**](https://huggingface.co/spaces/zhengchong/CatVTON) will remain available for use !
|
39 |
+
- **`2024/10/17`**:[**Mask-free version**](https://huggingface.co/zhengchong/CatVTON-MaskFree)π€ of CatVTON is release !
|
40 |
+
- **`2024/10/13`**: We have built a repo [**Awesome-Try-On-Models**](https://github.com/Zheng-Chong/Awesome-Try-On-Models) that focuses on image, video, and 3D-based try-on models published after 2023, aiming to provide insights into the latest technological trends. If you're interested, feel free to contribute or give it a π star!
|
41 |
+
- **`2024/08/13`**: We localize DensePose & SCHP to avoid certain environment issues.
|
42 |
+
- **`2024/08/10`**: Our π€ [**HuggingFace Space**](https://huggingface.co/spaces/zhengchong/CatVTON) is available now! Thanks for the grant from [**ZeroGPU**](https://huggingface.co/zero-gpu-explorers)οΌ
|
43 |
+
- **`2024/08/09`**: [**Evaluation code**](https://github.com/Zheng-Chong/CatVTON?tab=readme-ov-file#3-calculate-metrics) is provided to calculate metrics π.
|
44 |
+
- **`2024/07/27`**: We provide code and workflow for deploying CatVTON on [**ComfyUI**](https://github.com/Zheng-Chong/CatVTON?tab=readme-ov-file#comfyui-workflow) π₯.
|
45 |
+
- **`2024/07/24`**: Our [**Paper on ArXiv**](http://arxiv.org/abs/2407.15886) is available π₯³!
|
46 |
+
- **`2024/07/22`**: Our [**App Code**](https://github.com/Zheng-Chong/CatVTON/blob/main/app.py) is released, deploy and enjoy CatVTON on your mechine π!
|
47 |
+
- **`2024/07/21`**: Our [**Inference Code**](https://github.com/Zheng-Chong/CatVTON/blob/main/inference.py) and [**Weights** π€](https://huggingface.co/zhengchong/CatVTON) are released.
|
48 |
+
- **`2024/07/11`**: Our [**Online Demo**](https://huggingface.co/spaces/zhengchong/CatVTON) is released π.
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
## Installation
|
54 |
+
|
55 |
+
Create a conda environment & Install requirments
|
56 |
+
```shell
|
57 |
+
conda create -n catvton python==3.9.0
|
58 |
+
conda activate catvton
|
59 |
+
cd CatVTON-main # or your path to CatVTON project dir
|
60 |
+
pip install -r requirements.txt
|
61 |
+
```
|
62 |
+
|
63 |
+
## Deployment
|
64 |
+
### ComfyUI Workflow
|
65 |
+
We have modified the main code to enable easy deployment of CatVTON on [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Due to the incompatibility of the code structure, we have released this part in the [Releases](https://github.com/Zheng-Chong/CatVTON/releases/tag/ComfyUI), which includes the code placed under `custom_nodes` of ComfyUI and our workflow JSON files.
|
66 |
+
|
67 |
+
To deploy CatVTON to your ComfyUI, follow these steps:
|
68 |
+
1. Install all the requirements for both CatVTON and ComfyUI, refer to [Installation Guide for CatVTON](https://github.com/Zheng-Chong/CatVTON/blob/main/INSTALL.md) and [Installation Guide for ComfyUI](https://github.com/comfyanonymous/ComfyUI?tab=readme-ov-file#installing).
|
69 |
+
2. Download [`ComfyUI-CatVTON.zip`](https://github.com/Zheng-Chong/CatVTON/releases/download/ComfyUI/ComfyUI-CatVTON.zip) and unzip it in the `custom_nodes` folder under your ComfyUI project (clone from [ComfyUI](https://github.com/comfyanonymous/ComfyUI)).
|
70 |
+
3. Run the ComfyUI.
|
71 |
+
4. Download [`catvton_workflow.json`](https://github.com/Zheng-Chong/CatVTON/releases/download/ComfyUI/catvton_workflow.json) and drag it into you ComfyUI webpage and enjoy π!
|
72 |
+
|
73 |
+
> Problems under Windows OS, please refer to [issue#8](https://github.com/Zheng-Chong/CatVTON/issues/8).
|
74 |
+
>
|
75 |
+
When you run the CatVTON workflow for the first time, the weight files will be automatically downloaded, usually taking dozens of minutes.
|
76 |
+
|
77 |
+
<div align="center">
|
78 |
+
<img src="resource/img/comfyui-1.png" width="100%" height="100%"/>
|
79 |
+
</div>
|
80 |
+
|
81 |
+
<!-- <div align="center">
|
82 |
+
<img src="resource/img/comfyui.png" width="100%" height="100%"/>
|
83 |
+
</div> -->
|
84 |
+
|
85 |
+
### Gradio App
|
86 |
+
|
87 |
+
To deploy the Gradio App for CatVTON on your machine, run the following command, and checkpoints will be automatically downloaded from HuggingFace.
|
88 |
+
|
89 |
+
```PowerShell
|
90 |
+
CUDA_VISIBLE_DEVICES=0 python app.py \
|
91 |
+
--output_dir="resource/demo/output" \
|
92 |
+
--mixed_precision="bf16" \
|
93 |
+
--allow_tf32
|
94 |
+
```
|
95 |
+
When using `bf16` precision, generating results with a resolution of `1024x768` only requires about `8G` VRAM.
|
96 |
+
|
97 |
+
## Inference
|
98 |
+
### 1. Data Preparation
|
99 |
+
Before inference, you need to download the [VITON-HD](https://github.com/shadow2496/VITON-HD) or [DressCode](https://github.com/aimagelab/dress-code) dataset.
|
100 |
+
Once the datasets are downloaded, the folder structures should look like these:
|
101 |
+
```
|
102 |
+
βββ VITON-HD
|
103 |
+
| βββ test_pairs_unpaired.txt
|
104 |
+
β βββ test
|
105 |
+
| | βββ image
|
106 |
+
β β β βββ [000006_00.jpg | 000008_00.jpg | ...]
|
107 |
+
β β βββ cloth
|
108 |
+
β β β βββ [000006_00.jpg | 000008_00.jpg | ...]
|
109 |
+
β β βββ agnostic-mask
|
110 |
+
β β β βββ [000006_00_mask.png | 000008_00.png | ...]
|
111 |
+
...
|
112 |
+
```
|
113 |
+
|
114 |
+
```
|
115 |
+
βββ DressCode
|
116 |
+
| βββ test_pairs_paired.txt
|
117 |
+
| βββ test_pairs_unpaired.txt
|
118 |
+
β βββ [dresses | lower_body | upper_body]
|
119 |
+
| | βββ test_pairs_paired.txt
|
120 |
+
| | βββ test_pairs_unpaired.txt
|
121 |
+
β β βββ images
|
122 |
+
β β β βββ [013563_0.jpg | 013563_1.jpg | 013564_0.jpg | 013564_1.jpg | ...]
|
123 |
+
β β βββ agnostic_masks
|
124 |
+
β β β βββ [013563_0.png| 013564_0.png | ...]
|
125 |
+
...
|
126 |
+
```
|
127 |
+
For the DressCode dataset, we provide script to preprocessed agnostic masks, run the following command:
|
128 |
+
```PowerShell
|
129 |
+
CUDA_VISIBLE_DEVICES=0 python preprocess_agnostic_mask.py \
|
130 |
+
--data_root_path <your_path_to_DressCode>
|
131 |
+
```
|
132 |
+
|
133 |
+
### 2. Inference on VTIONHD/DressCode
|
134 |
+
To run the inference on the DressCode or VITON-HD dataset, run the following command, checkpoints will be automatically downloaded from HuggingFace.
|
135 |
+
|
136 |
+
```PowerShell
|
137 |
+
CUDA_VISIBLE_DEVICES=0 python inference.py \
|
138 |
+
--dataset [dresscode | vitonhd] \
|
139 |
+
--data_root_path <path> \
|
140 |
+
--output_dir <path>
|
141 |
+
--dataloader_num_workers 8 \
|
142 |
+
--batch_size 8 \
|
143 |
+
--seed 555 \
|
144 |
+
--mixed_precision [no | fp16 | bf16] \
|
145 |
+
--allow_tf32 \
|
146 |
+
--repaint \
|
147 |
+
--eval_pair
|
148 |
+
```
|
149 |
+
### 3. Calculate Metrics
|
150 |
+
|
151 |
+
After obtaining the inference results, calculate the metrics using the following command:
|
152 |
+
|
153 |
+
```PowerShell
|
154 |
+
CUDA_VISIBLE_DEVICES=0 python eval.py \
|
155 |
+
--gt_folder <your_path_to_gt_image_folder> \
|
156 |
+
--pred_folder <your_path_to_predicted_image_folder> \
|
157 |
+
--paired \
|
158 |
+
--batch_size=16 \
|
159 |
+
--num_workers=16
|
160 |
+
```
|
161 |
+
|
162 |
+
- `--gt_folder` and `--pred_folder` should be folders that contain **only images**.
|
163 |
+
- To evaluate the results in a paired setting, use `--paired`; for an unpaired setting, simply omit it.
|
164 |
+
- `--batch_size` and `--num_workers` should be adjusted based on your machine.
|
165 |
+
|
166 |
+
|
167 |
+
## Acknowledgement
|
168 |
+
Our code is modified based on [Diffusers](https://github.com/huggingface/diffusers). We adopt [Stable Diffusion v1.5 inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting) as the base model. We use [SCHP](https://github.com/GoGoDuck912/Self-Correction-Human-Parsing/tree/master) and [DensePose](https://github.com/facebookresearch/DensePose) to automatically generate masks in our [Gradio](https://github.com/gradio-app/gradio) App and [ComfyUI](https://github.com/comfyanonymous/ComfyUI) workflow. Thanks to all the contributors!
|
169 |
+
|
170 |
+
## License
|
171 |
+
All the materials, including code, checkpoints, and demo, are made available under the [Creative Commons BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) license. You are free to copy, redistribute, remix, transform, and build upon the project for non-commercial purposes, as long as you give appropriate credit and distribute your contributions under the same license.
|
172 |
+
|
173 |
+
|
174 |
+
## Citation
|
175 |
+
|
176 |
+
```bibtex
|
177 |
+
@misc{chong2024catvtonconcatenationneedvirtual,
|
178 |
+
title={CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models},
|
179 |
+
author={Zheng Chong and Xiao Dong and Haoxiang Li and Shiyue Zhang and Wenqing Zhang and Xujie Zhang and Hanqing Zhao and Xiaodan Liang},
|
180 |
+
year={2024},
|
181 |
+
eprint={2407.15886},
|
182 |
+
archivePrefix={arXiv},
|
183 |
+
primaryClass={cs.CV},
|
184 |
+
url={https://arxiv.org/abs/2407.15886},
|
185 |
+
}
|
186 |
+
```
|
app.py
ADDED
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from datetime import datetime
|
4 |
+
from openai import AzureOpenAI
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from diffusers.image_processor import VaeImageProcessor
|
9 |
+
from huggingface_hub import snapshot_download
|
10 |
+
from PIL import Image
|
11 |
+
import base64
|
12 |
+
from model.cloth_masker import AutoMasker, vis_mask
|
13 |
+
from model.pipeline import CatVTONPipeline
|
14 |
+
from utils import init_weight_dtype, resize_and_crop, resize_and_padding
|
15 |
+
from dotenv import load_dotenv
|
16 |
+
from io import BytesIO
|
17 |
+
|
18 |
+
|
19 |
+
load_dotenv()
|
20 |
+
|
21 |
+
import gc
|
22 |
+
from transformers import T5EncoderModel
|
23 |
+
from diffusers import FluxPipeline, FluxTransformer2DModel
|
24 |
+
|
25 |
+
openai_client = AzureOpenAI(
|
26 |
+
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
27 |
+
azure_endpoint=os.getenv("AZURE_ENDPOINT"),
|
28 |
+
api_version="2024-02-15-preview",
|
29 |
+
azure_deployment="gpt-4o-mvp-dev"
|
30 |
+
)
|
31 |
+
|
32 |
+
def parse_args():
|
33 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
34 |
+
parser.add_argument(
|
35 |
+
"--base_model_path",
|
36 |
+
type=str,
|
37 |
+
default="booksforcharlie/stable-diffusion-inpainting", # Change to a copy repo as runawayml delete original repo
|
38 |
+
help=(
|
39 |
+
"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
|
40 |
+
),
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--resume_path",
|
44 |
+
type=str,
|
45 |
+
default="zhengchong/CatVTON",
|
46 |
+
help=(
|
47 |
+
"The Path to the checkpoint of trained tryon model."
|
48 |
+
),
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--output_dir",
|
52 |
+
type=str,
|
53 |
+
default="resource/demo/output",
|
54 |
+
help="The output directory where the model predictions will be written.",
|
55 |
+
)
|
56 |
+
|
57 |
+
parser.add_argument(
|
58 |
+
"--width",
|
59 |
+
type=int,
|
60 |
+
default=768,
|
61 |
+
help=(
|
62 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
63 |
+
" resolution"
|
64 |
+
),
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--height",
|
68 |
+
type=int,
|
69 |
+
default=1024,
|
70 |
+
help=(
|
71 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
72 |
+
" resolution"
|
73 |
+
),
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--repaint",
|
77 |
+
action="store_true",
|
78 |
+
help="Whether to repaint the result image with the original background."
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--allow_tf32",
|
82 |
+
action="store_true",
|
83 |
+
default=True,
|
84 |
+
help=(
|
85 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
86 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
87 |
+
),
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"--mixed_precision",
|
91 |
+
type=str,
|
92 |
+
default="bf16",
|
93 |
+
choices=["no", "fp16", "bf16"],
|
94 |
+
help=(
|
95 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
96 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
97 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
98 |
+
),
|
99 |
+
)
|
100 |
+
|
101 |
+
args = parser.parse_args()
|
102 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
103 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
104 |
+
args.local_rank = env_local_rank
|
105 |
+
|
106 |
+
return args
|
107 |
+
|
108 |
+
def image_grid(imgs, rows, cols):
|
109 |
+
assert len(imgs) == rows * cols
|
110 |
+
|
111 |
+
w, h = imgs[0].size
|
112 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
113 |
+
|
114 |
+
for i, img in enumerate(imgs):
|
115 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
116 |
+
return grid
|
117 |
+
|
118 |
+
|
119 |
+
args = parse_args()
|
120 |
+
repo_path = snapshot_download(repo_id=args.resume_path)
|
121 |
+
|
122 |
+
def flush():
|
123 |
+
gc.collect()
|
124 |
+
torch.cuda.empty_cache()
|
125 |
+
torch.cuda.reset_max_memory_allocated()
|
126 |
+
torch.cuda.reset_peak_memory_stats()
|
127 |
+
|
128 |
+
flush()
|
129 |
+
|
130 |
+
ckpt_4bit_id = "sayakpaul/flux.1-dev-nf4-pkg"
|
131 |
+
|
132 |
+
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
|
133 |
+
ckpt_4bit_id,
|
134 |
+
subfolder="text_encoder_2",
|
135 |
+
)
|
136 |
+
|
137 |
+
# image gen pipeline
|
138 |
+
# Pipeline
|
139 |
+
pipeline = CatVTONPipeline(
|
140 |
+
base_ckpt=args.base_model_path,
|
141 |
+
attn_ckpt=repo_path,
|
142 |
+
attn_ckpt_version="mix",
|
143 |
+
weight_dtype=init_weight_dtype(args.mixed_precision),
|
144 |
+
use_tf32=args.allow_tf32,
|
145 |
+
device='cuda'
|
146 |
+
)
|
147 |
+
# AutoMasker
|
148 |
+
mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
|
149 |
+
automasker = AutoMasker(
|
150 |
+
densepose_ckpt=os.path.join(repo_path, "DensePose"),
|
151 |
+
schp_ckpt=os.path.join(repo_path, "SCHP"),
|
152 |
+
device='cuda',
|
153 |
+
)
|
154 |
+
|
155 |
+
def submit_function(
|
156 |
+
person_image,
|
157 |
+
cloth_image,
|
158 |
+
cloth_type,
|
159 |
+
num_inference_steps,
|
160 |
+
guidance_scale,
|
161 |
+
seed,
|
162 |
+
show_type,
|
163 |
+
campaign_context,
|
164 |
+
):
|
165 |
+
person_image, mask = person_image["background"], person_image["layers"][0]
|
166 |
+
mask = Image.open(mask).convert("L")
|
167 |
+
if len(np.unique(np.array(mask))) == 1:
|
168 |
+
mask = None
|
169 |
+
else:
|
170 |
+
mask = np.array(mask)
|
171 |
+
mask[mask > 0] = 255
|
172 |
+
mask = Image.fromarray(mask)
|
173 |
+
|
174 |
+
tmp_folder = args.output_dir
|
175 |
+
date_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
176 |
+
result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
|
177 |
+
if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
|
178 |
+
os.makedirs(os.path.join(tmp_folder, date_str[:8]))
|
179 |
+
|
180 |
+
generator = None
|
181 |
+
if seed != -1:
|
182 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
183 |
+
|
184 |
+
person_image = Image.open(person_image).convert("RGB")
|
185 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
186 |
+
person_image = resize_and_crop(person_image, (args.width, args.height))
|
187 |
+
cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
|
188 |
+
|
189 |
+
|
190 |
+
# Process mask
|
191 |
+
if mask is not None:
|
192 |
+
mask = resize_and_crop(mask, (args.width, args.height))
|
193 |
+
else:
|
194 |
+
mask = automasker(
|
195 |
+
person_image,
|
196 |
+
cloth_type
|
197 |
+
)['mask']
|
198 |
+
mask = mask_processor.blur(mask, blur_factor=9)
|
199 |
+
|
200 |
+
# Inference
|
201 |
+
# try:
|
202 |
+
result_image = pipeline(
|
203 |
+
image=person_image,
|
204 |
+
condition_image=cloth_image,
|
205 |
+
mask=mask,
|
206 |
+
num_inference_steps=num_inference_steps,
|
207 |
+
guidance_scale=guidance_scale,
|
208 |
+
generator=generator
|
209 |
+
)[0]
|
210 |
+
# except Exception as e:
|
211 |
+
# raise gr.Error(
|
212 |
+
# "An error occurred. Please try again later: {}".format(e)
|
213 |
+
# )
|
214 |
+
|
215 |
+
# Post-process
|
216 |
+
masked_person = vis_mask(person_image, mask)
|
217 |
+
save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
|
218 |
+
save_result_image.save(result_save_path)
|
219 |
+
# Generate product description
|
220 |
+
product_description = generate_upper_cloth_description(cloth_image, cloth_type)
|
221 |
+
|
222 |
+
# Generate captions for the campaign
|
223 |
+
captions = generate_captions(product_description, campaign_context)
|
224 |
+
|
225 |
+
if show_type == "result only":
|
226 |
+
return result_image
|
227 |
+
else:
|
228 |
+
width, height = person_image.size
|
229 |
+
if show_type == "input & result":
|
230 |
+
condition_width = width // 2
|
231 |
+
conditions = image_grid([person_image, cloth_image], 2, 1)
|
232 |
+
else:
|
233 |
+
condition_width = width // 3
|
234 |
+
conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
|
235 |
+
conditions = conditions.resize((condition_width, height), Image.NEAREST)
|
236 |
+
new_result_image = Image.new("RGB", (width + condition_width + 5, height))
|
237 |
+
new_result_image.paste(conditions, (0, 0))
|
238 |
+
new_result_image.paste(result_image, (condition_width + 5, 0))
|
239 |
+
return new_result_image, captions
|
240 |
+
|
241 |
+
|
242 |
+
def person_example_fn(image_path):
|
243 |
+
return image_path
|
244 |
+
|
245 |
+
def generate_person_image(text, cloth_description):
|
246 |
+
"""
|
247 |
+
Creates a test image based on the prompt.
|
248 |
+
Returns the path to the generated image.
|
249 |
+
"""
|
250 |
+
prompt = generate_ai_model_prompt(text, cloth_description)
|
251 |
+
ckpt_id = "black-forest-labs/FLUX.1-dev"
|
252 |
+
|
253 |
+
print("generating image with prompt: ", prompt)
|
254 |
+
image_gen_pipeline = FluxPipeline.from_pretrained(
|
255 |
+
ckpt_id,
|
256 |
+
text_encoder_2=text_encoder_2_4bit,
|
257 |
+
transformer=None,
|
258 |
+
vae=None,
|
259 |
+
torch_dtype=torch.float16,
|
260 |
+
)
|
261 |
+
image_gen_pipeline.enable_model_cpu_offload()
|
262 |
+
# Create a new image with a random background color
|
263 |
+
|
264 |
+
with torch.no_grad():
|
265 |
+
print("Encoding prompts.")
|
266 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = image_gen_pipeline.encode_prompt(
|
267 |
+
prompt=prompt, prompt_2=None, max_sequence_length=256
|
268 |
+
)
|
269 |
+
|
270 |
+
image_gen_pipeline = image_gen_pipeline.to("cpu")
|
271 |
+
del image_gen_pipeline
|
272 |
+
|
273 |
+
flush()
|
274 |
+
|
275 |
+
print(f"prompt_embeds shape: {prompt_embeds.shape}")
|
276 |
+
print(f"pooled_prompt_embeds shape: {pooled_prompt_embeds.shape}")
|
277 |
+
# Add the prompt text to the image
|
278 |
+
transformer_4bit = FluxTransformer2DModel.from_pretrained(ckpt_4bit_id, subfolder="transformer")
|
279 |
+
image_gen_pipeline = FluxPipeline.from_pretrained(
|
280 |
+
ckpt_id,
|
281 |
+
text_encoder=None,
|
282 |
+
text_encoder_2=None,
|
283 |
+
tokenizer=None,
|
284 |
+
tokenizer_2=None,
|
285 |
+
transformer=transformer_4bit,
|
286 |
+
torch_dtype=torch.float16,
|
287 |
+
)
|
288 |
+
image_gen_pipeline.enable_model_cpu_offload()
|
289 |
+
|
290 |
+
print("Running denoising.")
|
291 |
+
height, width = 1024, 1024
|
292 |
+
|
293 |
+
images = image_gen_pipeline(
|
294 |
+
prompt_embeds=prompt_embeds,
|
295 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
296 |
+
num_inference_steps=50,
|
297 |
+
guidance_scale=5.5,
|
298 |
+
height=height,
|
299 |
+
width=width,
|
300 |
+
output_type="pil",
|
301 |
+
).images
|
302 |
+
|
303 |
+
# Add current time to make each image unique
|
304 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
305 |
+
|
306 |
+
# Create output directory if it doesn't exist
|
307 |
+
os.makedirs('generated_images', exist_ok=True)
|
308 |
+
|
309 |
+
# Save the image
|
310 |
+
output_path = f'generated_images/generated_{timestamp}.png'
|
311 |
+
images[0].save(output_path)
|
312 |
+
|
313 |
+
return output_path
|
314 |
+
|
315 |
+
def pil_image_to_base64(image, format: str = "PNG") -> str:
|
316 |
+
"""
|
317 |
+
Converts an image to a Base64 encoded string.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
image: Either a file path (str) or a PIL Image object
|
321 |
+
format (str): The format to save the image as (default is PNG).
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
str: A Base64 encoded string of the image.
|
325 |
+
"""
|
326 |
+
try:
|
327 |
+
# If image is a file path, open it
|
328 |
+
if isinstance(image, str):
|
329 |
+
image = Image.open(image)
|
330 |
+
elif not isinstance(image, Image.Image):
|
331 |
+
raise ValueError("Input must be either a file path or a PIL Image object")
|
332 |
+
|
333 |
+
# Convert the image to Base64
|
334 |
+
buffered = BytesIO()
|
335 |
+
image.save(buffered, format=format)
|
336 |
+
buffered.seek(0) # Go to the start of the BytesIO stream
|
337 |
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
338 |
+
return image_base64
|
339 |
+
except Exception as e:
|
340 |
+
print(f"Error converting image to Base64: {e}")
|
341 |
+
raise e
|
342 |
+
|
343 |
+
def generate_upper_cloth_description(product_image, cloth_type: str):
|
344 |
+
try:
|
345 |
+
base_64_image = pil_image_to_base64(product_image)
|
346 |
+
|
347 |
+
if cloth_type == "upper":
|
348 |
+
system_prompt = """
|
349 |
+
You are world class fahsion designer
|
350 |
+
Your task is to Write a detailed description of the upper body garment shown in the image, focusing on its fit, sleeve style, fabric type, neckline, and any notable design elements or features in one or two lines for given image.
|
351 |
+
Don't start with "This image shows a pair of beige cargo ..." but instead start with "a pair of beige cargo ..."
|
352 |
+
"""
|
353 |
+
elif cloth_type == "lower":
|
354 |
+
system_prompt = """
|
355 |
+
You are world class fahsion designer
|
356 |
+
Your task is to Write a detailed description of the lower body garment shown in the image, focusing on its fit, fabric type, waist style, and any notable design elements or features in one or two lines for given image.
|
357 |
+
Don't start with "This image shows a pair of beige cargo ..." but instead start with "a pair of beige cargo ..."
|
358 |
+
"""
|
359 |
+
elif cloth_type == "overall":
|
360 |
+
system_prompt = """
|
361 |
+
You are world class fahsion designer
|
362 |
+
Your task is to Write a detailed description of the overall garment shown in the image, focusing on its fit, fabric type, sleeve style, neckline, and any notable design elements or features in one or two lines for given image.
|
363 |
+
Don't start with "This image shows a pair of beige cargo ..." but instead start with "a pair of beige cargo ..."
|
364 |
+
"""
|
365 |
+
else:
|
366 |
+
system_prompt = """
|
367 |
+
You are world class fahsion designer
|
368 |
+
Your task is to Write a detailed description of the upper body garment shown in the image, focusing on its fit, sleeve style, fabric type, neckline, and any notable design elements or features in one or two lines for given image.
|
369 |
+
Don't start with "This image shows a pair of beige cargo ..." but instead start with "a pair of beige cargo ..."
|
370 |
+
"""
|
371 |
+
|
372 |
+
response = openai_client.chat.completions.create(
|
373 |
+
model="gpt-4o",
|
374 |
+
messages=[
|
375 |
+
{"role": "system", "content": system_prompt},
|
376 |
+
{"role": "user", "content": [
|
377 |
+
{
|
378 |
+
"type": "image_url",
|
379 |
+
"image_url": {
|
380 |
+
"url": f"data:image/jpeg;base64,{base_64_image}"
|
381 |
+
}
|
382 |
+
}
|
383 |
+
]},
|
384 |
+
],
|
385 |
+
)
|
386 |
+
|
387 |
+
return response.choices[0].message.content
|
388 |
+
except Exception as e:
|
389 |
+
print(f"Error in generate_upper_cloth_description: {e}")
|
390 |
+
raise e
|
391 |
+
|
392 |
+
def generate_caption_for_image(image):
|
393 |
+
"""
|
394 |
+
Generates a caption for the given image using OpenAI's vision model.
|
395 |
+
"""
|
396 |
+
if image is None:
|
397 |
+
return "Please generate a try-on result first."
|
398 |
+
|
399 |
+
# Convert the image to base64
|
400 |
+
if isinstance(image, str):
|
401 |
+
base64_image = pil_image_to_base64(image)
|
402 |
+
else:
|
403 |
+
# Convert numpy array to PIL Image
|
404 |
+
if isinstance(image, np.ndarray):
|
405 |
+
image = Image.fromarray((image * 255).astype(np.uint8))
|
406 |
+
buffered = BytesIO()
|
407 |
+
image.save(buffered, format="PNG")
|
408 |
+
base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
409 |
+
|
410 |
+
system_prompt = """
|
411 |
+
You are a world class campaign generator for cloth that model is wearing.
|
412 |
+
Create campaign caption for the image shown below.
|
413 |
+
create engaging campaign captions for products in the merchandise for instagram stories that attract, convert and retain customers.
|
414 |
+
"""
|
415 |
+
|
416 |
+
try:
|
417 |
+
response = openai_client.chat.completions.create(
|
418 |
+
model="gpt-4o",
|
419 |
+
messages=[
|
420 |
+
{"role": "system", "content": system_prompt},
|
421 |
+
{"role": "user", "content": [
|
422 |
+
{
|
423 |
+
"type": "image_url",
|
424 |
+
"image_url": {
|
425 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
426 |
+
}
|
427 |
+
}
|
428 |
+
]},
|
429 |
+
],
|
430 |
+
)
|
431 |
+
return response.choices[0].message.content
|
432 |
+
except Exception as e:
|
433 |
+
return f"Error generating caption: {str(e)}"
|
434 |
+
|
435 |
+
def generate_ai_model_prompt(model_description, product_description):
|
436 |
+
print("prompt for ai model generation", f" {model_description} wearing {product_description}.")
|
437 |
+
return f" {model_description} wearing {product_description}, full image"
|
438 |
+
|
439 |
+
def generate_captions(product_description, campaign_context):
|
440 |
+
|
441 |
+
#system prompt
|
442 |
+
system_prompt = """
|
443 |
+
You are a world-class marketing expert.
|
444 |
+
Your task is to create engaging, professional, and contextually relevant campaign captions based on the details provided.
|
445 |
+
Use creative language to highlight the product's key features and align with the campaign's goals.
|
446 |
+
Ensure the captions are tailored to the specific advertising context provided.
|
447 |
+
"""
|
448 |
+
|
449 |
+
# user prompt
|
450 |
+
user_prompt = f"""
|
451 |
+
Campaign Context: {campaign_context}
|
452 |
+
Product Description: {product_description}
|
453 |
+
Generate captivating captions for this campaign that align with the provided context.
|
454 |
+
"""
|
455 |
+
|
456 |
+
# Call OpenAI API
|
457 |
+
response = openai_client.chat.completions.create(
|
458 |
+
model="gpt-4o",
|
459 |
+
messages=[
|
460 |
+
{"role": "system", "content": system_prompt},
|
461 |
+
{"role": "user", "content": user_prompt},
|
462 |
+
],
|
463 |
+
)
|
464 |
+
|
465 |
+
# Extract generated captions
|
466 |
+
captions = response.choices[0].message.content.strip()
|
467 |
+
return captions
|
468 |
+
|
469 |
+
|
470 |
+
HEADER = """
|
471 |
+
"""
|
472 |
+
|
473 |
+
def app_gradio():
|
474 |
+
with gr.Blocks(title="CatVTON") as demo:
|
475 |
+
gr.Markdown(HEADER)
|
476 |
+
with gr.Row():
|
477 |
+
with gr.Column(scale=1, min_width=350):
|
478 |
+
text_prompt = gr.Textbox(
|
479 |
+
label="Describe the person (e.g., 'a young woman in a neutral pose')",
|
480 |
+
lines=3
|
481 |
+
)
|
482 |
+
|
483 |
+
|
484 |
+
generate_button = gr.Button("Generate Person Image")
|
485 |
+
|
486 |
+
# Hidden image path component
|
487 |
+
image_path = gr.Image(
|
488 |
+
type="filepath",
|
489 |
+
interactive=True,
|
490 |
+
visible=False,
|
491 |
+
)
|
492 |
+
|
493 |
+
# Display generated person image
|
494 |
+
person_image = gr.ImageEditor(
|
495 |
+
interactive=True,
|
496 |
+
label="Generated Person Image",
|
497 |
+
type="filepath"
|
498 |
+
)
|
499 |
+
|
500 |
+
campaign_context = gr.Textbox(
|
501 |
+
label="Describe your campaign context (e.g., 'Summer sale campaign focusing on vibrant colors')",
|
502 |
+
lines=3,
|
503 |
+
placeholder="What message do you want to convey in this campaign?",
|
504 |
+
)
|
505 |
+
|
506 |
+
|
507 |
+
with gr.Row():
|
508 |
+
with gr.Column(scale=1, min_width=230):
|
509 |
+
cloth_image = gr.Image(
|
510 |
+
interactive=True, label="Condition Image", type="filepath"
|
511 |
+
)
|
512 |
+
|
513 |
+
cloth_description = gr.Textbox(
|
514 |
+
label="Cloth Description",
|
515 |
+
interactive=False,
|
516 |
+
lines=3
|
517 |
+
)
|
518 |
+
|
519 |
+
|
520 |
+
|
521 |
+
with gr.Column(scale=1, min_width=120):
|
522 |
+
gr.Markdown(
|
523 |
+
'<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `ποΈ` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
|
524 |
+
)
|
525 |
+
cloth_type = gr.Radio(
|
526 |
+
label="Try-On Cloth Type",
|
527 |
+
choices=["upper", "lower", "overall"],
|
528 |
+
value="upper",
|
529 |
+
)
|
530 |
+
|
531 |
+
cloth_image.change(
|
532 |
+
generate_upper_cloth_description,
|
533 |
+
inputs=[cloth_image, cloth_type],
|
534 |
+
outputs=[cloth_description],
|
535 |
+
)
|
536 |
+
|
537 |
+
|
538 |
+
submit = gr.Button("Submit")
|
539 |
+
gr.Markdown(
|
540 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
541 |
+
)
|
542 |
+
|
543 |
+
gr.Markdown(
|
544 |
+
'<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
|
545 |
+
)
|
546 |
+
with gr.Accordion("Advanced Options", open=False):
|
547 |
+
num_inference_steps = gr.Slider(
|
548 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
549 |
+
)
|
550 |
+
# Guidence Scale
|
551 |
+
guidance_scale = gr.Slider(
|
552 |
+
label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
|
553 |
+
)
|
554 |
+
# Random Seed
|
555 |
+
seed = gr.Slider(
|
556 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
557 |
+
)
|
558 |
+
show_type = gr.Radio(
|
559 |
+
label="Show Type",
|
560 |
+
choices=["result only", "input & result", "input & mask & result"],
|
561 |
+
value="input & mask & result",
|
562 |
+
)
|
563 |
+
|
564 |
+
with gr.Column(scale=2, min_width=500):
|
565 |
+
# single or multiple image
|
566 |
+
|
567 |
+
result_image = gr.Image(interactive=False, label="Result")
|
568 |
+
captions_textbox = gr.Textbox(
|
569 |
+
label="Generated Campaign Captions",
|
570 |
+
interactive=False,
|
571 |
+
lines=6
|
572 |
+
)
|
573 |
+
|
574 |
+
|
575 |
+
with gr.Row():
|
576 |
+
# Photo Examples
|
577 |
+
root_path = "resource/demo/example"
|
578 |
+
|
579 |
+
|
580 |
+
image_path.change(
|
581 |
+
person_example_fn, inputs=image_path, outputs=person_image
|
582 |
+
)
|
583 |
+
|
584 |
+
|
585 |
+
# Connect the generation button
|
586 |
+
generate_button.click(
|
587 |
+
generate_person_image,
|
588 |
+
inputs=[text_prompt, cloth_description],
|
589 |
+
outputs=[person_image]
|
590 |
+
)
|
591 |
+
|
592 |
+
|
593 |
+
submit.click(
|
594 |
+
submit_function,
|
595 |
+
[
|
596 |
+
person_image,
|
597 |
+
cloth_image,
|
598 |
+
cloth_type,
|
599 |
+
num_inference_steps,
|
600 |
+
guidance_scale,
|
601 |
+
seed,
|
602 |
+
show_type,
|
603 |
+
campaign_context,
|
604 |
+
],
|
605 |
+
[result_image, captions_textbox]
|
606 |
+
)
|
607 |
+
|
608 |
+
# generate_caption_btn.click(
|
609 |
+
# generate_caption_for_image,
|
610 |
+
# inputs=[result_image],
|
611 |
+
# outputs=[caption_text]
|
612 |
+
# )
|
613 |
+
demo.queue().launch(share=True, show_error=True)
|
614 |
+
|
615 |
+
|
616 |
+
if __name__ == "__main__":
|
617 |
+
app_gradio()
|
app_flux.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import gradio as gr
|
4 |
+
from datetime import datetime
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from diffusers.image_processor import VaeImageProcessor
|
9 |
+
from huggingface_hub import snapshot_download
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from model.cloth_masker import AutoMasker, vis_mask
|
13 |
+
from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
|
14 |
+
from utils import resize_and_crop, resize_and_padding
|
15 |
+
|
16 |
+
def parse_args():
|
17 |
+
parser = argparse.ArgumentParser(description="FLUX Try-On Demo")
|
18 |
+
parser.add_argument(
|
19 |
+
"--base_model_path",
|
20 |
+
type=str,
|
21 |
+
default="black-forest-labs/FLUX.1-Fill-dev",
|
22 |
+
# default="Models/FLUX.1-Fill-dev",
|
23 |
+
help="The path to the base model to use for evaluation."
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--resume_path",
|
27 |
+
type=str,
|
28 |
+
default="zhengchong/CatVTON",
|
29 |
+
help="The Path to the checkpoint of trained tryon model."
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--output_dir",
|
33 |
+
type=str,
|
34 |
+
default="resource/demo/output",
|
35 |
+
help="The output directory where the model predictions will be written."
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--mixed_precision",
|
39 |
+
type=str,
|
40 |
+
default="bf16",
|
41 |
+
choices=["no", "fp16", "bf16"],
|
42 |
+
help="Whether to use mixed precision."
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--allow_tf32",
|
46 |
+
action="store_true",
|
47 |
+
default=True,
|
48 |
+
help="Whether or not to allow TF32 on Ampere GPUs."
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--width",
|
52 |
+
type=int,
|
53 |
+
default=768,
|
54 |
+
help="The width of the input image."
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--height",
|
58 |
+
type=int,
|
59 |
+
default=1024,
|
60 |
+
help="The height of the input image."
|
61 |
+
)
|
62 |
+
return parser.parse_args()
|
63 |
+
|
64 |
+
def image_grid(imgs, rows, cols):
|
65 |
+
assert len(imgs) == rows * cols
|
66 |
+
w, h = imgs[0].size
|
67 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
68 |
+
for i, img in enumerate(imgs):
|
69 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
70 |
+
return grid
|
71 |
+
|
72 |
+
|
73 |
+
def submit_function_flux(
|
74 |
+
person_image,
|
75 |
+
cloth_image,
|
76 |
+
cloth_type,
|
77 |
+
num_inference_steps,
|
78 |
+
guidance_scale,
|
79 |
+
seed,
|
80 |
+
show_type
|
81 |
+
):
|
82 |
+
|
83 |
+
# Process image editor input
|
84 |
+
person_image, mask = person_image["background"], person_image["layers"][0]
|
85 |
+
mask = Image.open(mask).convert("L")
|
86 |
+
if len(np.unique(np.array(mask))) == 1:
|
87 |
+
mask = None
|
88 |
+
else:
|
89 |
+
mask = np.array(mask)
|
90 |
+
mask[mask > 0] = 255
|
91 |
+
mask = Image.fromarray(mask)
|
92 |
+
|
93 |
+
# Set random seed
|
94 |
+
generator = None
|
95 |
+
if seed != -1:
|
96 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
97 |
+
|
98 |
+
# Process input images
|
99 |
+
person_image = Image.open(person_image).convert("RGB")
|
100 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
101 |
+
|
102 |
+
# Adjust image sizes
|
103 |
+
person_image = resize_and_crop(person_image, (args.width, args.height))
|
104 |
+
cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
|
105 |
+
|
106 |
+
# Process mask
|
107 |
+
if mask is not None:
|
108 |
+
mask = resize_and_crop(mask, (args.width, args.height))
|
109 |
+
else:
|
110 |
+
mask = automasker(
|
111 |
+
person_image,
|
112 |
+
cloth_type
|
113 |
+
)['mask']
|
114 |
+
mask = mask_processor.blur(mask, blur_factor=9)
|
115 |
+
|
116 |
+
# Inference
|
117 |
+
result_image = pipeline_flux(
|
118 |
+
image=person_image,
|
119 |
+
condition_image=cloth_image,
|
120 |
+
mask_image=mask,
|
121 |
+
height=args.height,
|
122 |
+
width=args.width,
|
123 |
+
num_inference_steps=num_inference_steps,
|
124 |
+
guidance_scale=guidance_scale,
|
125 |
+
generator=generator
|
126 |
+
).images[0]
|
127 |
+
|
128 |
+
# Post-processing
|
129 |
+
masked_person = vis_mask(person_image, mask)
|
130 |
+
|
131 |
+
# Return result based on show type
|
132 |
+
if show_type == "result only":
|
133 |
+
return result_image
|
134 |
+
else:
|
135 |
+
width, height = person_image.size
|
136 |
+
if show_type == "input & result":
|
137 |
+
condition_width = width // 2
|
138 |
+
conditions = image_grid([person_image, cloth_image], 2, 1)
|
139 |
+
else:
|
140 |
+
condition_width = width // 3
|
141 |
+
conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
|
142 |
+
|
143 |
+
conditions = conditions.resize((condition_width, height), Image.NEAREST)
|
144 |
+
new_result_image = Image.new("RGB", (width + condition_width + 5, height))
|
145 |
+
new_result_image.paste(conditions, (0, 0))
|
146 |
+
new_result_image.paste(result_image, (condition_width + 5, 0))
|
147 |
+
return new_result_image
|
148 |
+
|
149 |
+
def person_example_fn(image_path):
|
150 |
+
return image_path
|
151 |
+
|
152 |
+
|
153 |
+
def app_gradio():
|
154 |
+
with gr.Blocks(title="CatVTON with FLUX.1-Fill-dev") as demo:
|
155 |
+
gr.Markdown("# CatVTON with FLUX.1-Fill-dev")
|
156 |
+
with gr.Row():
|
157 |
+
with gr.Column(scale=1, min_width=350):
|
158 |
+
with gr.Row():
|
159 |
+
image_path_flux = gr.Image(
|
160 |
+
type="filepath",
|
161 |
+
interactive=True,
|
162 |
+
visible=False,
|
163 |
+
)
|
164 |
+
person_image_flux = gr.ImageEditor(
|
165 |
+
interactive=True, label="Person Image", type="filepath"
|
166 |
+
)
|
167 |
+
|
168 |
+
with gr.Row():
|
169 |
+
with gr.Column(scale=1, min_width=230):
|
170 |
+
cloth_image_flux = gr.Image(
|
171 |
+
interactive=True, label="Condition Image", type="filepath"
|
172 |
+
)
|
173 |
+
with gr.Column(scale=1, min_width=120):
|
174 |
+
gr.Markdown(
|
175 |
+
'<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `ποΈ` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
|
176 |
+
)
|
177 |
+
cloth_type = gr.Radio(
|
178 |
+
label="Try-On Cloth Type",
|
179 |
+
choices=["upper", "lower", "overall"],
|
180 |
+
value="upper",
|
181 |
+
)
|
182 |
+
|
183 |
+
submit_flux = gr.Button("Submit")
|
184 |
+
gr.Markdown(
|
185 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
186 |
+
)
|
187 |
+
|
188 |
+
with gr.Accordion("Advanced Options", open=False):
|
189 |
+
num_inference_steps_flux = gr.Slider(
|
190 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
191 |
+
)
|
192 |
+
# Guidence Scale
|
193 |
+
guidance_scale_flux = gr.Slider(
|
194 |
+
label="CFG Strenth", minimum=0.0, maximum=50, step=0.5, value=30
|
195 |
+
)
|
196 |
+
# Random Seed
|
197 |
+
seed_flux = gr.Slider(
|
198 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
199 |
+
)
|
200 |
+
show_type = gr.Radio(
|
201 |
+
label="Show Type",
|
202 |
+
choices=["result only", "input & result", "input & mask & result"],
|
203 |
+
value="input & mask & result",
|
204 |
+
)
|
205 |
+
|
206 |
+
with gr.Column(scale=2, min_width=500):
|
207 |
+
result_image_flux = gr.Image(interactive=False, label="Result")
|
208 |
+
with gr.Row():
|
209 |
+
# Photo Examples
|
210 |
+
root_path = "resource/demo/example"
|
211 |
+
with gr.Column():
|
212 |
+
gr.Examples(
|
213 |
+
examples=[
|
214 |
+
os.path.join(root_path, "person", "men", _)
|
215 |
+
for _ in os.listdir(os.path.join(root_path, "person", "men"))
|
216 |
+
],
|
217 |
+
examples_per_page=4,
|
218 |
+
inputs=image_path_flux,
|
219 |
+
label="Person Examples β ",
|
220 |
+
)
|
221 |
+
gr.Examples(
|
222 |
+
examples=[
|
223 |
+
os.path.join(root_path, "person", "women", _)
|
224 |
+
for _ in os.listdir(os.path.join(root_path, "person", "women"))
|
225 |
+
],
|
226 |
+
examples_per_page=4,
|
227 |
+
inputs=image_path_flux,
|
228 |
+
label="Person Examples β‘",
|
229 |
+
)
|
230 |
+
gr.Markdown(
|
231 |
+
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
|
232 |
+
)
|
233 |
+
with gr.Column():
|
234 |
+
gr.Examples(
|
235 |
+
examples=[
|
236 |
+
os.path.join(root_path, "condition", "upper", _)
|
237 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
|
238 |
+
],
|
239 |
+
examples_per_page=4,
|
240 |
+
inputs=cloth_image_flux,
|
241 |
+
label="Condition Upper Examples",
|
242 |
+
)
|
243 |
+
gr.Examples(
|
244 |
+
examples=[
|
245 |
+
os.path.join(root_path, "condition", "overall", _)
|
246 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
|
247 |
+
],
|
248 |
+
examples_per_page=4,
|
249 |
+
inputs=cloth_image_flux,
|
250 |
+
label="Condition Overall Examples",
|
251 |
+
)
|
252 |
+
condition_person_exm = gr.Examples(
|
253 |
+
examples=[
|
254 |
+
os.path.join(root_path, "condition", "person", _)
|
255 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
|
256 |
+
],
|
257 |
+
examples_per_page=4,
|
258 |
+
inputs=cloth_image_flux,
|
259 |
+
label="Condition Reference Person Examples",
|
260 |
+
)
|
261 |
+
gr.Markdown(
|
262 |
+
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
|
263 |
+
)
|
264 |
+
|
265 |
+
|
266 |
+
image_path_flux.change(
|
267 |
+
person_example_fn, inputs=image_path_flux, outputs=person_image_flux
|
268 |
+
)
|
269 |
+
|
270 |
+
submit_flux.click(
|
271 |
+
submit_function_flux,
|
272 |
+
[person_image_flux, cloth_image_flux, cloth_type, num_inference_steps_flux, guidance_scale_flux, seed_flux, show_type],
|
273 |
+
result_image_flux,
|
274 |
+
)
|
275 |
+
|
276 |
+
|
277 |
+
demo.queue().launch(share=True, show_error=True)
|
278 |
+
|
279 |
+
# θ§£ζεζ°
|
280 |
+
args = parse_args()
|
281 |
+
|
282 |
+
# ε 载樑ε
|
283 |
+
repo_path = snapshot_download(repo_id=args.resume_path)
|
284 |
+
pipeline_flux = FluxTryOnPipeline.from_pretrained(args.base_model_path)
|
285 |
+
pipeline_flux.load_lora_weights(
|
286 |
+
os.path.join(repo_path, "flux-lora"),
|
287 |
+
weight_name='pytorch_lora_weights.safetensors'
|
288 |
+
)
|
289 |
+
pipeline_flux.to("cuda", torch.bfloat16)
|
290 |
+
|
291 |
+
# εε§ε AutoMasker
|
292 |
+
mask_processor = VaeImageProcessor(
|
293 |
+
vae_scale_factor=8,
|
294 |
+
do_normalize=False,
|
295 |
+
do_binarize=True,
|
296 |
+
do_convert_grayscale=True
|
297 |
+
)
|
298 |
+
automasker = AutoMasker(
|
299 |
+
densepose_ckpt=os.path.join(repo_path, "DensePose"),
|
300 |
+
schp_ckpt=os.path.join(repo_path, "SCHP"),
|
301 |
+
device='cuda'
|
302 |
+
)
|
303 |
+
|
304 |
+
if __name__ == "__main__":
|
305 |
+
app_gradio()
|
app_p2p.py
ADDED
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from datetime import datetime
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from diffusers.image_processor import VaeImageProcessor
|
9 |
+
from huggingface_hub import snapshot_download
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from model.cloth_masker import AutoMasker, vis_mask
|
13 |
+
from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
|
14 |
+
from utils import init_weight_dtype, resize_and_crop, resize_and_padding
|
15 |
+
|
16 |
+
|
17 |
+
def parse_args():
|
18 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
19 |
+
parser.add_argument(
|
20 |
+
"--p2p_base_model_path",
|
21 |
+
type=str,
|
22 |
+
default="timbrooks/instruct-pix2pix",
|
23 |
+
help=(
|
24 |
+
"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
|
25 |
+
),
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--ip_base_model_path",
|
29 |
+
type=str,
|
30 |
+
default="booksforcharlie/stable-diffusion-inpainting",
|
31 |
+
help=(
|
32 |
+
"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
|
33 |
+
),
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--p2p_resume_path",
|
37 |
+
type=str,
|
38 |
+
default="zhengchong/CatVTON-MaskFree",
|
39 |
+
help=(
|
40 |
+
"The Path to the checkpoint of trained tryon model."
|
41 |
+
),
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--ip_resume_path",
|
45 |
+
type=str,
|
46 |
+
default="zhengchong/CatVTON",
|
47 |
+
help=(
|
48 |
+
"The Path to the checkpoint of trained tryon model."
|
49 |
+
),
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--output_dir",
|
53 |
+
type=str,
|
54 |
+
default="resource/demo/output",
|
55 |
+
help="The output directory where the model predictions will be written.",
|
56 |
+
)
|
57 |
+
|
58 |
+
parser.add_argument(
|
59 |
+
"--width",
|
60 |
+
type=int,
|
61 |
+
default=768,
|
62 |
+
help=(
|
63 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
64 |
+
" resolution"
|
65 |
+
),
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--height",
|
69 |
+
type=int,
|
70 |
+
default=1024,
|
71 |
+
help=(
|
72 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
73 |
+
" resolution"
|
74 |
+
),
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--repaint",
|
78 |
+
action="store_true",
|
79 |
+
help="Whether to repaint the result image with the original background."
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--allow_tf32",
|
83 |
+
action="store_true",
|
84 |
+
default=True,
|
85 |
+
help=(
|
86 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
87 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
88 |
+
),
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--mixed_precision",
|
92 |
+
type=str,
|
93 |
+
default="bf16",
|
94 |
+
choices=["no", "fp16", "bf16"],
|
95 |
+
help=(
|
96 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
97 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
98 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
99 |
+
),
|
100 |
+
)
|
101 |
+
|
102 |
+
args = parser.parse_args()
|
103 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
104 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
105 |
+
args.local_rank = env_local_rank
|
106 |
+
|
107 |
+
return args
|
108 |
+
|
109 |
+
def image_grid(imgs, rows, cols):
|
110 |
+
assert len(imgs) == rows * cols
|
111 |
+
|
112 |
+
w, h = imgs[0].size
|
113 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
114 |
+
|
115 |
+
for i, img in enumerate(imgs):
|
116 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
117 |
+
return grid
|
118 |
+
|
119 |
+
|
120 |
+
args = parse_args()
|
121 |
+
repo_path = snapshot_download(repo_id=args.ip_resume_path)
|
122 |
+
# Pipeline
|
123 |
+
pipeline_p2p = CatVTONPix2PixPipeline(
|
124 |
+
base_ckpt=args.p2p_base_model_path,
|
125 |
+
attn_ckpt=repo_path,
|
126 |
+
attn_ckpt_version="mix-48k-1024",
|
127 |
+
weight_dtype=init_weight_dtype(args.mixed_precision),
|
128 |
+
use_tf32=args.allow_tf32,
|
129 |
+
device='cuda'
|
130 |
+
)
|
131 |
+
|
132 |
+
# Pipeline
|
133 |
+
repo_path = snapshot_download(repo_id=args.ip_resume_path)
|
134 |
+
pipeline = CatVTONPipeline(
|
135 |
+
base_ckpt=args.ip_base_model_path,
|
136 |
+
attn_ckpt=repo_path,
|
137 |
+
attn_ckpt_version="mix",
|
138 |
+
weight_dtype=init_weight_dtype(args.mixed_precision),
|
139 |
+
use_tf32=args.allow_tf32,
|
140 |
+
device='cuda'
|
141 |
+
)
|
142 |
+
|
143 |
+
# AutoMasker
|
144 |
+
mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
|
145 |
+
automasker = AutoMasker(
|
146 |
+
densepose_ckpt=os.path.join(repo_path, "DensePose"),
|
147 |
+
schp_ckpt=os.path.join(repo_path, "SCHP"),
|
148 |
+
device='cuda',
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
def submit_function_p2p(
|
153 |
+
person_image,
|
154 |
+
cloth_image,
|
155 |
+
num_inference_steps,
|
156 |
+
guidance_scale,
|
157 |
+
seed):
|
158 |
+
person_image= person_image["background"]
|
159 |
+
|
160 |
+
tmp_folder = args.output_dir
|
161 |
+
date_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
162 |
+
result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
|
163 |
+
if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
|
164 |
+
os.makedirs(os.path.join(tmp_folder, date_str[:8]))
|
165 |
+
|
166 |
+
generator = None
|
167 |
+
if seed != -1:
|
168 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
169 |
+
|
170 |
+
person_image = Image.open(person_image).convert("RGB")
|
171 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
172 |
+
person_image = resize_and_crop(person_image, (args.width, args.height))
|
173 |
+
cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
|
174 |
+
|
175 |
+
# Inference
|
176 |
+
try:
|
177 |
+
result_image = pipeline_p2p(
|
178 |
+
image=person_image,
|
179 |
+
condition_image=cloth_image,
|
180 |
+
num_inference_steps=num_inference_steps,
|
181 |
+
guidance_scale=guidance_scale,
|
182 |
+
generator=generator
|
183 |
+
)[0]
|
184 |
+
except Exception as e:
|
185 |
+
raise gr.Error(
|
186 |
+
"An error occurred. Please try again later: {}".format(e)
|
187 |
+
)
|
188 |
+
|
189 |
+
# Post-process
|
190 |
+
save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
|
191 |
+
save_result_image.save(result_save_path)
|
192 |
+
return result_image
|
193 |
+
|
194 |
+
def submit_function(
|
195 |
+
person_image,
|
196 |
+
cloth_image,
|
197 |
+
cloth_type,
|
198 |
+
num_inference_steps,
|
199 |
+
guidance_scale,
|
200 |
+
seed,
|
201 |
+
show_type
|
202 |
+
):
|
203 |
+
person_image, mask = person_image["background"], person_image["layers"][0]
|
204 |
+
mask = Image.open(mask).convert("L")
|
205 |
+
if len(np.unique(np.array(mask))) == 1:
|
206 |
+
mask = None
|
207 |
+
else:
|
208 |
+
mask = np.array(mask)
|
209 |
+
mask[mask > 0] = 255
|
210 |
+
mask = Image.fromarray(mask)
|
211 |
+
|
212 |
+
tmp_folder = args.output_dir
|
213 |
+
date_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
214 |
+
result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
|
215 |
+
if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
|
216 |
+
os.makedirs(os.path.join(tmp_folder, date_str[:8]))
|
217 |
+
|
218 |
+
generator = None
|
219 |
+
if seed != -1:
|
220 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
221 |
+
|
222 |
+
person_image = Image.open(person_image).convert("RGB")
|
223 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
224 |
+
person_image = resize_and_crop(person_image, (args.width, args.height))
|
225 |
+
cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
|
226 |
+
|
227 |
+
# Process mask
|
228 |
+
if mask is not None:
|
229 |
+
mask = resize_and_crop(mask, (args.width, args.height))
|
230 |
+
else:
|
231 |
+
mask = automasker(
|
232 |
+
person_image,
|
233 |
+
cloth_type
|
234 |
+
)['mask']
|
235 |
+
mask = mask_processor.blur(mask, blur_factor=9)
|
236 |
+
|
237 |
+
# Inference
|
238 |
+
# try:
|
239 |
+
result_image = pipeline(
|
240 |
+
image=person_image,
|
241 |
+
condition_image=cloth_image,
|
242 |
+
mask=mask,
|
243 |
+
num_inference_steps=num_inference_steps,
|
244 |
+
guidance_scale=guidance_scale,
|
245 |
+
generator=generator
|
246 |
+
)[0]
|
247 |
+
# except Exception as e:
|
248 |
+
# raise gr.Error(
|
249 |
+
# "An error occurred. Please try again later: {}".format(e)
|
250 |
+
# )
|
251 |
+
|
252 |
+
# Post-process
|
253 |
+
masked_person = vis_mask(person_image, mask)
|
254 |
+
save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
|
255 |
+
save_result_image.save(result_save_path)
|
256 |
+
if show_type == "result only":
|
257 |
+
return result_image
|
258 |
+
else:
|
259 |
+
width, height = person_image.size
|
260 |
+
if show_type == "input & result":
|
261 |
+
condition_width = width // 2
|
262 |
+
conditions = image_grid([person_image, cloth_image], 2, 1)
|
263 |
+
else:
|
264 |
+
condition_width = width // 3
|
265 |
+
conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
|
266 |
+
conditions = conditions.resize((condition_width, height), Image.NEAREST)
|
267 |
+
new_result_image = Image.new("RGB", (width + condition_width + 5, height))
|
268 |
+
new_result_image.paste(conditions, (0, 0))
|
269 |
+
new_result_image.paste(result_image, (condition_width + 5, 0))
|
270 |
+
return new_result_image
|
271 |
+
|
272 |
+
|
273 |
+
|
274 |
+
def person_example_fn(image_path):
|
275 |
+
return image_path
|
276 |
+
|
277 |
+
HEADER = """
|
278 |
+
<h1 style="text-align: center;"> π CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
|
279 |
+
<div style="display: flex; justify-content: center; align-items: center;">
|
280 |
+
<a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
|
281 |
+
<img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
|
282 |
+
</a>
|
283 |
+
<a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
|
284 |
+
<img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
|
285 |
+
</a>
|
286 |
+
<a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
|
287 |
+
<img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
|
288 |
+
</a>
|
289 |
+
<a href="http://120.76.142.206:8888" style="margin: 0 2px;">
|
290 |
+
<img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
291 |
+
</a>
|
292 |
+
<a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
|
293 |
+
<img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
294 |
+
</a>
|
295 |
+
<a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
|
296 |
+
<img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
|
297 |
+
</a>
|
298 |
+
<a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
|
299 |
+
<img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
|
300 |
+
</a>
|
301 |
+
</div>
|
302 |
+
<br>
|
303 |
+
Β· This demo and our weights are only for <span>Non-commercial Use</span>. <br>
|
304 |
+
Β· You can try CatVTON in our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a> or our <a href="http://120.76.142.206:8888">online demo</a> (run on 3090). <br>
|
305 |
+
Β· Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
|
306 |
+
Β· SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
|
307 |
+
"""
|
308 |
+
|
309 |
+
def app_gradio():
|
310 |
+
with gr.Blocks(title="CatVTON") as demo:
|
311 |
+
gr.Markdown(HEADER)
|
312 |
+
with gr.Tab("Mask-based Virtual Try-On"):
|
313 |
+
with gr.Row():
|
314 |
+
with gr.Column(scale=1, min_width=350):
|
315 |
+
with gr.Row():
|
316 |
+
image_path = gr.Image(
|
317 |
+
type="filepath",
|
318 |
+
interactive=True,
|
319 |
+
visible=False,
|
320 |
+
)
|
321 |
+
person_image = gr.ImageEditor(
|
322 |
+
interactive=True, label="Person Image", type="filepath"
|
323 |
+
)
|
324 |
+
|
325 |
+
with gr.Row():
|
326 |
+
with gr.Column(scale=1, min_width=230):
|
327 |
+
cloth_image = gr.Image(
|
328 |
+
interactive=True, label="Condition Image", type="filepath"
|
329 |
+
)
|
330 |
+
with gr.Column(scale=1, min_width=120):
|
331 |
+
gr.Markdown(
|
332 |
+
'<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `ποΈ` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
|
333 |
+
)
|
334 |
+
cloth_type = gr.Radio(
|
335 |
+
label="Try-On Cloth Type",
|
336 |
+
choices=["upper", "lower", "overall"],
|
337 |
+
value="upper",
|
338 |
+
)
|
339 |
+
|
340 |
+
|
341 |
+
submit = gr.Button("Submit")
|
342 |
+
gr.Markdown(
|
343 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
344 |
+
)
|
345 |
+
|
346 |
+
gr.Markdown(
|
347 |
+
'<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
|
348 |
+
)
|
349 |
+
with gr.Accordion("Advanced Options", open=False):
|
350 |
+
num_inference_steps = gr.Slider(
|
351 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
352 |
+
)
|
353 |
+
# Guidence Scale
|
354 |
+
guidance_scale = gr.Slider(
|
355 |
+
label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
|
356 |
+
)
|
357 |
+
# Random Seed
|
358 |
+
seed = gr.Slider(
|
359 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
360 |
+
)
|
361 |
+
show_type = gr.Radio(
|
362 |
+
label="Show Type",
|
363 |
+
choices=["result only", "input & result", "input & mask & result"],
|
364 |
+
value="input & mask & result",
|
365 |
+
)
|
366 |
+
|
367 |
+
with gr.Column(scale=2, min_width=500):
|
368 |
+
result_image = gr.Image(interactive=False, label="Result")
|
369 |
+
with gr.Row():
|
370 |
+
# Photo Examples
|
371 |
+
root_path = "resource/demo/example"
|
372 |
+
with gr.Column():
|
373 |
+
men_exm = gr.Examples(
|
374 |
+
examples=[
|
375 |
+
os.path.join(root_path, "person", "men", _)
|
376 |
+
for _ in os.listdir(os.path.join(root_path, "person", "men"))
|
377 |
+
],
|
378 |
+
examples_per_page=4,
|
379 |
+
inputs=image_path,
|
380 |
+
label="Person Examples β ",
|
381 |
+
)
|
382 |
+
women_exm = gr.Examples(
|
383 |
+
examples=[
|
384 |
+
os.path.join(root_path, "person", "women", _)
|
385 |
+
for _ in os.listdir(os.path.join(root_path, "person", "women"))
|
386 |
+
],
|
387 |
+
examples_per_page=4,
|
388 |
+
inputs=image_path,
|
389 |
+
label="Person Examples β‘",
|
390 |
+
)
|
391 |
+
gr.Markdown(
|
392 |
+
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
|
393 |
+
)
|
394 |
+
with gr.Column():
|
395 |
+
condition_upper_exm = gr.Examples(
|
396 |
+
examples=[
|
397 |
+
os.path.join(root_path, "condition", "upper", _)
|
398 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
|
399 |
+
],
|
400 |
+
examples_per_page=4,
|
401 |
+
inputs=cloth_image,
|
402 |
+
label="Condition Upper Examples",
|
403 |
+
)
|
404 |
+
condition_overall_exm = gr.Examples(
|
405 |
+
examples=[
|
406 |
+
os.path.join(root_path, "condition", "overall", _)
|
407 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
|
408 |
+
],
|
409 |
+
examples_per_page=4,
|
410 |
+
inputs=cloth_image,
|
411 |
+
label="Condition Overall Examples",
|
412 |
+
)
|
413 |
+
condition_person_exm = gr.Examples(
|
414 |
+
examples=[
|
415 |
+
os.path.join(root_path, "condition", "person", _)
|
416 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
|
417 |
+
],
|
418 |
+
examples_per_page=4,
|
419 |
+
inputs=cloth_image,
|
420 |
+
label="Condition Reference Person Examples",
|
421 |
+
)
|
422 |
+
gr.Markdown(
|
423 |
+
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
|
424 |
+
)
|
425 |
+
|
426 |
+
image_path.change(
|
427 |
+
person_example_fn, inputs=image_path, outputs=person_image
|
428 |
+
)
|
429 |
+
|
430 |
+
submit.click(
|
431 |
+
submit_function,
|
432 |
+
[
|
433 |
+
person_image,
|
434 |
+
cloth_image,
|
435 |
+
cloth_type,
|
436 |
+
num_inference_steps,
|
437 |
+
guidance_scale,
|
438 |
+
seed,
|
439 |
+
show_type,
|
440 |
+
],
|
441 |
+
result_image,
|
442 |
+
)
|
443 |
+
|
444 |
+
with gr.Tab("Mask-Free Virtual Try-On"):
|
445 |
+
with gr.Row():
|
446 |
+
with gr.Column(scale=1, min_width=350):
|
447 |
+
with gr.Row():
|
448 |
+
image_path_p2p = gr.Image(
|
449 |
+
type="filepath",
|
450 |
+
interactive=True,
|
451 |
+
visible=False,
|
452 |
+
)
|
453 |
+
person_image_p2p = gr.ImageEditor(
|
454 |
+
interactive=True, label="Person Image", type="filepath"
|
455 |
+
)
|
456 |
+
|
457 |
+
with gr.Row():
|
458 |
+
with gr.Column(scale=1, min_width=230):
|
459 |
+
cloth_image_p2p = gr.Image(
|
460 |
+
interactive=True, label="Condition Image", type="filepath"
|
461 |
+
)
|
462 |
+
|
463 |
+
submit_p2p = gr.Button("Submit")
|
464 |
+
gr.Markdown(
|
465 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
466 |
+
)
|
467 |
+
|
468 |
+
gr.Markdown(
|
469 |
+
'<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
|
470 |
+
)
|
471 |
+
with gr.Accordion("Advanced Options", open=False):
|
472 |
+
num_inference_steps_p2p = gr.Slider(
|
473 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
474 |
+
)
|
475 |
+
# Guidence Scale
|
476 |
+
guidance_scale_p2p = gr.Slider(
|
477 |
+
label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
|
478 |
+
)
|
479 |
+
# Random Seed
|
480 |
+
seed_p2p = gr.Slider(
|
481 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
482 |
+
)
|
483 |
+
# show_type = gr.Radio(
|
484 |
+
# label="Show Type",
|
485 |
+
# choices=["result only", "input & result", "input & mask & result"],
|
486 |
+
# value="input & mask & result",
|
487 |
+
# )
|
488 |
+
|
489 |
+
with gr.Column(scale=2, min_width=500):
|
490 |
+
result_image_p2p = gr.Image(interactive=False, label="Result")
|
491 |
+
with gr.Row():
|
492 |
+
# Photo Examples
|
493 |
+
root_path = "resource/demo/example"
|
494 |
+
with gr.Column():
|
495 |
+
gr.Examples(
|
496 |
+
examples=[
|
497 |
+
os.path.join(root_path, "person", "men", _)
|
498 |
+
for _ in os.listdir(os.path.join(root_path, "person", "men"))
|
499 |
+
],
|
500 |
+
examples_per_page=4,
|
501 |
+
inputs=image_path_p2p,
|
502 |
+
label="Person Examples β ",
|
503 |
+
)
|
504 |
+
gr.Examples(
|
505 |
+
examples=[
|
506 |
+
os.path.join(root_path, "person", "women", _)
|
507 |
+
for _ in os.listdir(os.path.join(root_path, "person", "women"))
|
508 |
+
],
|
509 |
+
examples_per_page=4,
|
510 |
+
inputs=image_path_p2p,
|
511 |
+
label="Person Examples β‘",
|
512 |
+
)
|
513 |
+
gr.Markdown(
|
514 |
+
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
|
515 |
+
)
|
516 |
+
with gr.Column():
|
517 |
+
gr.Examples(
|
518 |
+
examples=[
|
519 |
+
os.path.join(root_path, "condition", "upper", _)
|
520 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
|
521 |
+
],
|
522 |
+
examples_per_page=4,
|
523 |
+
inputs=cloth_image_p2p,
|
524 |
+
label="Condition Upper Examples",
|
525 |
+
)
|
526 |
+
gr.Examples(
|
527 |
+
examples=[
|
528 |
+
os.path.join(root_path, "condition", "overall", _)
|
529 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
|
530 |
+
],
|
531 |
+
examples_per_page=4,
|
532 |
+
inputs=cloth_image_p2p,
|
533 |
+
label="Condition Overall Examples",
|
534 |
+
)
|
535 |
+
condition_person_exm = gr.Examples(
|
536 |
+
examples=[
|
537 |
+
os.path.join(root_path, "condition", "person", _)
|
538 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
|
539 |
+
],
|
540 |
+
examples_per_page=4,
|
541 |
+
inputs=cloth_image_p2p,
|
542 |
+
label="Condition Reference Person Examples",
|
543 |
+
)
|
544 |
+
gr.Markdown(
|
545 |
+
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
|
546 |
+
)
|
547 |
+
|
548 |
+
image_path_p2p.change(
|
549 |
+
person_example_fn, inputs=image_path_p2p, outputs=person_image_p2p
|
550 |
+
)
|
551 |
+
|
552 |
+
submit_p2p.click(
|
553 |
+
submit_function_p2p,
|
554 |
+
[
|
555 |
+
person_image_p2p,
|
556 |
+
cloth_image_p2p,
|
557 |
+
num_inference_steps_p2p,
|
558 |
+
guidance_scale_p2p,
|
559 |
+
seed_p2p],
|
560 |
+
result_image_p2p,
|
561 |
+
)
|
562 |
+
|
563 |
+
demo.queue().launch(share=True, show_error=True)
|
564 |
+
|
565 |
+
|
566 |
+
if __name__ == "__main__":
|
567 |
+
app_gradio()
|
densepose/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
from .data.datasets import builtin # just to register data
|
5 |
+
from .converters import builtin as builtin_converters # register converters
|
6 |
+
from .config import (
|
7 |
+
add_densepose_config,
|
8 |
+
add_densepose_head_config,
|
9 |
+
add_hrnet_config,
|
10 |
+
add_dataset_category_config,
|
11 |
+
add_bootstrap_config,
|
12 |
+
load_bootstrap_config,
|
13 |
+
)
|
14 |
+
from .structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
|
15 |
+
from .evaluation import DensePoseCOCOEvaluator
|
16 |
+
from .modeling.roi_heads import DensePoseROIHeads
|
17 |
+
from .modeling.test_time_augmentation import (
|
18 |
+
DensePoseGeneralizedRCNNWithTTA,
|
19 |
+
DensePoseDatasetMapperTTA,
|
20 |
+
)
|
21 |
+
from .utils.transform import load_from_cfg
|
22 |
+
from .modeling.hrfpn import build_hrfpn_backbone
|
densepose/config.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding = utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
# pyre-ignore-all-errors
|
4 |
+
|
5 |
+
from detectron2.config import CfgNode as CN
|
6 |
+
|
7 |
+
|
8 |
+
def add_dataset_category_config(cfg: CN) -> None:
|
9 |
+
"""
|
10 |
+
Add config for additional category-related dataset options
|
11 |
+
- category whitelisting
|
12 |
+
- category mapping
|
13 |
+
"""
|
14 |
+
_C = cfg
|
15 |
+
_C.DATASETS.CATEGORY_MAPS = CN(new_allowed=True)
|
16 |
+
_C.DATASETS.WHITELISTED_CATEGORIES = CN(new_allowed=True)
|
17 |
+
# class to mesh mapping
|
18 |
+
_C.DATASETS.CLASS_TO_MESH_NAME_MAPPING = CN(new_allowed=True)
|
19 |
+
|
20 |
+
|
21 |
+
def add_evaluation_config(cfg: CN) -> None:
|
22 |
+
_C = cfg
|
23 |
+
_C.DENSEPOSE_EVALUATION = CN()
|
24 |
+
# evaluator type, possible values:
|
25 |
+
# - "iou": evaluator for models that produce iou data
|
26 |
+
# - "cse": evaluator for models that produce cse data
|
27 |
+
_C.DENSEPOSE_EVALUATION.TYPE = "iou"
|
28 |
+
# storage for DensePose results, possible values:
|
29 |
+
# - "none": no explicit storage, all the results are stored in the
|
30 |
+
# dictionary with predictions, memory intensive;
|
31 |
+
# historically the default storage type
|
32 |
+
# - "ram": RAM storage, uses per-process RAM storage, which is
|
33 |
+
# reduced to a single process storage on later stages,
|
34 |
+
# less memory intensive
|
35 |
+
# - "file": file storage, uses per-process file-based storage,
|
36 |
+
# the least memory intensive, but may create bottlenecks
|
37 |
+
# on file system accesses
|
38 |
+
_C.DENSEPOSE_EVALUATION.STORAGE = "none"
|
39 |
+
# minimum threshold for IOU values: the lower its values is,
|
40 |
+
# the more matches are produced (and the higher the AP score)
|
41 |
+
_C.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD = 0.5
|
42 |
+
# Non-distributed inference is slower (at inference time) but can avoid RAM OOM
|
43 |
+
_C.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE = True
|
44 |
+
# evaluate mesh alignment based on vertex embeddings, only makes sense in CSE context
|
45 |
+
_C.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT = False
|
46 |
+
# meshes to compute mesh alignment for
|
47 |
+
_C.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES = []
|
48 |
+
|
49 |
+
|
50 |
+
def add_bootstrap_config(cfg: CN) -> None:
|
51 |
+
""" """
|
52 |
+
_C = cfg
|
53 |
+
_C.BOOTSTRAP_DATASETS = []
|
54 |
+
_C.BOOTSTRAP_MODEL = CN()
|
55 |
+
_C.BOOTSTRAP_MODEL.WEIGHTS = ""
|
56 |
+
_C.BOOTSTRAP_MODEL.DEVICE = "cuda"
|
57 |
+
|
58 |
+
|
59 |
+
def get_bootstrap_dataset_config() -> CN:
|
60 |
+
_C = CN()
|
61 |
+
_C.DATASET = ""
|
62 |
+
# ratio used to mix data loaders
|
63 |
+
_C.RATIO = 0.1
|
64 |
+
# image loader
|
65 |
+
_C.IMAGE_LOADER = CN(new_allowed=True)
|
66 |
+
_C.IMAGE_LOADER.TYPE = ""
|
67 |
+
_C.IMAGE_LOADER.BATCH_SIZE = 4
|
68 |
+
_C.IMAGE_LOADER.NUM_WORKERS = 4
|
69 |
+
_C.IMAGE_LOADER.CATEGORIES = []
|
70 |
+
_C.IMAGE_LOADER.MAX_COUNT_PER_CATEGORY = 1_000_000
|
71 |
+
_C.IMAGE_LOADER.CATEGORY_TO_CLASS_MAPPING = CN(new_allowed=True)
|
72 |
+
# inference
|
73 |
+
_C.INFERENCE = CN()
|
74 |
+
# batch size for model inputs
|
75 |
+
_C.INFERENCE.INPUT_BATCH_SIZE = 4
|
76 |
+
# batch size to group model outputs
|
77 |
+
_C.INFERENCE.OUTPUT_BATCH_SIZE = 2
|
78 |
+
# sampled data
|
79 |
+
_C.DATA_SAMPLER = CN(new_allowed=True)
|
80 |
+
_C.DATA_SAMPLER.TYPE = ""
|
81 |
+
_C.DATA_SAMPLER.USE_GROUND_TRUTH_CATEGORIES = False
|
82 |
+
# filter
|
83 |
+
_C.FILTER = CN(new_allowed=True)
|
84 |
+
_C.FILTER.TYPE = ""
|
85 |
+
return _C
|
86 |
+
|
87 |
+
|
88 |
+
def load_bootstrap_config(cfg: CN) -> None:
|
89 |
+
"""
|
90 |
+
Bootstrap datasets are given as a list of `dict` that are not automatically
|
91 |
+
converted into CfgNode. This method processes all bootstrap dataset entries
|
92 |
+
and ensures that they are in CfgNode format and comply with the specification
|
93 |
+
"""
|
94 |
+
if not cfg.BOOTSTRAP_DATASETS:
|
95 |
+
return
|
96 |
+
|
97 |
+
bootstrap_datasets_cfgnodes = []
|
98 |
+
for dataset_cfg in cfg.BOOTSTRAP_DATASETS:
|
99 |
+
_C = get_bootstrap_dataset_config().clone()
|
100 |
+
_C.merge_from_other_cfg(CN(dataset_cfg))
|
101 |
+
bootstrap_datasets_cfgnodes.append(_C)
|
102 |
+
cfg.BOOTSTRAP_DATASETS = bootstrap_datasets_cfgnodes
|
103 |
+
|
104 |
+
|
105 |
+
def add_densepose_head_cse_config(cfg: CN) -> None:
|
106 |
+
"""
|
107 |
+
Add configuration options for Continuous Surface Embeddings (CSE)
|
108 |
+
"""
|
109 |
+
_C = cfg
|
110 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE = CN()
|
111 |
+
# Dimensionality D of the embedding space
|
112 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE = 16
|
113 |
+
# Embedder specifications for various mesh IDs
|
114 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS = CN(new_allowed=True)
|
115 |
+
# normalization coefficient for embedding distances
|
116 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA = 0.01
|
117 |
+
# normalization coefficient for geodesic distances
|
118 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA = 0.01
|
119 |
+
# embedding loss weight
|
120 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT = 0.6
|
121 |
+
# embedding loss name, currently the following options are supported:
|
122 |
+
# - EmbeddingLoss: cross-entropy on vertex labels
|
123 |
+
# - SoftEmbeddingLoss: cross-entropy on vertex label combined with
|
124 |
+
# Gaussian penalty on distance between vertices
|
125 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME = "EmbeddingLoss"
|
126 |
+
# optimizer hyperparameters
|
127 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR = 1.0
|
128 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR = 1.0
|
129 |
+
# Shape to shape cycle consistency loss parameters:
|
130 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
|
131 |
+
# shape to shape cycle consistency loss weight
|
132 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.025
|
133 |
+
# norm type used for loss computation
|
134 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
|
135 |
+
# normalization term for embedding similarity matrices
|
136 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE = 0.05
|
137 |
+
# maximum number of vertices to include into shape to shape cycle loss
|
138 |
+
# if negative or zero, all vertices are considered
|
139 |
+
# if positive, random subset of vertices of given size is considered
|
140 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES = 4936
|
141 |
+
# Pixel to shape cycle consistency loss parameters:
|
142 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
|
143 |
+
# pixel to shape cycle consistency loss weight
|
144 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.0001
|
145 |
+
# norm type used for loss computation
|
146 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
|
147 |
+
# map images to all meshes and back (if false, use only gt meshes from the batch)
|
148 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY = False
|
149 |
+
# Randomly select at most this number of pixels from every instance
|
150 |
+
# if negative or zero, all vertices are considered
|
151 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE = 100
|
152 |
+
# normalization factor for pixel to pixel distances (higher value = smoother distribution)
|
153 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA = 5.0
|
154 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX = 0.05
|
155 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL = 0.05
|
156 |
+
|
157 |
+
|
158 |
+
def add_densepose_head_config(cfg: CN) -> None:
|
159 |
+
"""
|
160 |
+
Add config for densepose head.
|
161 |
+
"""
|
162 |
+
_C = cfg
|
163 |
+
|
164 |
+
_C.MODEL.DENSEPOSE_ON = True
|
165 |
+
|
166 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD = CN()
|
167 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NAME = ""
|
168 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS = 8
|
169 |
+
# Number of parts used for point labels
|
170 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES = 24
|
171 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL = 4
|
172 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM = 512
|
173 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL = 3
|
174 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE = 2
|
175 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE = 112
|
176 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE = "ROIAlignV2"
|
177 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION = 28
|
178 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO = 2
|
179 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS = 2 # 15 or 2
|
180 |
+
# Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD)
|
181 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD = 0.7
|
182 |
+
# Loss weights for annotation masks.(14 Parts)
|
183 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS = 5.0
|
184 |
+
# Loss weights for surface parts. (24 Parts)
|
185 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS = 1.0
|
186 |
+
# Loss weights for UV regression.
|
187 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS = 0.01
|
188 |
+
# Coarse segmentation is trained using instance segmentation task data
|
189 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS = False
|
190 |
+
# For Decoder
|
191 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_ON = True
|
192 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NUM_CLASSES = 256
|
193 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_CONV_DIMS = 256
|
194 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NORM = ""
|
195 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_COMMON_STRIDE = 4
|
196 |
+
# For DeepLab head
|
197 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB = CN()
|
198 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM = "GN"
|
199 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON = 0
|
200 |
+
# Predictor class name, must be registered in DENSEPOSE_PREDICTOR_REGISTRY
|
201 |
+
# Some registered predictors:
|
202 |
+
# "DensePoseChartPredictor": predicts segmentation and UV coordinates for predefined charts
|
203 |
+
# "DensePoseChartWithConfidencePredictor": predicts segmentation, UV coordinates
|
204 |
+
# and associated confidences for predefined charts (default)
|
205 |
+
# "DensePoseEmbeddingWithConfidencePredictor": predicts segmentation, embeddings
|
206 |
+
# and associated confidences for CSE
|
207 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME = "DensePoseChartWithConfidencePredictor"
|
208 |
+
# Loss class name, must be registered in DENSEPOSE_LOSS_REGISTRY
|
209 |
+
# Some registered losses:
|
210 |
+
# "DensePoseChartLoss": loss for chart-based models that estimate
|
211 |
+
# segmentation and UV coordinates
|
212 |
+
# "DensePoseChartWithConfidenceLoss": loss for chart-based models that estimate
|
213 |
+
# segmentation, UV coordinates and the corresponding confidences (default)
|
214 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME = "DensePoseChartWithConfidenceLoss"
|
215 |
+
# Confidences
|
216 |
+
# Enable learning UV confidences (variances) along with the actual values
|
217 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE = CN({"ENABLED": False})
|
218 |
+
# UV confidence lower bound
|
219 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.EPSILON = 0.01
|
220 |
+
# Enable learning segmentation confidences (variances) along with the actual values
|
221 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE = CN({"ENABLED": False})
|
222 |
+
# Segmentation confidence lower bound
|
223 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.EPSILON = 0.01
|
224 |
+
# Statistical model type for confidence learning, possible values:
|
225 |
+
# - "iid_iso": statistically independent identically distributed residuals
|
226 |
+
# with isotropic covariance
|
227 |
+
# - "indep_aniso": statistically independent residuals with anisotropic
|
228 |
+
# covariances
|
229 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.TYPE = "iid_iso"
|
230 |
+
# List of angles for rotation in data augmentation during training
|
231 |
+
_C.INPUT.ROTATION_ANGLES = [0]
|
232 |
+
_C.TEST.AUG.ROTATION_ANGLES = () # Rotation TTA
|
233 |
+
|
234 |
+
add_densepose_head_cse_config(cfg)
|
235 |
+
|
236 |
+
|
237 |
+
def add_hrnet_config(cfg: CN) -> None:
|
238 |
+
"""
|
239 |
+
Add config for HRNet backbone.
|
240 |
+
"""
|
241 |
+
_C = cfg
|
242 |
+
|
243 |
+
# For HigherHRNet w32
|
244 |
+
_C.MODEL.HRNET = CN()
|
245 |
+
_C.MODEL.HRNET.STEM_INPLANES = 64
|
246 |
+
_C.MODEL.HRNET.STAGE2 = CN()
|
247 |
+
_C.MODEL.HRNET.STAGE2.NUM_MODULES = 1
|
248 |
+
_C.MODEL.HRNET.STAGE2.NUM_BRANCHES = 2
|
249 |
+
_C.MODEL.HRNET.STAGE2.BLOCK = "BASIC"
|
250 |
+
_C.MODEL.HRNET.STAGE2.NUM_BLOCKS = [4, 4]
|
251 |
+
_C.MODEL.HRNET.STAGE2.NUM_CHANNELS = [32, 64]
|
252 |
+
_C.MODEL.HRNET.STAGE2.FUSE_METHOD = "SUM"
|
253 |
+
_C.MODEL.HRNET.STAGE3 = CN()
|
254 |
+
_C.MODEL.HRNET.STAGE3.NUM_MODULES = 4
|
255 |
+
_C.MODEL.HRNET.STAGE3.NUM_BRANCHES = 3
|
256 |
+
_C.MODEL.HRNET.STAGE3.BLOCK = "BASIC"
|
257 |
+
_C.MODEL.HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
|
258 |
+
_C.MODEL.HRNET.STAGE3.NUM_CHANNELS = [32, 64, 128]
|
259 |
+
_C.MODEL.HRNET.STAGE3.FUSE_METHOD = "SUM"
|
260 |
+
_C.MODEL.HRNET.STAGE4 = CN()
|
261 |
+
_C.MODEL.HRNET.STAGE4.NUM_MODULES = 3
|
262 |
+
_C.MODEL.HRNET.STAGE4.NUM_BRANCHES = 4
|
263 |
+
_C.MODEL.HRNET.STAGE4.BLOCK = "BASIC"
|
264 |
+
_C.MODEL.HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
|
265 |
+
_C.MODEL.HRNET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
|
266 |
+
_C.MODEL.HRNET.STAGE4.FUSE_METHOD = "SUM"
|
267 |
+
|
268 |
+
_C.MODEL.HRNET.HRFPN = CN()
|
269 |
+
_C.MODEL.HRNET.HRFPN.OUT_CHANNELS = 256
|
270 |
+
|
271 |
+
|
272 |
+
def add_densepose_config(cfg: CN) -> None:
|
273 |
+
add_densepose_head_config(cfg)
|
274 |
+
add_hrnet_config(cfg)
|
275 |
+
add_bootstrap_config(cfg)
|
276 |
+
add_dataset_category_config(cfg)
|
277 |
+
add_evaluation_config(cfg)
|
densepose/converters/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from .hflip import HFlipConverter
|
6 |
+
from .to_mask import ToMaskConverter
|
7 |
+
from .to_chart_result import ToChartResultConverter, ToChartResultConverterWithConfidences
|
8 |
+
from .segm_to_mask import (
|
9 |
+
predictor_output_with_fine_and_coarse_segm_to_mask,
|
10 |
+
predictor_output_with_coarse_segm_to_mask,
|
11 |
+
resample_fine_and_coarse_segm_to_bbox,
|
12 |
+
)
|
13 |
+
from .chart_output_to_chart_result import (
|
14 |
+
densepose_chart_predictor_output_to_result,
|
15 |
+
densepose_chart_predictor_output_to_result_with_confidences,
|
16 |
+
)
|
17 |
+
from .chart_output_hflip import densepose_chart_predictor_output_hflip
|
densepose/converters/base.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from typing import Any, Tuple, Type
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class BaseConverter:
|
10 |
+
"""
|
11 |
+
Converter base class to be reused by various converters.
|
12 |
+
Converter allows one to convert data from various source types to a particular
|
13 |
+
destination type. Each source type needs to register its converter. The
|
14 |
+
registration for each source type is valid for all descendants of that type.
|
15 |
+
"""
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def register(cls, from_type: Type, converter: Any = None):
|
19 |
+
"""
|
20 |
+
Registers a converter for the specified type.
|
21 |
+
Can be used as a decorator (if converter is None), or called as a method.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
from_type (type): type to register the converter for;
|
25 |
+
all instances of this type will use the same converter
|
26 |
+
converter (callable): converter to be registered for the given
|
27 |
+
type; if None, this method is assumed to be a decorator for the converter
|
28 |
+
"""
|
29 |
+
|
30 |
+
if converter is not None:
|
31 |
+
cls._do_register(from_type, converter)
|
32 |
+
|
33 |
+
def wrapper(converter: Any) -> Any:
|
34 |
+
cls._do_register(from_type, converter)
|
35 |
+
return converter
|
36 |
+
|
37 |
+
return wrapper
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def _do_register(cls, from_type: Type, converter: Any):
|
41 |
+
cls.registry[from_type] = converter # pyre-ignore[16]
|
42 |
+
|
43 |
+
@classmethod
|
44 |
+
def _lookup_converter(cls, from_type: Type) -> Any:
|
45 |
+
"""
|
46 |
+
Perform recursive lookup for the given type
|
47 |
+
to find registered converter. If a converter was found for some base
|
48 |
+
class, it gets registered for this class to save on further lookups.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
from_type: type for which to find a converter
|
52 |
+
Return:
|
53 |
+
callable or None - registered converter or None
|
54 |
+
if no suitable entry was found in the registry
|
55 |
+
"""
|
56 |
+
if from_type in cls.registry: # pyre-ignore[16]
|
57 |
+
return cls.registry[from_type]
|
58 |
+
for base in from_type.__bases__:
|
59 |
+
converter = cls._lookup_converter(base)
|
60 |
+
if converter is not None:
|
61 |
+
cls._do_register(from_type, converter)
|
62 |
+
return converter
|
63 |
+
return None
|
64 |
+
|
65 |
+
@classmethod
|
66 |
+
def convert(cls, instance: Any, *args, **kwargs):
|
67 |
+
"""
|
68 |
+
Convert an instance to the destination type using some registered
|
69 |
+
converter. Does recursive lookup for base classes, so there's no need
|
70 |
+
for explicit registration for derived classes.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
instance: source instance to convert to the destination type
|
74 |
+
Return:
|
75 |
+
An instance of the destination type obtained from the source instance
|
76 |
+
Raises KeyError, if no suitable converter found
|
77 |
+
"""
|
78 |
+
instance_type = type(instance)
|
79 |
+
converter = cls._lookup_converter(instance_type)
|
80 |
+
if converter is None:
|
81 |
+
if cls.dst_type is None: # pyre-ignore[16]
|
82 |
+
output_type_str = "itself"
|
83 |
+
else:
|
84 |
+
output_type_str = cls.dst_type
|
85 |
+
raise KeyError(f"Could not find converter from {instance_type} to {output_type_str}")
|
86 |
+
return converter(instance, *args, **kwargs)
|
87 |
+
|
88 |
+
|
89 |
+
IntTupleBox = Tuple[int, int, int, int]
|
90 |
+
|
91 |
+
|
92 |
+
def make_int_box(box: torch.Tensor) -> IntTupleBox:
|
93 |
+
int_box = [0, 0, 0, 0]
|
94 |
+
int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
|
95 |
+
return int_box[0], int_box[1], int_box[2], int_box[3]
|
densepose/converters/builtin.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from ..structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
|
6 |
+
from . import (
|
7 |
+
HFlipConverter,
|
8 |
+
ToChartResultConverter,
|
9 |
+
ToChartResultConverterWithConfidences,
|
10 |
+
ToMaskConverter,
|
11 |
+
densepose_chart_predictor_output_hflip,
|
12 |
+
densepose_chart_predictor_output_to_result,
|
13 |
+
densepose_chart_predictor_output_to_result_with_confidences,
|
14 |
+
predictor_output_with_coarse_segm_to_mask,
|
15 |
+
predictor_output_with_fine_and_coarse_segm_to_mask,
|
16 |
+
)
|
17 |
+
|
18 |
+
ToMaskConverter.register(
|
19 |
+
DensePoseChartPredictorOutput, predictor_output_with_fine_and_coarse_segm_to_mask
|
20 |
+
)
|
21 |
+
ToMaskConverter.register(
|
22 |
+
DensePoseEmbeddingPredictorOutput, predictor_output_with_coarse_segm_to_mask
|
23 |
+
)
|
24 |
+
|
25 |
+
ToChartResultConverter.register(
|
26 |
+
DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result
|
27 |
+
)
|
28 |
+
|
29 |
+
ToChartResultConverterWithConfidences.register(
|
30 |
+
DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result_with_confidences
|
31 |
+
)
|
32 |
+
|
33 |
+
HFlipConverter.register(DensePoseChartPredictorOutput, densepose_chart_predictor_output_hflip)
|
densepose/converters/chart_output_hflip.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
from dataclasses import fields
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from densepose.structures import DensePoseChartPredictorOutput, DensePoseTransformData
|
8 |
+
|
9 |
+
|
10 |
+
def densepose_chart_predictor_output_hflip(
|
11 |
+
densepose_predictor_output: DensePoseChartPredictorOutput,
|
12 |
+
transform_data: DensePoseTransformData,
|
13 |
+
) -> DensePoseChartPredictorOutput:
|
14 |
+
"""
|
15 |
+
Change to take into account a Horizontal flip.
|
16 |
+
"""
|
17 |
+
if len(densepose_predictor_output) > 0:
|
18 |
+
|
19 |
+
PredictorOutput = type(densepose_predictor_output)
|
20 |
+
output_dict = {}
|
21 |
+
|
22 |
+
for field in fields(densepose_predictor_output):
|
23 |
+
field_value = getattr(densepose_predictor_output, field.name)
|
24 |
+
# flip tensors
|
25 |
+
if isinstance(field_value, torch.Tensor):
|
26 |
+
setattr(densepose_predictor_output, field.name, torch.flip(field_value, [3]))
|
27 |
+
|
28 |
+
densepose_predictor_output = _flip_iuv_semantics_tensor(
|
29 |
+
densepose_predictor_output, transform_data
|
30 |
+
)
|
31 |
+
densepose_predictor_output = _flip_segm_semantics_tensor(
|
32 |
+
densepose_predictor_output, transform_data
|
33 |
+
)
|
34 |
+
|
35 |
+
for field in fields(densepose_predictor_output):
|
36 |
+
output_dict[field.name] = getattr(densepose_predictor_output, field.name)
|
37 |
+
|
38 |
+
return PredictorOutput(**output_dict)
|
39 |
+
else:
|
40 |
+
return densepose_predictor_output
|
41 |
+
|
42 |
+
|
43 |
+
def _flip_iuv_semantics_tensor(
|
44 |
+
densepose_predictor_output: DensePoseChartPredictorOutput,
|
45 |
+
dp_transform_data: DensePoseTransformData,
|
46 |
+
) -> DensePoseChartPredictorOutput:
|
47 |
+
point_label_symmetries = dp_transform_data.point_label_symmetries
|
48 |
+
uv_symmetries = dp_transform_data.uv_symmetries
|
49 |
+
|
50 |
+
N, C, H, W = densepose_predictor_output.u.shape
|
51 |
+
u_loc = (densepose_predictor_output.u[:, 1:, :, :].clamp(0, 1) * 255).long()
|
52 |
+
v_loc = (densepose_predictor_output.v[:, 1:, :, :].clamp(0, 1) * 255).long()
|
53 |
+
Iindex = torch.arange(C - 1, device=densepose_predictor_output.u.device)[
|
54 |
+
None, :, None, None
|
55 |
+
].expand(N, C - 1, H, W)
|
56 |
+
densepose_predictor_output.u[:, 1:, :, :] = uv_symmetries["U_transforms"][Iindex, v_loc, u_loc]
|
57 |
+
densepose_predictor_output.v[:, 1:, :, :] = uv_symmetries["V_transforms"][Iindex, v_loc, u_loc]
|
58 |
+
|
59 |
+
for el in ["fine_segm", "u", "v"]:
|
60 |
+
densepose_predictor_output.__dict__[el] = densepose_predictor_output.__dict__[el][
|
61 |
+
:, point_label_symmetries, :, :
|
62 |
+
]
|
63 |
+
return densepose_predictor_output
|
64 |
+
|
65 |
+
|
66 |
+
def _flip_segm_semantics_tensor(
|
67 |
+
densepose_predictor_output: DensePoseChartPredictorOutput, dp_transform_data
|
68 |
+
):
|
69 |
+
if densepose_predictor_output.coarse_segm.shape[1] > 2:
|
70 |
+
densepose_predictor_output.coarse_segm = densepose_predictor_output.coarse_segm[
|
71 |
+
:, dp_transform_data.mask_label_symmetries, :, :
|
72 |
+
]
|
73 |
+
return densepose_predictor_output
|
densepose/converters/chart_output_to_chart_result.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from typing import Dict
|
6 |
+
import torch
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from detectron2.structures.boxes import Boxes, BoxMode
|
10 |
+
|
11 |
+
from ..structures import (
|
12 |
+
DensePoseChartPredictorOutput,
|
13 |
+
DensePoseChartResult,
|
14 |
+
DensePoseChartResultWithConfidences,
|
15 |
+
)
|
16 |
+
from . import resample_fine_and_coarse_segm_to_bbox
|
17 |
+
from .base import IntTupleBox, make_int_box
|
18 |
+
|
19 |
+
|
20 |
+
def resample_uv_tensors_to_bbox(
|
21 |
+
u: torch.Tensor,
|
22 |
+
v: torch.Tensor,
|
23 |
+
labels: torch.Tensor,
|
24 |
+
box_xywh_abs: IntTupleBox,
|
25 |
+
) -> torch.Tensor:
|
26 |
+
"""
|
27 |
+
Resamples U and V coordinate estimates for the given bounding box
|
28 |
+
|
29 |
+
Args:
|
30 |
+
u (tensor [1, C, H, W] of float): U coordinates
|
31 |
+
v (tensor [1, C, H, W] of float): V coordinates
|
32 |
+
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
33 |
+
outputs for the given bounding box
|
34 |
+
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
35 |
+
Return:
|
36 |
+
Resampled U and V coordinates - a tensor [2, H, W] of float
|
37 |
+
"""
|
38 |
+
x, y, w, h = box_xywh_abs
|
39 |
+
w = max(int(w), 1)
|
40 |
+
h = max(int(h), 1)
|
41 |
+
u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
|
42 |
+
v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
|
43 |
+
uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
|
44 |
+
for part_id in range(1, u_bbox.size(1)):
|
45 |
+
uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
|
46 |
+
uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
|
47 |
+
return uv
|
48 |
+
|
49 |
+
|
50 |
+
def resample_uv_to_bbox(
|
51 |
+
predictor_output: DensePoseChartPredictorOutput,
|
52 |
+
labels: torch.Tensor,
|
53 |
+
box_xywh_abs: IntTupleBox,
|
54 |
+
) -> torch.Tensor:
|
55 |
+
"""
|
56 |
+
Resamples U and V coordinate estimates for the given bounding box
|
57 |
+
|
58 |
+
Args:
|
59 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
60 |
+
output to be resampled
|
61 |
+
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
62 |
+
outputs for the given bounding box
|
63 |
+
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
64 |
+
Return:
|
65 |
+
Resampled U and V coordinates - a tensor [2, H, W] of float
|
66 |
+
"""
|
67 |
+
return resample_uv_tensors_to_bbox(
|
68 |
+
predictor_output.u,
|
69 |
+
predictor_output.v,
|
70 |
+
labels,
|
71 |
+
box_xywh_abs,
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
def densepose_chart_predictor_output_to_result(
|
76 |
+
predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
|
77 |
+
) -> DensePoseChartResult:
|
78 |
+
"""
|
79 |
+
Convert densepose chart predictor outputs to results
|
80 |
+
|
81 |
+
Args:
|
82 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
83 |
+
output to be converted to results, must contain only 1 output
|
84 |
+
boxes (Boxes): bounding box that corresponds to the predictor output,
|
85 |
+
must contain only 1 bounding box
|
86 |
+
Return:
|
87 |
+
DensePose chart-based result (DensePoseChartResult)
|
88 |
+
"""
|
89 |
+
assert len(predictor_output) == 1 and len(boxes) == 1, (
|
90 |
+
f"Predictor output to result conversion can operate only single outputs"
|
91 |
+
f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
|
92 |
+
)
|
93 |
+
|
94 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
95 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
96 |
+
box_xywh = make_int_box(boxes_xywh_abs[0])
|
97 |
+
|
98 |
+
labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
|
99 |
+
uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
|
100 |
+
return DensePoseChartResult(labels=labels, uv=uv)
|
101 |
+
|
102 |
+
|
103 |
+
def resample_confidences_to_bbox(
|
104 |
+
predictor_output: DensePoseChartPredictorOutput,
|
105 |
+
labels: torch.Tensor,
|
106 |
+
box_xywh_abs: IntTupleBox,
|
107 |
+
) -> Dict[str, torch.Tensor]:
|
108 |
+
"""
|
109 |
+
Resamples confidences for the given bounding box
|
110 |
+
|
111 |
+
Args:
|
112 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
113 |
+
output to be resampled
|
114 |
+
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
115 |
+
outputs for the given bounding box
|
116 |
+
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
117 |
+
Return:
|
118 |
+
Resampled confidences - a dict of [H, W] tensors of float
|
119 |
+
"""
|
120 |
+
|
121 |
+
x, y, w, h = box_xywh_abs
|
122 |
+
w = max(int(w), 1)
|
123 |
+
h = max(int(h), 1)
|
124 |
+
|
125 |
+
confidence_names = [
|
126 |
+
"sigma_1",
|
127 |
+
"sigma_2",
|
128 |
+
"kappa_u",
|
129 |
+
"kappa_v",
|
130 |
+
"fine_segm_confidence",
|
131 |
+
"coarse_segm_confidence",
|
132 |
+
]
|
133 |
+
confidence_results = {key: None for key in confidence_names}
|
134 |
+
confidence_names = [
|
135 |
+
key for key in confidence_names if getattr(predictor_output, key) is not None
|
136 |
+
]
|
137 |
+
confidence_base = torch.zeros([h, w], dtype=torch.float32, device=predictor_output.u.device)
|
138 |
+
|
139 |
+
# assign data from channels that correspond to the labels
|
140 |
+
for key in confidence_names:
|
141 |
+
resampled_confidence = F.interpolate(
|
142 |
+
getattr(predictor_output, key),
|
143 |
+
(h, w),
|
144 |
+
mode="bilinear",
|
145 |
+
align_corners=False,
|
146 |
+
)
|
147 |
+
result = confidence_base.clone()
|
148 |
+
for part_id in range(1, predictor_output.u.size(1)):
|
149 |
+
if resampled_confidence.size(1) != predictor_output.u.size(1):
|
150 |
+
# confidence is not part-based, don't try to fill it part by part
|
151 |
+
continue
|
152 |
+
result[labels == part_id] = resampled_confidence[0, part_id][labels == part_id]
|
153 |
+
|
154 |
+
if resampled_confidence.size(1) != predictor_output.u.size(1):
|
155 |
+
# confidence is not part-based, fill the data with the first channel
|
156 |
+
# (targeted for segmentation confidences that have only 1 channel)
|
157 |
+
result = resampled_confidence[0, 0]
|
158 |
+
|
159 |
+
confidence_results[key] = result
|
160 |
+
|
161 |
+
return confidence_results # pyre-ignore[7]
|
162 |
+
|
163 |
+
|
164 |
+
def densepose_chart_predictor_output_to_result_with_confidences(
|
165 |
+
predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
|
166 |
+
) -> DensePoseChartResultWithConfidences:
|
167 |
+
"""
|
168 |
+
Convert densepose chart predictor outputs to results
|
169 |
+
|
170 |
+
Args:
|
171 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
172 |
+
output with confidences to be converted to results, must contain only 1 output
|
173 |
+
boxes (Boxes): bounding box that corresponds to the predictor output,
|
174 |
+
must contain only 1 bounding box
|
175 |
+
Return:
|
176 |
+
DensePose chart-based result with confidences (DensePoseChartResultWithConfidences)
|
177 |
+
"""
|
178 |
+
assert len(predictor_output) == 1 and len(boxes) == 1, (
|
179 |
+
f"Predictor output to result conversion can operate only single outputs"
|
180 |
+
f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
|
181 |
+
)
|
182 |
+
|
183 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
184 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
185 |
+
box_xywh = make_int_box(boxes_xywh_abs[0])
|
186 |
+
|
187 |
+
labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
|
188 |
+
uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
|
189 |
+
confidences = resample_confidences_to_bbox(predictor_output, labels, box_xywh)
|
190 |
+
return DensePoseChartResultWithConfidences(labels=labels, uv=uv, **confidences)
|
densepose/converters/hflip.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from typing import Any
|
6 |
+
|
7 |
+
from .base import BaseConverter
|
8 |
+
|
9 |
+
|
10 |
+
class HFlipConverter(BaseConverter):
|
11 |
+
"""
|
12 |
+
Converts various DensePose predictor outputs to DensePose results.
|
13 |
+
Each DensePose predictor output type has to register its convertion strategy.
|
14 |
+
"""
|
15 |
+
|
16 |
+
registry = {}
|
17 |
+
dst_type = None
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
21 |
+
# inconsistently.
|
22 |
+
def convert(cls, predictor_outputs: Any, transform_data: Any, *args, **kwargs):
|
23 |
+
"""
|
24 |
+
Performs an horizontal flip on DensePose predictor outputs.
|
25 |
+
Does recursive lookup for base classes, so there's no need
|
26 |
+
for explicit registration for derived classes.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
predictor_outputs: DensePose predictor output to be converted to BitMasks
|
30 |
+
transform_data: Anything useful for the flip
|
31 |
+
Return:
|
32 |
+
An instance of the same type as predictor_outputs
|
33 |
+
"""
|
34 |
+
return super(HFlipConverter, cls).convert(
|
35 |
+
predictor_outputs, transform_data, *args, **kwargs
|
36 |
+
)
|
densepose/converters/segm_to_mask.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from typing import Any
|
6 |
+
import torch
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from detectron2.structures import BitMasks, Boxes, BoxMode
|
10 |
+
|
11 |
+
from .base import IntTupleBox, make_int_box
|
12 |
+
from .to_mask import ImageSizeType
|
13 |
+
|
14 |
+
|
15 |
+
def resample_coarse_segm_tensor_to_bbox(coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox):
|
16 |
+
"""
|
17 |
+
Resample coarse segmentation tensor to the given
|
18 |
+
bounding box and derive labels for each pixel of the bounding box
|
19 |
+
|
20 |
+
Args:
|
21 |
+
coarse_segm: float tensor of shape [1, K, Hout, Wout]
|
22 |
+
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
23 |
+
corner coordinates, width (W) and height (H)
|
24 |
+
Return:
|
25 |
+
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
26 |
+
"""
|
27 |
+
x, y, w, h = box_xywh_abs
|
28 |
+
w = max(int(w), 1)
|
29 |
+
h = max(int(h), 1)
|
30 |
+
labels = F.interpolate(coarse_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
|
31 |
+
return labels
|
32 |
+
|
33 |
+
|
34 |
+
def resample_fine_and_coarse_segm_tensors_to_bbox(
|
35 |
+
fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
Resample fine and coarse segmentation tensors to the given
|
39 |
+
bounding box and derive labels for each pixel of the bounding box
|
40 |
+
|
41 |
+
Args:
|
42 |
+
fine_segm: float tensor of shape [1, C, Hout, Wout]
|
43 |
+
coarse_segm: float tensor of shape [1, K, Hout, Wout]
|
44 |
+
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
45 |
+
corner coordinates, width (W) and height (H)
|
46 |
+
Return:
|
47 |
+
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
48 |
+
"""
|
49 |
+
x, y, w, h = box_xywh_abs
|
50 |
+
w = max(int(w), 1)
|
51 |
+
h = max(int(h), 1)
|
52 |
+
# coarse segmentation
|
53 |
+
coarse_segm_bbox = F.interpolate(
|
54 |
+
coarse_segm,
|
55 |
+
(h, w),
|
56 |
+
mode="bilinear",
|
57 |
+
align_corners=False,
|
58 |
+
).argmax(dim=1)
|
59 |
+
# combined coarse and fine segmentation
|
60 |
+
labels = (
|
61 |
+
F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
|
62 |
+
* (coarse_segm_bbox > 0).long()
|
63 |
+
)
|
64 |
+
return labels
|
65 |
+
|
66 |
+
|
67 |
+
def resample_fine_and_coarse_segm_to_bbox(predictor_output: Any, box_xywh_abs: IntTupleBox):
|
68 |
+
"""
|
69 |
+
Resample fine and coarse segmentation outputs from a predictor to the given
|
70 |
+
bounding box and derive labels for each pixel of the bounding box
|
71 |
+
|
72 |
+
Args:
|
73 |
+
predictor_output: DensePose predictor output that contains segmentation
|
74 |
+
results to be resampled
|
75 |
+
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
76 |
+
corner coordinates, width (W) and height (H)
|
77 |
+
Return:
|
78 |
+
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
79 |
+
"""
|
80 |
+
return resample_fine_and_coarse_segm_tensors_to_bbox(
|
81 |
+
predictor_output.fine_segm,
|
82 |
+
predictor_output.coarse_segm,
|
83 |
+
box_xywh_abs,
|
84 |
+
)
|
85 |
+
|
86 |
+
|
87 |
+
def predictor_output_with_coarse_segm_to_mask(
|
88 |
+
predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
|
89 |
+
) -> BitMasks:
|
90 |
+
"""
|
91 |
+
Convert predictor output with coarse and fine segmentation to a mask.
|
92 |
+
Assumes that predictor output has the following attributes:
|
93 |
+
- coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
|
94 |
+
unnormalized scores for N instances; D is the number of coarse
|
95 |
+
segmentation labels, H and W is the resolution of the estimate
|
96 |
+
|
97 |
+
Args:
|
98 |
+
predictor_output: DensePose predictor output to be converted to mask
|
99 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
100 |
+
predictor outputs
|
101 |
+
image_size_hw (tuple [int, int]): image height Himg and width Wimg
|
102 |
+
Return:
|
103 |
+
BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
|
104 |
+
a mask of the size of the image for each instance
|
105 |
+
"""
|
106 |
+
H, W = image_size_hw
|
107 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
108 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
109 |
+
N = len(boxes_xywh_abs)
|
110 |
+
masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
|
111 |
+
for i in range(len(boxes_xywh_abs)):
|
112 |
+
box_xywh = make_int_box(boxes_xywh_abs[i])
|
113 |
+
box_mask = resample_coarse_segm_tensor_to_bbox(predictor_output[i].coarse_segm, box_xywh)
|
114 |
+
x, y, w, h = box_xywh
|
115 |
+
masks[i, y : y + h, x : x + w] = box_mask
|
116 |
+
|
117 |
+
return BitMasks(masks)
|
118 |
+
|
119 |
+
|
120 |
+
def predictor_output_with_fine_and_coarse_segm_to_mask(
|
121 |
+
predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
|
122 |
+
) -> BitMasks:
|
123 |
+
"""
|
124 |
+
Convert predictor output with coarse and fine segmentation to a mask.
|
125 |
+
Assumes that predictor output has the following attributes:
|
126 |
+
- coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
|
127 |
+
unnormalized scores for N instances; D is the number of coarse
|
128 |
+
segmentation labels, H and W is the resolution of the estimate
|
129 |
+
- fine_segm (tensor of size [N, C, H, W]): fine segmentation
|
130 |
+
unnormalized scores for N instances; C is the number of fine
|
131 |
+
segmentation labels, H and W is the resolution of the estimate
|
132 |
+
|
133 |
+
Args:
|
134 |
+
predictor_output: DensePose predictor output to be converted to mask
|
135 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
136 |
+
predictor outputs
|
137 |
+
image_size_hw (tuple [int, int]): image height Himg and width Wimg
|
138 |
+
Return:
|
139 |
+
BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
|
140 |
+
a mask of the size of the image for each instance
|
141 |
+
"""
|
142 |
+
H, W = image_size_hw
|
143 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
144 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
145 |
+
N = len(boxes_xywh_abs)
|
146 |
+
masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
|
147 |
+
for i in range(len(boxes_xywh_abs)):
|
148 |
+
box_xywh = make_int_box(boxes_xywh_abs[i])
|
149 |
+
labels_i = resample_fine_and_coarse_segm_to_bbox(predictor_output[i], box_xywh)
|
150 |
+
x, y, w, h = box_xywh
|
151 |
+
masks[i, y : y + h, x : x + w] = labels_i > 0
|
152 |
+
return BitMasks(masks)
|
densepose/converters/to_chart_result.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from typing import Any
|
6 |
+
|
7 |
+
from detectron2.structures import Boxes
|
8 |
+
|
9 |
+
from ..structures import DensePoseChartResult, DensePoseChartResultWithConfidences
|
10 |
+
from .base import BaseConverter
|
11 |
+
|
12 |
+
|
13 |
+
class ToChartResultConverter(BaseConverter):
|
14 |
+
"""
|
15 |
+
Converts various DensePose predictor outputs to DensePose results.
|
16 |
+
Each DensePose predictor output type has to register its convertion strategy.
|
17 |
+
"""
|
18 |
+
|
19 |
+
registry = {}
|
20 |
+
dst_type = DensePoseChartResult
|
21 |
+
|
22 |
+
@classmethod
|
23 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
24 |
+
# inconsistently.
|
25 |
+
def convert(cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs) -> DensePoseChartResult:
|
26 |
+
"""
|
27 |
+
Convert DensePose predictor outputs to DensePoseResult using some registered
|
28 |
+
converter. Does recursive lookup for base classes, so there's no need
|
29 |
+
for explicit registration for derived classes.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
densepose_predictor_outputs: DensePose predictor output to be
|
33 |
+
converted to BitMasks
|
34 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
35 |
+
predictor outputs
|
36 |
+
Return:
|
37 |
+
An instance of DensePoseResult. If no suitable converter was found, raises KeyError
|
38 |
+
"""
|
39 |
+
return super(ToChartResultConverter, cls).convert(predictor_outputs, boxes, *args, **kwargs)
|
40 |
+
|
41 |
+
|
42 |
+
class ToChartResultConverterWithConfidences(BaseConverter):
|
43 |
+
"""
|
44 |
+
Converts various DensePose predictor outputs to DensePose results.
|
45 |
+
Each DensePose predictor output type has to register its convertion strategy.
|
46 |
+
"""
|
47 |
+
|
48 |
+
registry = {}
|
49 |
+
dst_type = DensePoseChartResultWithConfidences
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
53 |
+
# inconsistently.
|
54 |
+
def convert(
|
55 |
+
cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs
|
56 |
+
) -> DensePoseChartResultWithConfidences:
|
57 |
+
"""
|
58 |
+
Convert DensePose predictor outputs to DensePoseResult with confidences
|
59 |
+
using some registered converter. Does recursive lookup for base classes,
|
60 |
+
so there's no need for explicit registration for derived classes.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
densepose_predictor_outputs: DensePose predictor output with confidences
|
64 |
+
to be converted to BitMasks
|
65 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
66 |
+
predictor outputs
|
67 |
+
Return:
|
68 |
+
An instance of DensePoseResult. If no suitable converter was found, raises KeyError
|
69 |
+
"""
|
70 |
+
return super(ToChartResultConverterWithConfidences, cls).convert(
|
71 |
+
predictor_outputs, boxes, *args, **kwargs
|
72 |
+
)
|
densepose/converters/to_mask.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from typing import Any, Tuple
|
6 |
+
|
7 |
+
from detectron2.structures import BitMasks, Boxes
|
8 |
+
|
9 |
+
from .base import BaseConverter
|
10 |
+
|
11 |
+
ImageSizeType = Tuple[int, int]
|
12 |
+
|
13 |
+
|
14 |
+
class ToMaskConverter(BaseConverter):
|
15 |
+
"""
|
16 |
+
Converts various DensePose predictor outputs to masks
|
17 |
+
in bit mask format (see `BitMasks`). Each DensePose predictor output type
|
18 |
+
has to register its convertion strategy.
|
19 |
+
"""
|
20 |
+
|
21 |
+
registry = {}
|
22 |
+
dst_type = BitMasks
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
26 |
+
# inconsistently.
|
27 |
+
def convert(
|
28 |
+
cls,
|
29 |
+
densepose_predictor_outputs: Any,
|
30 |
+
boxes: Boxes,
|
31 |
+
image_size_hw: ImageSizeType,
|
32 |
+
*args,
|
33 |
+
**kwargs
|
34 |
+
) -> BitMasks:
|
35 |
+
"""
|
36 |
+
Convert DensePose predictor outputs to BitMasks using some registered
|
37 |
+
converter. Does recursive lookup for base classes, so there's no need
|
38 |
+
for explicit registration for derived classes.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
densepose_predictor_outputs: DensePose predictor output to be
|
42 |
+
converted to BitMasks
|
43 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
44 |
+
predictor outputs
|
45 |
+
image_size_hw (tuple [int, int]): image height and width
|
46 |
+
Return:
|
47 |
+
An instance of `BitMasks`. If no suitable converter was found, raises KeyError
|
48 |
+
"""
|
49 |
+
return super(ToMaskConverter, cls).convert(
|
50 |
+
densepose_predictor_outputs, boxes, image_size_hw, *args, **kwargs
|
51 |
+
)
|
densepose/data/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from .meshes import builtin
|
6 |
+
from .build import (
|
7 |
+
build_detection_test_loader,
|
8 |
+
build_detection_train_loader,
|
9 |
+
build_combined_loader,
|
10 |
+
build_frame_selector,
|
11 |
+
build_inference_based_loaders,
|
12 |
+
has_inference_based_loaders,
|
13 |
+
BootstrapDatasetFactoryCatalog,
|
14 |
+
)
|
15 |
+
from .combined_loader import CombinedDataLoader
|
16 |
+
from .dataset_mapper import DatasetMapper
|
17 |
+
from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
|
18 |
+
from .image_list_dataset import ImageListDataset
|
19 |
+
from .utils import is_relative_local_path, maybe_prepend_base_path
|
20 |
+
|
21 |
+
# ensure the builtin datasets are registered
|
22 |
+
from . import datasets
|
23 |
+
|
24 |
+
# ensure the bootstrap datasets builders are registered
|
25 |
+
from . import build
|
26 |
+
|
27 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
densepose/data/build.py
ADDED
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
import itertools
|
6 |
+
import logging
|
7 |
+
import numpy as np
|
8 |
+
from collections import UserDict, defaultdict
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple
|
11 |
+
import torch
|
12 |
+
from torch.utils.data.dataset import Dataset
|
13 |
+
|
14 |
+
from detectron2.config import CfgNode
|
15 |
+
from detectron2.data.build import build_detection_test_loader as d2_build_detection_test_loader
|
16 |
+
from detectron2.data.build import build_detection_train_loader as d2_build_detection_train_loader
|
17 |
+
from detectron2.data.build import (
|
18 |
+
load_proposals_into_dataset,
|
19 |
+
print_instances_class_histogram,
|
20 |
+
trivial_batch_collator,
|
21 |
+
worker_init_reset_seed,
|
22 |
+
)
|
23 |
+
from detectron2.data.catalog import DatasetCatalog, Metadata, MetadataCatalog
|
24 |
+
from detectron2.data.samplers import TrainingSampler
|
25 |
+
from detectron2.utils.comm import get_world_size
|
26 |
+
|
27 |
+
from densepose.config import get_bootstrap_dataset_config
|
28 |
+
from densepose.modeling import build_densepose_embedder
|
29 |
+
|
30 |
+
from .combined_loader import CombinedDataLoader, Loader
|
31 |
+
from .dataset_mapper import DatasetMapper
|
32 |
+
from .datasets.coco import DENSEPOSE_CSE_KEYS_WITHOUT_MASK, DENSEPOSE_IUV_KEYS_WITHOUT_MASK
|
33 |
+
from .datasets.dataset_type import DatasetType
|
34 |
+
from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
|
35 |
+
from .samplers import (
|
36 |
+
DensePoseConfidenceBasedSampler,
|
37 |
+
DensePoseCSEConfidenceBasedSampler,
|
38 |
+
DensePoseCSEUniformSampler,
|
39 |
+
DensePoseUniformSampler,
|
40 |
+
MaskFromDensePoseSampler,
|
41 |
+
PredictionToGroundTruthSampler,
|
42 |
+
)
|
43 |
+
from .transform import ImageResizeTransform
|
44 |
+
from .utils import get_category_to_class_mapping, get_class_to_mesh_name_mapping
|
45 |
+
from .video import (
|
46 |
+
FirstKFramesSelector,
|
47 |
+
FrameSelectionStrategy,
|
48 |
+
LastKFramesSelector,
|
49 |
+
RandomKFramesSelector,
|
50 |
+
VideoKeyframeDataset,
|
51 |
+
video_list_from_file,
|
52 |
+
)
|
53 |
+
|
54 |
+
__all__ = ["build_detection_train_loader", "build_detection_test_loader"]
|
55 |
+
|
56 |
+
|
57 |
+
Instance = Dict[str, Any]
|
58 |
+
InstancePredicate = Callable[[Instance], bool]
|
59 |
+
|
60 |
+
|
61 |
+
def _compute_num_images_per_worker(cfg: CfgNode) -> int:
|
62 |
+
num_workers = get_world_size()
|
63 |
+
images_per_batch = cfg.SOLVER.IMS_PER_BATCH
|
64 |
+
assert (
|
65 |
+
images_per_batch % num_workers == 0
|
66 |
+
), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
|
67 |
+
images_per_batch, num_workers
|
68 |
+
)
|
69 |
+
assert (
|
70 |
+
images_per_batch >= num_workers
|
71 |
+
), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
|
72 |
+
images_per_batch, num_workers
|
73 |
+
)
|
74 |
+
images_per_worker = images_per_batch // num_workers
|
75 |
+
return images_per_worker
|
76 |
+
|
77 |
+
|
78 |
+
def _map_category_id_to_contiguous_id(dataset_name: str, dataset_dicts: Iterable[Instance]) -> None:
|
79 |
+
meta = MetadataCatalog.get(dataset_name)
|
80 |
+
for dataset_dict in dataset_dicts:
|
81 |
+
for ann in dataset_dict["annotations"]:
|
82 |
+
ann["category_id"] = meta.thing_dataset_id_to_contiguous_id[ann["category_id"]]
|
83 |
+
|
84 |
+
|
85 |
+
@dataclass
|
86 |
+
class _DatasetCategory:
|
87 |
+
"""
|
88 |
+
Class representing category data in a dataset:
|
89 |
+
- id: category ID, as specified in the dataset annotations file
|
90 |
+
- name: category name, as specified in the dataset annotations file
|
91 |
+
- mapped_id: category ID after applying category maps (DATASETS.CATEGORY_MAPS config option)
|
92 |
+
- mapped_name: category name after applying category maps
|
93 |
+
- dataset_name: dataset in which the category is defined
|
94 |
+
|
95 |
+
For example, when training models in a class-agnostic manner, one could take LVIS 1.0
|
96 |
+
dataset and map the animal categories to the same category as human data from COCO:
|
97 |
+
id = 225
|
98 |
+
name = "cat"
|
99 |
+
mapped_id = 1
|
100 |
+
mapped_name = "person"
|
101 |
+
dataset_name = "lvis_v1_animals_dp_train"
|
102 |
+
"""
|
103 |
+
|
104 |
+
id: int
|
105 |
+
name: str
|
106 |
+
mapped_id: int
|
107 |
+
mapped_name: str
|
108 |
+
dataset_name: str
|
109 |
+
|
110 |
+
|
111 |
+
_MergedCategoriesT = Dict[int, List[_DatasetCategory]]
|
112 |
+
|
113 |
+
|
114 |
+
def _add_category_id_to_contiguous_id_maps_to_metadata(
|
115 |
+
merged_categories: _MergedCategoriesT,
|
116 |
+
) -> None:
|
117 |
+
merged_categories_per_dataset = {}
|
118 |
+
for contiguous_cat_id, cat_id in enumerate(sorted(merged_categories.keys())):
|
119 |
+
for cat in merged_categories[cat_id]:
|
120 |
+
if cat.dataset_name not in merged_categories_per_dataset:
|
121 |
+
merged_categories_per_dataset[cat.dataset_name] = defaultdict(list)
|
122 |
+
merged_categories_per_dataset[cat.dataset_name][cat_id].append(
|
123 |
+
(
|
124 |
+
contiguous_cat_id,
|
125 |
+
cat,
|
126 |
+
)
|
127 |
+
)
|
128 |
+
|
129 |
+
logger = logging.getLogger(__name__)
|
130 |
+
for dataset_name, merged_categories in merged_categories_per_dataset.items():
|
131 |
+
meta = MetadataCatalog.get(dataset_name)
|
132 |
+
if not hasattr(meta, "thing_classes"):
|
133 |
+
meta.thing_classes = []
|
134 |
+
meta.thing_dataset_id_to_contiguous_id = {}
|
135 |
+
meta.thing_dataset_id_to_merged_id = {}
|
136 |
+
else:
|
137 |
+
meta.thing_classes.clear()
|
138 |
+
meta.thing_dataset_id_to_contiguous_id.clear()
|
139 |
+
meta.thing_dataset_id_to_merged_id.clear()
|
140 |
+
logger.info(f"Dataset {dataset_name}: category ID to contiguous ID mapping:")
|
141 |
+
for _cat_id, categories in sorted(merged_categories.items()):
|
142 |
+
added_to_thing_classes = False
|
143 |
+
for contiguous_cat_id, cat in categories:
|
144 |
+
if not added_to_thing_classes:
|
145 |
+
meta.thing_classes.append(cat.mapped_name)
|
146 |
+
added_to_thing_classes = True
|
147 |
+
meta.thing_dataset_id_to_contiguous_id[cat.id] = contiguous_cat_id
|
148 |
+
meta.thing_dataset_id_to_merged_id[cat.id] = cat.mapped_id
|
149 |
+
logger.info(f"{cat.id} ({cat.name}) -> {contiguous_cat_id}")
|
150 |
+
|
151 |
+
|
152 |
+
def _maybe_create_general_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
153 |
+
def has_annotations(instance: Instance) -> bool:
|
154 |
+
return "annotations" in instance
|
155 |
+
|
156 |
+
def has_only_crowd_anotations(instance: Instance) -> bool:
|
157 |
+
for ann in instance["annotations"]:
|
158 |
+
if ann.get("is_crowd", 0) == 0:
|
159 |
+
return False
|
160 |
+
return True
|
161 |
+
|
162 |
+
def general_keep_instance_predicate(instance: Instance) -> bool:
|
163 |
+
return has_annotations(instance) and not has_only_crowd_anotations(instance)
|
164 |
+
|
165 |
+
if not cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS:
|
166 |
+
return None
|
167 |
+
return general_keep_instance_predicate
|
168 |
+
|
169 |
+
|
170 |
+
def _maybe_create_keypoints_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
171 |
+
|
172 |
+
min_num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
173 |
+
|
174 |
+
def has_sufficient_num_keypoints(instance: Instance) -> bool:
|
175 |
+
num_kpts = sum(
|
176 |
+
(np.array(ann["keypoints"][2::3]) > 0).sum()
|
177 |
+
for ann in instance["annotations"]
|
178 |
+
if "keypoints" in ann
|
179 |
+
)
|
180 |
+
return num_kpts >= min_num_keypoints
|
181 |
+
|
182 |
+
if cfg.MODEL.KEYPOINT_ON and (min_num_keypoints > 0):
|
183 |
+
return has_sufficient_num_keypoints
|
184 |
+
return None
|
185 |
+
|
186 |
+
|
187 |
+
def _maybe_create_mask_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
188 |
+
if not cfg.MODEL.MASK_ON:
|
189 |
+
return None
|
190 |
+
|
191 |
+
def has_mask_annotations(instance: Instance) -> bool:
|
192 |
+
return any("segmentation" in ann for ann in instance["annotations"])
|
193 |
+
|
194 |
+
return has_mask_annotations
|
195 |
+
|
196 |
+
|
197 |
+
def _maybe_create_densepose_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
198 |
+
if not cfg.MODEL.DENSEPOSE_ON:
|
199 |
+
return None
|
200 |
+
|
201 |
+
use_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
|
202 |
+
|
203 |
+
def has_densepose_annotations(instance: Instance) -> bool:
|
204 |
+
for ann in instance["annotations"]:
|
205 |
+
if all(key in ann for key in DENSEPOSE_IUV_KEYS_WITHOUT_MASK) or all(
|
206 |
+
key in ann for key in DENSEPOSE_CSE_KEYS_WITHOUT_MASK
|
207 |
+
):
|
208 |
+
return True
|
209 |
+
if use_masks and "segmentation" in ann:
|
210 |
+
return True
|
211 |
+
return False
|
212 |
+
|
213 |
+
return has_densepose_annotations
|
214 |
+
|
215 |
+
|
216 |
+
def _maybe_create_specific_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
217 |
+
specific_predicate_creators = [
|
218 |
+
_maybe_create_keypoints_keep_instance_predicate,
|
219 |
+
_maybe_create_mask_keep_instance_predicate,
|
220 |
+
_maybe_create_densepose_keep_instance_predicate,
|
221 |
+
]
|
222 |
+
predicates = [creator(cfg) for creator in specific_predicate_creators]
|
223 |
+
predicates = [p for p in predicates if p is not None]
|
224 |
+
if not predicates:
|
225 |
+
return None
|
226 |
+
|
227 |
+
def combined_predicate(instance: Instance) -> bool:
|
228 |
+
return any(p(instance) for p in predicates)
|
229 |
+
|
230 |
+
return combined_predicate
|
231 |
+
|
232 |
+
|
233 |
+
def _get_train_keep_instance_predicate(cfg: CfgNode):
|
234 |
+
general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
|
235 |
+
combined_specific_keep_predicate = _maybe_create_specific_keep_instance_predicate(cfg)
|
236 |
+
|
237 |
+
def combined_general_specific_keep_predicate(instance: Instance) -> bool:
|
238 |
+
return general_keep_predicate(instance) and combined_specific_keep_predicate(instance)
|
239 |
+
|
240 |
+
if (general_keep_predicate is None) and (combined_specific_keep_predicate is None):
|
241 |
+
return None
|
242 |
+
if general_keep_predicate is None:
|
243 |
+
return combined_specific_keep_predicate
|
244 |
+
if combined_specific_keep_predicate is None:
|
245 |
+
return general_keep_predicate
|
246 |
+
return combined_general_specific_keep_predicate
|
247 |
+
|
248 |
+
|
249 |
+
def _get_test_keep_instance_predicate(cfg: CfgNode):
|
250 |
+
general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
|
251 |
+
return general_keep_predicate
|
252 |
+
|
253 |
+
|
254 |
+
def _maybe_filter_and_map_categories(
|
255 |
+
dataset_name: str, dataset_dicts: List[Instance]
|
256 |
+
) -> List[Instance]:
|
257 |
+
meta = MetadataCatalog.get(dataset_name)
|
258 |
+
category_id_map = meta.thing_dataset_id_to_contiguous_id
|
259 |
+
filtered_dataset_dicts = []
|
260 |
+
for dataset_dict in dataset_dicts:
|
261 |
+
anns = []
|
262 |
+
for ann in dataset_dict["annotations"]:
|
263 |
+
cat_id = ann["category_id"]
|
264 |
+
if cat_id not in category_id_map:
|
265 |
+
continue
|
266 |
+
ann["category_id"] = category_id_map[cat_id]
|
267 |
+
anns.append(ann)
|
268 |
+
dataset_dict["annotations"] = anns
|
269 |
+
filtered_dataset_dicts.append(dataset_dict)
|
270 |
+
return filtered_dataset_dicts
|
271 |
+
|
272 |
+
|
273 |
+
def _add_category_whitelists_to_metadata(cfg: CfgNode) -> None:
|
274 |
+
for dataset_name, whitelisted_cat_ids in cfg.DATASETS.WHITELISTED_CATEGORIES.items():
|
275 |
+
meta = MetadataCatalog.get(dataset_name)
|
276 |
+
meta.whitelisted_categories = whitelisted_cat_ids
|
277 |
+
logger = logging.getLogger(__name__)
|
278 |
+
logger.info(
|
279 |
+
"Whitelisted categories for dataset {}: {}".format(
|
280 |
+
dataset_name, meta.whitelisted_categories
|
281 |
+
)
|
282 |
+
)
|
283 |
+
|
284 |
+
|
285 |
+
def _add_category_maps_to_metadata(cfg: CfgNode) -> None:
|
286 |
+
for dataset_name, category_map in cfg.DATASETS.CATEGORY_MAPS.items():
|
287 |
+
category_map = {
|
288 |
+
int(cat_id_src): int(cat_id_dst) for cat_id_src, cat_id_dst in category_map.items()
|
289 |
+
}
|
290 |
+
meta = MetadataCatalog.get(dataset_name)
|
291 |
+
meta.category_map = category_map
|
292 |
+
logger = logging.getLogger(__name__)
|
293 |
+
logger.info("Category maps for dataset {}: {}".format(dataset_name, meta.category_map))
|
294 |
+
|
295 |
+
|
296 |
+
def _add_category_info_to_bootstrapping_metadata(dataset_name: str, dataset_cfg: CfgNode) -> None:
|
297 |
+
meta = MetadataCatalog.get(dataset_name)
|
298 |
+
meta.category_to_class_mapping = get_category_to_class_mapping(dataset_cfg)
|
299 |
+
meta.categories = dataset_cfg.CATEGORIES
|
300 |
+
meta.max_count_per_category = dataset_cfg.MAX_COUNT_PER_CATEGORY
|
301 |
+
logger = logging.getLogger(__name__)
|
302 |
+
logger.info(
|
303 |
+
"Category to class mapping for dataset {}: {}".format(
|
304 |
+
dataset_name, meta.category_to_class_mapping
|
305 |
+
)
|
306 |
+
)
|
307 |
+
|
308 |
+
|
309 |
+
def _maybe_add_class_to_mesh_name_map_to_metadata(dataset_names: List[str], cfg: CfgNode) -> None:
|
310 |
+
for dataset_name in dataset_names:
|
311 |
+
meta = MetadataCatalog.get(dataset_name)
|
312 |
+
if not hasattr(meta, "class_to_mesh_name"):
|
313 |
+
meta.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
|
314 |
+
|
315 |
+
|
316 |
+
def _merge_categories(dataset_names: Collection[str]) -> _MergedCategoriesT:
|
317 |
+
merged_categories = defaultdict(list)
|
318 |
+
category_names = {}
|
319 |
+
for dataset_name in dataset_names:
|
320 |
+
meta = MetadataCatalog.get(dataset_name)
|
321 |
+
whitelisted_categories = meta.get("whitelisted_categories")
|
322 |
+
category_map = meta.get("category_map", {})
|
323 |
+
cat_ids = (
|
324 |
+
whitelisted_categories if whitelisted_categories is not None else meta.categories.keys()
|
325 |
+
)
|
326 |
+
for cat_id in cat_ids:
|
327 |
+
cat_name = meta.categories[cat_id]
|
328 |
+
cat_id_mapped = category_map.get(cat_id, cat_id)
|
329 |
+
if cat_id_mapped == cat_id or cat_id_mapped in cat_ids:
|
330 |
+
category_names[cat_id] = cat_name
|
331 |
+
else:
|
332 |
+
category_names[cat_id] = str(cat_id_mapped)
|
333 |
+
# assign temporary mapped category name, this name can be changed
|
334 |
+
# during the second pass, since mapped ID can correspond to a category
|
335 |
+
# from a different dataset
|
336 |
+
cat_name_mapped = meta.categories[cat_id_mapped]
|
337 |
+
merged_categories[cat_id_mapped].append(
|
338 |
+
_DatasetCategory(
|
339 |
+
id=cat_id,
|
340 |
+
name=cat_name,
|
341 |
+
mapped_id=cat_id_mapped,
|
342 |
+
mapped_name=cat_name_mapped,
|
343 |
+
dataset_name=dataset_name,
|
344 |
+
)
|
345 |
+
)
|
346 |
+
# second pass to assign proper mapped category names
|
347 |
+
for cat_id, categories in merged_categories.items():
|
348 |
+
for cat in categories:
|
349 |
+
if cat_id in category_names and cat.mapped_name != category_names[cat_id]:
|
350 |
+
cat.mapped_name = category_names[cat_id]
|
351 |
+
|
352 |
+
return merged_categories
|
353 |
+
|
354 |
+
|
355 |
+
def _warn_if_merged_different_categories(merged_categories: _MergedCategoriesT) -> None:
|
356 |
+
logger = logging.getLogger(__name__)
|
357 |
+
for cat_id in merged_categories:
|
358 |
+
merged_categories_i = merged_categories[cat_id]
|
359 |
+
first_cat_name = merged_categories_i[0].name
|
360 |
+
if len(merged_categories_i) > 1 and not all(
|
361 |
+
cat.name == first_cat_name for cat in merged_categories_i[1:]
|
362 |
+
):
|
363 |
+
cat_summary_str = ", ".join(
|
364 |
+
[f"{cat.id} ({cat.name}) from {cat.dataset_name}" for cat in merged_categories_i]
|
365 |
+
)
|
366 |
+
logger.warning(
|
367 |
+
f"Merged category {cat_id} corresponds to the following categories: "
|
368 |
+
f"{cat_summary_str}"
|
369 |
+
)
|
370 |
+
|
371 |
+
|
372 |
+
def combine_detection_dataset_dicts(
|
373 |
+
dataset_names: Collection[str],
|
374 |
+
keep_instance_predicate: Optional[InstancePredicate] = None,
|
375 |
+
proposal_files: Optional[Collection[str]] = None,
|
376 |
+
) -> List[Instance]:
|
377 |
+
"""
|
378 |
+
Load and prepare dataset dicts for training / testing
|
379 |
+
|
380 |
+
Args:
|
381 |
+
dataset_names (Collection[str]): a list of dataset names
|
382 |
+
keep_instance_predicate (Callable: Dict[str, Any] -> bool): predicate
|
383 |
+
applied to instance dicts which defines whether to keep the instance
|
384 |
+
proposal_files (Collection[str]): if given, a list of object proposal files
|
385 |
+
that match each dataset in `dataset_names`.
|
386 |
+
"""
|
387 |
+
assert len(dataset_names)
|
388 |
+
if proposal_files is None:
|
389 |
+
proposal_files = [None] * len(dataset_names)
|
390 |
+
assert len(dataset_names) == len(proposal_files)
|
391 |
+
# load datasets and metadata
|
392 |
+
dataset_name_to_dicts = {}
|
393 |
+
for dataset_name in dataset_names:
|
394 |
+
dataset_name_to_dicts[dataset_name] = DatasetCatalog.get(dataset_name)
|
395 |
+
assert len(dataset_name_to_dicts), f"Dataset '{dataset_name}' is empty!"
|
396 |
+
# merge categories, requires category metadata to be loaded
|
397 |
+
# cat_id -> [(orig_cat_id, cat_name, dataset_name)]
|
398 |
+
merged_categories = _merge_categories(dataset_names)
|
399 |
+
_warn_if_merged_different_categories(merged_categories)
|
400 |
+
merged_category_names = [
|
401 |
+
merged_categories[cat_id][0].mapped_name for cat_id in sorted(merged_categories)
|
402 |
+
]
|
403 |
+
# map to contiguous category IDs
|
404 |
+
_add_category_id_to_contiguous_id_maps_to_metadata(merged_categories)
|
405 |
+
# load annotations and dataset metadata
|
406 |
+
for dataset_name, proposal_file in zip(dataset_names, proposal_files):
|
407 |
+
dataset_dicts = dataset_name_to_dicts[dataset_name]
|
408 |
+
assert len(dataset_dicts), f"Dataset '{dataset_name}' is empty!"
|
409 |
+
if proposal_file is not None:
|
410 |
+
dataset_dicts = load_proposals_into_dataset(dataset_dicts, proposal_file)
|
411 |
+
dataset_dicts = _maybe_filter_and_map_categories(dataset_name, dataset_dicts)
|
412 |
+
print_instances_class_histogram(dataset_dicts, merged_category_names)
|
413 |
+
dataset_name_to_dicts[dataset_name] = dataset_dicts
|
414 |
+
|
415 |
+
if keep_instance_predicate is not None:
|
416 |
+
all_datasets_dicts_plain = [
|
417 |
+
d
|
418 |
+
for d in itertools.chain.from_iterable(dataset_name_to_dicts.values())
|
419 |
+
if keep_instance_predicate(d)
|
420 |
+
]
|
421 |
+
else:
|
422 |
+
all_datasets_dicts_plain = list(
|
423 |
+
itertools.chain.from_iterable(dataset_name_to_dicts.values())
|
424 |
+
)
|
425 |
+
return all_datasets_dicts_plain
|
426 |
+
|
427 |
+
|
428 |
+
def build_detection_train_loader(cfg: CfgNode, mapper=None):
|
429 |
+
"""
|
430 |
+
A data loader is created in a way similar to that of Detectron2.
|
431 |
+
The main differences are:
|
432 |
+
- it allows to combine datasets with different but compatible object category sets
|
433 |
+
|
434 |
+
The data loader is created by the following steps:
|
435 |
+
1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
|
436 |
+
2. Start workers to work on the dicts. Each worker will:
|
437 |
+
* Map each metadata dict into another format to be consumed by the model.
|
438 |
+
* Batch them by simply putting dicts into a list.
|
439 |
+
The batched ``list[mapped_dict]`` is what this dataloader will return.
|
440 |
+
|
441 |
+
Args:
|
442 |
+
cfg (CfgNode): the config
|
443 |
+
mapper (callable): a callable which takes a sample (dict) from dataset and
|
444 |
+
returns the format to be consumed by the model.
|
445 |
+
By default it will be `DatasetMapper(cfg, True)`.
|
446 |
+
|
447 |
+
Returns:
|
448 |
+
an infinite iterator of training data
|
449 |
+
"""
|
450 |
+
|
451 |
+
_add_category_whitelists_to_metadata(cfg)
|
452 |
+
_add_category_maps_to_metadata(cfg)
|
453 |
+
_maybe_add_class_to_mesh_name_map_to_metadata(cfg.DATASETS.TRAIN, cfg)
|
454 |
+
dataset_dicts = combine_detection_dataset_dicts(
|
455 |
+
cfg.DATASETS.TRAIN,
|
456 |
+
keep_instance_predicate=_get_train_keep_instance_predicate(cfg),
|
457 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
458 |
+
)
|
459 |
+
if mapper is None:
|
460 |
+
mapper = DatasetMapper(cfg, True)
|
461 |
+
return d2_build_detection_train_loader(cfg, dataset=dataset_dicts, mapper=mapper)
|
462 |
+
|
463 |
+
|
464 |
+
def build_detection_test_loader(cfg, dataset_name, mapper=None):
|
465 |
+
"""
|
466 |
+
Similar to `build_detection_train_loader`.
|
467 |
+
But this function uses the given `dataset_name` argument (instead of the names in cfg),
|
468 |
+
and uses batch size 1.
|
469 |
+
|
470 |
+
Args:
|
471 |
+
cfg: a detectron2 CfgNode
|
472 |
+
dataset_name (str): a name of the dataset that's available in the DatasetCatalog
|
473 |
+
mapper (callable): a callable which takes a sample (dict) from dataset
|
474 |
+
and returns the format to be consumed by the model.
|
475 |
+
By default it will be `DatasetMapper(cfg, False)`.
|
476 |
+
|
477 |
+
Returns:
|
478 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
479 |
+
dataset, with test-time transformation and batching.
|
480 |
+
"""
|
481 |
+
_add_category_whitelists_to_metadata(cfg)
|
482 |
+
_add_category_maps_to_metadata(cfg)
|
483 |
+
_maybe_add_class_to_mesh_name_map_to_metadata([dataset_name], cfg)
|
484 |
+
dataset_dicts = combine_detection_dataset_dicts(
|
485 |
+
[dataset_name],
|
486 |
+
keep_instance_predicate=_get_test_keep_instance_predicate(cfg),
|
487 |
+
proposal_files=(
|
488 |
+
[cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]]
|
489 |
+
if cfg.MODEL.LOAD_PROPOSALS
|
490 |
+
else None
|
491 |
+
),
|
492 |
+
)
|
493 |
+
sampler = None
|
494 |
+
if not cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE:
|
495 |
+
sampler = torch.utils.data.SequentialSampler(dataset_dicts)
|
496 |
+
if mapper is None:
|
497 |
+
mapper = DatasetMapper(cfg, False)
|
498 |
+
return d2_build_detection_test_loader(
|
499 |
+
dataset_dicts, mapper=mapper, num_workers=cfg.DATALOADER.NUM_WORKERS, sampler=sampler
|
500 |
+
)
|
501 |
+
|
502 |
+
|
503 |
+
def build_frame_selector(cfg: CfgNode):
|
504 |
+
strategy = FrameSelectionStrategy(cfg.STRATEGY)
|
505 |
+
if strategy == FrameSelectionStrategy.RANDOM_K:
|
506 |
+
frame_selector = RandomKFramesSelector(cfg.NUM_IMAGES)
|
507 |
+
elif strategy == FrameSelectionStrategy.FIRST_K:
|
508 |
+
frame_selector = FirstKFramesSelector(cfg.NUM_IMAGES)
|
509 |
+
elif strategy == FrameSelectionStrategy.LAST_K:
|
510 |
+
frame_selector = LastKFramesSelector(cfg.NUM_IMAGES)
|
511 |
+
elif strategy == FrameSelectionStrategy.ALL:
|
512 |
+
frame_selector = None
|
513 |
+
# pyre-fixme[61]: `frame_selector` may not be initialized here.
|
514 |
+
return frame_selector
|
515 |
+
|
516 |
+
|
517 |
+
def build_transform(cfg: CfgNode, data_type: str):
|
518 |
+
if cfg.TYPE == "resize":
|
519 |
+
if data_type == "image":
|
520 |
+
return ImageResizeTransform(cfg.MIN_SIZE, cfg.MAX_SIZE)
|
521 |
+
raise ValueError(f"Unknown transform {cfg.TYPE} for data type {data_type}")
|
522 |
+
|
523 |
+
|
524 |
+
def build_combined_loader(cfg: CfgNode, loaders: Collection[Loader], ratios: Sequence[float]):
|
525 |
+
images_per_worker = _compute_num_images_per_worker(cfg)
|
526 |
+
return CombinedDataLoader(loaders, images_per_worker, ratios)
|
527 |
+
|
528 |
+
|
529 |
+
def build_bootstrap_dataset(dataset_name: str, cfg: CfgNode) -> Sequence[torch.Tensor]:
|
530 |
+
"""
|
531 |
+
Build dataset that provides data to bootstrap on
|
532 |
+
|
533 |
+
Args:
|
534 |
+
dataset_name (str): Name of the dataset, needs to have associated metadata
|
535 |
+
to load the data
|
536 |
+
cfg (CfgNode): bootstrapping config
|
537 |
+
Returns:
|
538 |
+
Sequence[Tensor] - dataset that provides image batches, Tensors of size
|
539 |
+
[N, C, H, W] of type float32
|
540 |
+
"""
|
541 |
+
logger = logging.getLogger(__name__)
|
542 |
+
_add_category_info_to_bootstrapping_metadata(dataset_name, cfg)
|
543 |
+
meta = MetadataCatalog.get(dataset_name)
|
544 |
+
factory = BootstrapDatasetFactoryCatalog.get(meta.dataset_type)
|
545 |
+
dataset = None
|
546 |
+
if factory is not None:
|
547 |
+
dataset = factory(meta, cfg)
|
548 |
+
if dataset is None:
|
549 |
+
logger.warning(f"Failed to create dataset {dataset_name} of type {meta.dataset_type}")
|
550 |
+
return dataset
|
551 |
+
|
552 |
+
|
553 |
+
def build_data_sampler(cfg: CfgNode, sampler_cfg: CfgNode, embedder: Optional[torch.nn.Module]):
|
554 |
+
if sampler_cfg.TYPE == "densepose_uniform":
|
555 |
+
data_sampler = PredictionToGroundTruthSampler()
|
556 |
+
# transform densepose pred -> gt
|
557 |
+
data_sampler.register_sampler(
|
558 |
+
"pred_densepose",
|
559 |
+
"gt_densepose",
|
560 |
+
DensePoseUniformSampler(count_per_class=sampler_cfg.COUNT_PER_CLASS),
|
561 |
+
)
|
562 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
563 |
+
return data_sampler
|
564 |
+
elif sampler_cfg.TYPE == "densepose_UV_confidence":
|
565 |
+
data_sampler = PredictionToGroundTruthSampler()
|
566 |
+
# transform densepose pred -> gt
|
567 |
+
data_sampler.register_sampler(
|
568 |
+
"pred_densepose",
|
569 |
+
"gt_densepose",
|
570 |
+
DensePoseConfidenceBasedSampler(
|
571 |
+
confidence_channel="sigma_2",
|
572 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
573 |
+
search_proportion=0.5,
|
574 |
+
),
|
575 |
+
)
|
576 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
577 |
+
return data_sampler
|
578 |
+
elif sampler_cfg.TYPE == "densepose_fine_segm_confidence":
|
579 |
+
data_sampler = PredictionToGroundTruthSampler()
|
580 |
+
# transform densepose pred -> gt
|
581 |
+
data_sampler.register_sampler(
|
582 |
+
"pred_densepose",
|
583 |
+
"gt_densepose",
|
584 |
+
DensePoseConfidenceBasedSampler(
|
585 |
+
confidence_channel="fine_segm_confidence",
|
586 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
587 |
+
search_proportion=0.5,
|
588 |
+
),
|
589 |
+
)
|
590 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
591 |
+
return data_sampler
|
592 |
+
elif sampler_cfg.TYPE == "densepose_coarse_segm_confidence":
|
593 |
+
data_sampler = PredictionToGroundTruthSampler()
|
594 |
+
# transform densepose pred -> gt
|
595 |
+
data_sampler.register_sampler(
|
596 |
+
"pred_densepose",
|
597 |
+
"gt_densepose",
|
598 |
+
DensePoseConfidenceBasedSampler(
|
599 |
+
confidence_channel="coarse_segm_confidence",
|
600 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
601 |
+
search_proportion=0.5,
|
602 |
+
),
|
603 |
+
)
|
604 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
605 |
+
return data_sampler
|
606 |
+
elif sampler_cfg.TYPE == "densepose_cse_uniform":
|
607 |
+
assert embedder is not None
|
608 |
+
data_sampler = PredictionToGroundTruthSampler()
|
609 |
+
# transform densepose pred -> gt
|
610 |
+
data_sampler.register_sampler(
|
611 |
+
"pred_densepose",
|
612 |
+
"gt_densepose",
|
613 |
+
DensePoseCSEUniformSampler(
|
614 |
+
cfg=cfg,
|
615 |
+
use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
|
616 |
+
embedder=embedder,
|
617 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
618 |
+
),
|
619 |
+
)
|
620 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
621 |
+
return data_sampler
|
622 |
+
elif sampler_cfg.TYPE == "densepose_cse_coarse_segm_confidence":
|
623 |
+
assert embedder is not None
|
624 |
+
data_sampler = PredictionToGroundTruthSampler()
|
625 |
+
# transform densepose pred -> gt
|
626 |
+
data_sampler.register_sampler(
|
627 |
+
"pred_densepose",
|
628 |
+
"gt_densepose",
|
629 |
+
DensePoseCSEConfidenceBasedSampler(
|
630 |
+
cfg=cfg,
|
631 |
+
use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
|
632 |
+
embedder=embedder,
|
633 |
+
confidence_channel="coarse_segm_confidence",
|
634 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
635 |
+
search_proportion=0.5,
|
636 |
+
),
|
637 |
+
)
|
638 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
639 |
+
return data_sampler
|
640 |
+
|
641 |
+
raise ValueError(f"Unknown data sampler type {sampler_cfg.TYPE}")
|
642 |
+
|
643 |
+
|
644 |
+
def build_data_filter(cfg: CfgNode):
|
645 |
+
if cfg.TYPE == "detection_score":
|
646 |
+
min_score = cfg.MIN_VALUE
|
647 |
+
return ScoreBasedFilter(min_score=min_score)
|
648 |
+
raise ValueError(f"Unknown data filter type {cfg.TYPE}")
|
649 |
+
|
650 |
+
|
651 |
+
def build_inference_based_loader(
|
652 |
+
cfg: CfgNode,
|
653 |
+
dataset_cfg: CfgNode,
|
654 |
+
model: torch.nn.Module,
|
655 |
+
embedder: Optional[torch.nn.Module] = None,
|
656 |
+
) -> InferenceBasedLoader:
|
657 |
+
"""
|
658 |
+
Constructs data loader based on inference results of a model.
|
659 |
+
"""
|
660 |
+
dataset = build_bootstrap_dataset(dataset_cfg.DATASET, dataset_cfg.IMAGE_LOADER)
|
661 |
+
meta = MetadataCatalog.get(dataset_cfg.DATASET)
|
662 |
+
training_sampler = TrainingSampler(len(dataset))
|
663 |
+
data_loader = torch.utils.data.DataLoader(
|
664 |
+
dataset, # pyre-ignore[6]
|
665 |
+
batch_size=dataset_cfg.IMAGE_LOADER.BATCH_SIZE,
|
666 |
+
sampler=training_sampler,
|
667 |
+
num_workers=dataset_cfg.IMAGE_LOADER.NUM_WORKERS,
|
668 |
+
collate_fn=trivial_batch_collator,
|
669 |
+
worker_init_fn=worker_init_reset_seed,
|
670 |
+
)
|
671 |
+
return InferenceBasedLoader(
|
672 |
+
model,
|
673 |
+
data_loader=data_loader,
|
674 |
+
data_sampler=build_data_sampler(cfg, dataset_cfg.DATA_SAMPLER, embedder),
|
675 |
+
data_filter=build_data_filter(dataset_cfg.FILTER),
|
676 |
+
shuffle=True,
|
677 |
+
batch_size=dataset_cfg.INFERENCE.OUTPUT_BATCH_SIZE,
|
678 |
+
inference_batch_size=dataset_cfg.INFERENCE.INPUT_BATCH_SIZE,
|
679 |
+
category_to_class_mapping=meta.category_to_class_mapping,
|
680 |
+
)
|
681 |
+
|
682 |
+
|
683 |
+
def has_inference_based_loaders(cfg: CfgNode) -> bool:
|
684 |
+
"""
|
685 |
+
Returns True, if at least one inferense-based loader must
|
686 |
+
be instantiated for training
|
687 |
+
"""
|
688 |
+
return len(cfg.BOOTSTRAP_DATASETS) > 0
|
689 |
+
|
690 |
+
|
691 |
+
def build_inference_based_loaders(
|
692 |
+
cfg: CfgNode, model: torch.nn.Module
|
693 |
+
) -> Tuple[List[InferenceBasedLoader], List[float]]:
|
694 |
+
loaders = []
|
695 |
+
ratios = []
|
696 |
+
embedder = build_densepose_embedder(cfg).to(device=model.device) # pyre-ignore[16]
|
697 |
+
for dataset_spec in cfg.BOOTSTRAP_DATASETS:
|
698 |
+
dataset_cfg = get_bootstrap_dataset_config().clone()
|
699 |
+
dataset_cfg.merge_from_other_cfg(CfgNode(dataset_spec))
|
700 |
+
loader = build_inference_based_loader(cfg, dataset_cfg, model, embedder)
|
701 |
+
loaders.append(loader)
|
702 |
+
ratios.append(dataset_cfg.RATIO)
|
703 |
+
return loaders, ratios
|
704 |
+
|
705 |
+
|
706 |
+
def build_video_list_dataset(meta: Metadata, cfg: CfgNode):
|
707 |
+
video_list_fpath = meta.video_list_fpath
|
708 |
+
video_base_path = meta.video_base_path
|
709 |
+
category = meta.category
|
710 |
+
if cfg.TYPE == "video_keyframe":
|
711 |
+
frame_selector = build_frame_selector(cfg.SELECT)
|
712 |
+
transform = build_transform(cfg.TRANSFORM, data_type="image")
|
713 |
+
video_list = video_list_from_file(video_list_fpath, video_base_path)
|
714 |
+
keyframe_helper_fpath = getattr(cfg, "KEYFRAME_HELPER", None)
|
715 |
+
return VideoKeyframeDataset(
|
716 |
+
video_list, category, frame_selector, transform, keyframe_helper_fpath
|
717 |
+
)
|
718 |
+
|
719 |
+
|
720 |
+
class _BootstrapDatasetFactoryCatalog(UserDict):
|
721 |
+
"""
|
722 |
+
A global dictionary that stores information about bootstrapped datasets creation functions
|
723 |
+
from metadata and config, for diverse DatasetType
|
724 |
+
"""
|
725 |
+
|
726 |
+
def register(self, dataset_type: DatasetType, factory: Callable[[Metadata, CfgNode], Dataset]):
|
727 |
+
"""
|
728 |
+
Args:
|
729 |
+
dataset_type (DatasetType): a DatasetType e.g. DatasetType.VIDEO_LIST
|
730 |
+
factory (Callable[Metadata, CfgNode]): a callable which takes Metadata and cfg
|
731 |
+
arguments and returns a dataset object.
|
732 |
+
"""
|
733 |
+
assert dataset_type not in self, "Dataset '{}' is already registered!".format(dataset_type)
|
734 |
+
self[dataset_type] = factory
|
735 |
+
|
736 |
+
|
737 |
+
BootstrapDatasetFactoryCatalog = _BootstrapDatasetFactoryCatalog()
|
738 |
+
BootstrapDatasetFactoryCatalog.register(DatasetType.VIDEO_LIST, build_video_list_dataset)
|
densepose/data/combined_loader.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
import random
|
6 |
+
from collections import deque
|
7 |
+
from typing import Any, Collection, Deque, Iterable, Iterator, List, Sequence
|
8 |
+
|
9 |
+
Loader = Iterable[Any]
|
10 |
+
|
11 |
+
|
12 |
+
def _pooled_next(iterator: Iterator[Any], pool: Deque[Any]):
|
13 |
+
if not pool:
|
14 |
+
pool.extend(next(iterator))
|
15 |
+
return pool.popleft()
|
16 |
+
|
17 |
+
|
18 |
+
class CombinedDataLoader:
|
19 |
+
"""
|
20 |
+
Combines data loaders using the provided sampling ratios
|
21 |
+
"""
|
22 |
+
|
23 |
+
BATCH_COUNT = 100
|
24 |
+
|
25 |
+
def __init__(self, loaders: Collection[Loader], batch_size: int, ratios: Sequence[float]):
|
26 |
+
self.loaders = loaders
|
27 |
+
self.batch_size = batch_size
|
28 |
+
self.ratios = ratios
|
29 |
+
|
30 |
+
def __iter__(self) -> Iterator[List[Any]]:
|
31 |
+
iters = [iter(loader) for loader in self.loaders]
|
32 |
+
indices = []
|
33 |
+
pool = [deque()] * len(iters)
|
34 |
+
# infinite iterator, as in D2
|
35 |
+
while True:
|
36 |
+
if not indices:
|
37 |
+
# just a buffer of indices, its size doesn't matter
|
38 |
+
# as long as it's a multiple of batch_size
|
39 |
+
k = self.batch_size * self.BATCH_COUNT
|
40 |
+
indices = random.choices(range(len(self.loaders)), self.ratios, k=k)
|
41 |
+
try:
|
42 |
+
batch = [_pooled_next(iters[i], pool[i]) for i in indices[: self.batch_size]]
|
43 |
+
except StopIteration:
|
44 |
+
break
|
45 |
+
indices = indices[self.batch_size :]
|
46 |
+
yield batch
|
densepose/data/dataset_mapper.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
# pyre-unsafe
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import logging
|
8 |
+
from typing import Any, Dict, List, Tuple
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from detectron2.data import MetadataCatalog
|
12 |
+
from detectron2.data import detection_utils as utils
|
13 |
+
from detectron2.data import transforms as T
|
14 |
+
from detectron2.layers import ROIAlign
|
15 |
+
from detectron2.structures import BoxMode
|
16 |
+
from detectron2.utils.file_io import PathManager
|
17 |
+
|
18 |
+
from densepose.structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
|
19 |
+
|
20 |
+
|
21 |
+
def build_augmentation(cfg, is_train):
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
result = utils.build_augmentation(cfg, is_train)
|
24 |
+
if is_train:
|
25 |
+
random_rotation = T.RandomRotation(
|
26 |
+
cfg.INPUT.ROTATION_ANGLES, expand=False, sample_style="choice"
|
27 |
+
)
|
28 |
+
result.append(random_rotation)
|
29 |
+
logger.info("DensePose-specific augmentation used in training: " + str(random_rotation))
|
30 |
+
return result
|
31 |
+
|
32 |
+
|
33 |
+
class DatasetMapper:
|
34 |
+
"""
|
35 |
+
A customized version of `detectron2.data.DatasetMapper`
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, cfg, is_train=True):
|
39 |
+
self.augmentation = build_augmentation(cfg, is_train)
|
40 |
+
|
41 |
+
# fmt: off
|
42 |
+
self.img_format = cfg.INPUT.FORMAT
|
43 |
+
self.mask_on = (
|
44 |
+
cfg.MODEL.MASK_ON or (
|
45 |
+
cfg.MODEL.DENSEPOSE_ON
|
46 |
+
and cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS)
|
47 |
+
)
|
48 |
+
self.keypoint_on = cfg.MODEL.KEYPOINT_ON
|
49 |
+
self.densepose_on = cfg.MODEL.DENSEPOSE_ON
|
50 |
+
assert not cfg.MODEL.LOAD_PROPOSALS, "not supported yet"
|
51 |
+
# fmt: on
|
52 |
+
if self.keypoint_on and is_train:
|
53 |
+
# Flip only makes sense in training
|
54 |
+
self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
|
55 |
+
else:
|
56 |
+
self.keypoint_hflip_indices = None
|
57 |
+
|
58 |
+
if self.densepose_on:
|
59 |
+
densepose_transform_srcs = [
|
60 |
+
MetadataCatalog.get(ds).densepose_transform_src
|
61 |
+
for ds in cfg.DATASETS.TRAIN + cfg.DATASETS.TEST
|
62 |
+
]
|
63 |
+
assert len(densepose_transform_srcs) > 0
|
64 |
+
# TODO: check that DensePose transformation data is the same for
|
65 |
+
# all the datasets. Otherwise one would have to pass DB ID with
|
66 |
+
# each entry to select proper transformation data. For now, since
|
67 |
+
# all DensePose annotated data uses the same data semantics, we
|
68 |
+
# omit this check.
|
69 |
+
densepose_transform_data_fpath = PathManager.get_local_path(densepose_transform_srcs[0])
|
70 |
+
self.densepose_transform_data = DensePoseTransformData.load(
|
71 |
+
densepose_transform_data_fpath
|
72 |
+
)
|
73 |
+
|
74 |
+
self.is_train = is_train
|
75 |
+
|
76 |
+
def __call__(self, dataset_dict):
|
77 |
+
"""
|
78 |
+
Args:
|
79 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
dict: a format that builtin models in detectron2 accept
|
83 |
+
"""
|
84 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
85 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
86 |
+
utils.check_image_size(dataset_dict, image)
|
87 |
+
|
88 |
+
image, transforms = T.apply_transform_gens(self.augmentation, image)
|
89 |
+
image_shape = image.shape[:2] # h, w
|
90 |
+
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
|
91 |
+
|
92 |
+
if not self.is_train:
|
93 |
+
dataset_dict.pop("annotations", None)
|
94 |
+
return dataset_dict
|
95 |
+
|
96 |
+
for anno in dataset_dict["annotations"]:
|
97 |
+
if not self.mask_on:
|
98 |
+
anno.pop("segmentation", None)
|
99 |
+
if not self.keypoint_on:
|
100 |
+
anno.pop("keypoints", None)
|
101 |
+
|
102 |
+
# USER: Implement additional transformations if you have other types of data
|
103 |
+
# USER: Don't call transpose_densepose if you don't need
|
104 |
+
annos = [
|
105 |
+
self._transform_densepose(
|
106 |
+
utils.transform_instance_annotations(
|
107 |
+
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
|
108 |
+
),
|
109 |
+
transforms,
|
110 |
+
)
|
111 |
+
for obj in dataset_dict.pop("annotations")
|
112 |
+
if obj.get("iscrowd", 0) == 0
|
113 |
+
]
|
114 |
+
|
115 |
+
if self.mask_on:
|
116 |
+
self._add_densepose_masks_as_segmentation(annos, image_shape)
|
117 |
+
|
118 |
+
instances = utils.annotations_to_instances(annos, image_shape, mask_format="bitmask")
|
119 |
+
densepose_annotations = [obj.get("densepose") for obj in annos]
|
120 |
+
if densepose_annotations and not all(v is None for v in densepose_annotations):
|
121 |
+
instances.gt_densepose = DensePoseList(
|
122 |
+
densepose_annotations, instances.gt_boxes, image_shape
|
123 |
+
)
|
124 |
+
|
125 |
+
dataset_dict["instances"] = instances[instances.gt_boxes.nonempty()]
|
126 |
+
return dataset_dict
|
127 |
+
|
128 |
+
def _transform_densepose(self, annotation, transforms):
|
129 |
+
if not self.densepose_on:
|
130 |
+
return annotation
|
131 |
+
|
132 |
+
# Handle densepose annotations
|
133 |
+
is_valid, reason_not_valid = DensePoseDataRelative.validate_annotation(annotation)
|
134 |
+
if is_valid:
|
135 |
+
densepose_data = DensePoseDataRelative(annotation, cleanup=True)
|
136 |
+
densepose_data.apply_transform(transforms, self.densepose_transform_data)
|
137 |
+
annotation["densepose"] = densepose_data
|
138 |
+
else:
|
139 |
+
# logger = logging.getLogger(__name__)
|
140 |
+
# logger.debug("Could not load DensePose annotation: {}".format(reason_not_valid))
|
141 |
+
DensePoseDataRelative.cleanup_annotation(annotation)
|
142 |
+
# NOTE: annotations for certain instances may be unavailable.
|
143 |
+
# 'None' is accepted by the DensePostList data structure.
|
144 |
+
annotation["densepose"] = None
|
145 |
+
return annotation
|
146 |
+
|
147 |
+
def _add_densepose_masks_as_segmentation(
|
148 |
+
self, annotations: List[Dict[str, Any]], image_shape_hw: Tuple[int, int]
|
149 |
+
):
|
150 |
+
for obj in annotations:
|
151 |
+
if ("densepose" not in obj) or ("segmentation" in obj):
|
152 |
+
continue
|
153 |
+
# DP segmentation: torch.Tensor [S, S] of float32, S=256
|
154 |
+
segm_dp = torch.zeros_like(obj["densepose"].segm)
|
155 |
+
segm_dp[obj["densepose"].segm > 0] = 1
|
156 |
+
segm_h, segm_w = segm_dp.shape
|
157 |
+
bbox_segm_dp = torch.tensor((0, 0, segm_h - 1, segm_w - 1), dtype=torch.float32)
|
158 |
+
# image bbox
|
159 |
+
x0, y0, x1, y1 = (
|
160 |
+
v.item() for v in BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS)
|
161 |
+
)
|
162 |
+
segm_aligned = (
|
163 |
+
ROIAlign((y1 - y0, x1 - x0), 1.0, 0, aligned=True)
|
164 |
+
.forward(segm_dp.view(1, 1, *segm_dp.shape), bbox_segm_dp)
|
165 |
+
.squeeze()
|
166 |
+
)
|
167 |
+
image_mask = torch.zeros(*image_shape_hw, dtype=torch.float32)
|
168 |
+
image_mask[y0:y1, x0:x1] = segm_aligned
|
169 |
+
# segmentation for BitMask: np.array [H, W] of bool
|
170 |
+
obj["segmentation"] = image_mask >= 0.5
|
densepose/data/datasets/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from . import builtin # ensure the builtin datasets are registered
|
6 |
+
|
7 |
+
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
|
densepose/data/datasets/builtin.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
from .chimpnsee import register_dataset as register_chimpnsee_dataset
|
5 |
+
from .coco import BASE_DATASETS as BASE_COCO_DATASETS
|
6 |
+
from .coco import DATASETS as COCO_DATASETS
|
7 |
+
from .coco import register_datasets as register_coco_datasets
|
8 |
+
from .lvis import DATASETS as LVIS_DATASETS
|
9 |
+
from .lvis import register_datasets as register_lvis_datasets
|
10 |
+
|
11 |
+
DEFAULT_DATASETS_ROOT = "datasets"
|
12 |
+
|
13 |
+
|
14 |
+
register_coco_datasets(COCO_DATASETS, DEFAULT_DATASETS_ROOT)
|
15 |
+
register_coco_datasets(BASE_COCO_DATASETS, DEFAULT_DATASETS_ROOT)
|
16 |
+
register_lvis_datasets(LVIS_DATASETS, DEFAULT_DATASETS_ROOT)
|
17 |
+
|
18 |
+
register_chimpnsee_dataset(DEFAULT_DATASETS_ROOT) # pyre-ignore[19]
|
densepose/data/datasets/chimpnsee.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
8 |
+
|
9 |
+
from ..utils import maybe_prepend_base_path
|
10 |
+
from .dataset_type import DatasetType
|
11 |
+
|
12 |
+
CHIMPNSEE_DATASET_NAME = "chimpnsee"
|
13 |
+
|
14 |
+
|
15 |
+
def register_dataset(datasets_root: Optional[str] = None) -> None:
|
16 |
+
def empty_load_callback():
|
17 |
+
pass
|
18 |
+
|
19 |
+
video_list_fpath = maybe_prepend_base_path(
|
20 |
+
datasets_root,
|
21 |
+
"chimpnsee/cdna.eva.mpg.de/video_list.txt",
|
22 |
+
)
|
23 |
+
video_base_path = maybe_prepend_base_path(datasets_root, "chimpnsee/cdna.eva.mpg.de")
|
24 |
+
|
25 |
+
DatasetCatalog.register(CHIMPNSEE_DATASET_NAME, empty_load_callback)
|
26 |
+
MetadataCatalog.get(CHIMPNSEE_DATASET_NAME).set(
|
27 |
+
dataset_type=DatasetType.VIDEO_LIST,
|
28 |
+
video_list_fpath=video_list_fpath,
|
29 |
+
video_base_path=video_base_path,
|
30 |
+
category="chimpanzee",
|
31 |
+
)
|
densepose/data/datasets/coco.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
import contextlib
|
5 |
+
import io
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
from collections import defaultdict
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import Any, Dict, Iterable, List, Optional
|
11 |
+
from fvcore.common.timer import Timer
|
12 |
+
|
13 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
14 |
+
from detectron2.structures import BoxMode
|
15 |
+
from detectron2.utils.file_io import PathManager
|
16 |
+
|
17 |
+
from ..utils import maybe_prepend_base_path
|
18 |
+
|
19 |
+
DENSEPOSE_MASK_KEY = "dp_masks"
|
20 |
+
DENSEPOSE_IUV_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_I", "dp_U", "dp_V"]
|
21 |
+
DENSEPOSE_CSE_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_vertex", "ref_model"]
|
22 |
+
DENSEPOSE_ALL_POSSIBLE_KEYS = set(
|
23 |
+
DENSEPOSE_IUV_KEYS_WITHOUT_MASK + DENSEPOSE_CSE_KEYS_WITHOUT_MASK + [DENSEPOSE_MASK_KEY]
|
24 |
+
)
|
25 |
+
DENSEPOSE_METADATA_URL_PREFIX = "https://dl.fbaipublicfiles.com/densepose/data/"
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class CocoDatasetInfo:
|
30 |
+
name: str
|
31 |
+
images_root: str
|
32 |
+
annotations_fpath: str
|
33 |
+
|
34 |
+
|
35 |
+
DATASETS = [
|
36 |
+
CocoDatasetInfo(
|
37 |
+
name="densepose_coco_2014_train",
|
38 |
+
images_root="coco/train2014",
|
39 |
+
annotations_fpath="coco/annotations/densepose_train2014.json",
|
40 |
+
),
|
41 |
+
CocoDatasetInfo(
|
42 |
+
name="densepose_coco_2014_minival",
|
43 |
+
images_root="coco/val2014",
|
44 |
+
annotations_fpath="coco/annotations/densepose_minival2014.json",
|
45 |
+
),
|
46 |
+
CocoDatasetInfo(
|
47 |
+
name="densepose_coco_2014_minival_100",
|
48 |
+
images_root="coco/val2014",
|
49 |
+
annotations_fpath="coco/annotations/densepose_minival2014_100.json",
|
50 |
+
),
|
51 |
+
CocoDatasetInfo(
|
52 |
+
name="densepose_coco_2014_valminusminival",
|
53 |
+
images_root="coco/val2014",
|
54 |
+
annotations_fpath="coco/annotations/densepose_valminusminival2014.json",
|
55 |
+
),
|
56 |
+
CocoDatasetInfo(
|
57 |
+
name="densepose_coco_2014_train_cse",
|
58 |
+
images_root="coco/train2014",
|
59 |
+
annotations_fpath="coco_cse/densepose_train2014_cse.json",
|
60 |
+
),
|
61 |
+
CocoDatasetInfo(
|
62 |
+
name="densepose_coco_2014_minival_cse",
|
63 |
+
images_root="coco/val2014",
|
64 |
+
annotations_fpath="coco_cse/densepose_minival2014_cse.json",
|
65 |
+
),
|
66 |
+
CocoDatasetInfo(
|
67 |
+
name="densepose_coco_2014_minival_100_cse",
|
68 |
+
images_root="coco/val2014",
|
69 |
+
annotations_fpath="coco_cse/densepose_minival2014_100_cse.json",
|
70 |
+
),
|
71 |
+
CocoDatasetInfo(
|
72 |
+
name="densepose_coco_2014_valminusminival_cse",
|
73 |
+
images_root="coco/val2014",
|
74 |
+
annotations_fpath="coco_cse/densepose_valminusminival2014_cse.json",
|
75 |
+
),
|
76 |
+
CocoDatasetInfo(
|
77 |
+
name="densepose_chimps",
|
78 |
+
images_root="densepose_chimps/images",
|
79 |
+
annotations_fpath="densepose_chimps/densepose_chimps_densepose.json",
|
80 |
+
),
|
81 |
+
CocoDatasetInfo(
|
82 |
+
name="densepose_chimps_cse_train",
|
83 |
+
images_root="densepose_chimps/images",
|
84 |
+
annotations_fpath="densepose_chimps/densepose_chimps_cse_train.json",
|
85 |
+
),
|
86 |
+
CocoDatasetInfo(
|
87 |
+
name="densepose_chimps_cse_val",
|
88 |
+
images_root="densepose_chimps/images",
|
89 |
+
annotations_fpath="densepose_chimps/densepose_chimps_cse_val.json",
|
90 |
+
),
|
91 |
+
CocoDatasetInfo(
|
92 |
+
name="posetrack2017_train",
|
93 |
+
images_root="posetrack2017/posetrack_data_2017",
|
94 |
+
annotations_fpath="posetrack2017/densepose_posetrack_train2017.json",
|
95 |
+
),
|
96 |
+
CocoDatasetInfo(
|
97 |
+
name="posetrack2017_val",
|
98 |
+
images_root="posetrack2017/posetrack_data_2017",
|
99 |
+
annotations_fpath="posetrack2017/densepose_posetrack_val2017.json",
|
100 |
+
),
|
101 |
+
CocoDatasetInfo(
|
102 |
+
name="lvis_v05_train",
|
103 |
+
images_root="coco/train2017",
|
104 |
+
annotations_fpath="lvis/lvis_v0.5_plus_dp_train.json",
|
105 |
+
),
|
106 |
+
CocoDatasetInfo(
|
107 |
+
name="lvis_v05_val",
|
108 |
+
images_root="coco/val2017",
|
109 |
+
annotations_fpath="lvis/lvis_v0.5_plus_dp_val.json",
|
110 |
+
),
|
111 |
+
]
|
112 |
+
|
113 |
+
|
114 |
+
BASE_DATASETS = [
|
115 |
+
CocoDatasetInfo(
|
116 |
+
name="base_coco_2017_train",
|
117 |
+
images_root="coco/train2017",
|
118 |
+
annotations_fpath="coco/annotations/instances_train2017.json",
|
119 |
+
),
|
120 |
+
CocoDatasetInfo(
|
121 |
+
name="base_coco_2017_val",
|
122 |
+
images_root="coco/val2017",
|
123 |
+
annotations_fpath="coco/annotations/instances_val2017.json",
|
124 |
+
),
|
125 |
+
CocoDatasetInfo(
|
126 |
+
name="base_coco_2017_val_100",
|
127 |
+
images_root="coco/val2017",
|
128 |
+
annotations_fpath="coco/annotations/instances_val2017_100.json",
|
129 |
+
),
|
130 |
+
]
|
131 |
+
|
132 |
+
|
133 |
+
def get_metadata(base_path: Optional[str]) -> Dict[str, Any]:
|
134 |
+
"""
|
135 |
+
Returns metadata associated with COCO DensePose datasets
|
136 |
+
|
137 |
+
Args:
|
138 |
+
base_path: Optional[str]
|
139 |
+
Base path used to load metadata from
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
Dict[str, Any]
|
143 |
+
Metadata in the form of a dictionary
|
144 |
+
"""
|
145 |
+
meta = {
|
146 |
+
"densepose_transform_src": maybe_prepend_base_path(base_path, "UV_symmetry_transforms.mat"),
|
147 |
+
"densepose_smpl_subdiv": maybe_prepend_base_path(base_path, "SMPL_subdiv.mat"),
|
148 |
+
"densepose_smpl_subdiv_transform": maybe_prepend_base_path(
|
149 |
+
base_path,
|
150 |
+
"SMPL_SUBDIV_TRANSFORM.mat",
|
151 |
+
),
|
152 |
+
}
|
153 |
+
return meta
|
154 |
+
|
155 |
+
|
156 |
+
def _load_coco_annotations(json_file: str):
|
157 |
+
"""
|
158 |
+
Load COCO annotations from a JSON file
|
159 |
+
|
160 |
+
Args:
|
161 |
+
json_file: str
|
162 |
+
Path to the file to load annotations from
|
163 |
+
Returns:
|
164 |
+
Instance of `pycocotools.coco.COCO` that provides access to annotations
|
165 |
+
data
|
166 |
+
"""
|
167 |
+
from pycocotools.coco import COCO
|
168 |
+
|
169 |
+
logger = logging.getLogger(__name__)
|
170 |
+
timer = Timer()
|
171 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
172 |
+
coco_api = COCO(json_file)
|
173 |
+
if timer.seconds() > 1:
|
174 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
175 |
+
return coco_api
|
176 |
+
|
177 |
+
|
178 |
+
def _add_categories_metadata(dataset_name: str, categories: List[Dict[str, Any]]):
|
179 |
+
meta = MetadataCatalog.get(dataset_name)
|
180 |
+
meta.categories = {c["id"]: c["name"] for c in categories}
|
181 |
+
logger = logging.getLogger(__name__)
|
182 |
+
logger.info("Dataset {} categories: {}".format(dataset_name, meta.categories))
|
183 |
+
|
184 |
+
|
185 |
+
def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]):
|
186 |
+
if "minival" in json_file:
|
187 |
+
# Skip validation on COCO2014 valminusminival and minival annotations
|
188 |
+
# The ratio of buggy annotations there is tiny and does not affect accuracy
|
189 |
+
# Therefore we explicitly white-list them
|
190 |
+
return
|
191 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
192 |
+
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
|
193 |
+
json_file
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
|
198 |
+
if "bbox" not in ann_dict:
|
199 |
+
return
|
200 |
+
obj["bbox"] = ann_dict["bbox"]
|
201 |
+
obj["bbox_mode"] = BoxMode.XYWH_ABS
|
202 |
+
|
203 |
+
|
204 |
+
def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
|
205 |
+
if "segmentation" not in ann_dict:
|
206 |
+
return
|
207 |
+
segm = ann_dict["segmentation"]
|
208 |
+
if not isinstance(segm, dict):
|
209 |
+
# filter out invalid polygons (< 3 points)
|
210 |
+
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
|
211 |
+
if len(segm) == 0:
|
212 |
+
return
|
213 |
+
obj["segmentation"] = segm
|
214 |
+
|
215 |
+
|
216 |
+
def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
|
217 |
+
if "keypoints" not in ann_dict:
|
218 |
+
return
|
219 |
+
keypts = ann_dict["keypoints"] # list[int]
|
220 |
+
for idx, v in enumerate(keypts):
|
221 |
+
if idx % 3 != 2:
|
222 |
+
# COCO's segmentation coordinates are floating points in [0, H or W],
|
223 |
+
# but keypoint coordinates are integers in [0, H-1 or W-1]
|
224 |
+
# Therefore we assume the coordinates are "pixel indices" and
|
225 |
+
# add 0.5 to convert to floating point coordinates.
|
226 |
+
keypts[idx] = v + 0.5
|
227 |
+
obj["keypoints"] = keypts
|
228 |
+
|
229 |
+
|
230 |
+
def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
|
231 |
+
for key in DENSEPOSE_ALL_POSSIBLE_KEYS:
|
232 |
+
if key in ann_dict:
|
233 |
+
obj[key] = ann_dict[key]
|
234 |
+
|
235 |
+
|
236 |
+
def _combine_images_with_annotations(
|
237 |
+
dataset_name: str,
|
238 |
+
image_root: str,
|
239 |
+
img_datas: Iterable[Dict[str, Any]],
|
240 |
+
ann_datas: Iterable[Iterable[Dict[str, Any]]],
|
241 |
+
):
|
242 |
+
|
243 |
+
ann_keys = ["iscrowd", "category_id"]
|
244 |
+
dataset_dicts = []
|
245 |
+
contains_video_frame_info = False
|
246 |
+
|
247 |
+
for img_dict, ann_dicts in zip(img_datas, ann_datas):
|
248 |
+
record = {}
|
249 |
+
record["file_name"] = os.path.join(image_root, img_dict["file_name"])
|
250 |
+
record["height"] = img_dict["height"]
|
251 |
+
record["width"] = img_dict["width"]
|
252 |
+
record["image_id"] = img_dict["id"]
|
253 |
+
record["dataset"] = dataset_name
|
254 |
+
if "frame_id" in img_dict:
|
255 |
+
record["frame_id"] = img_dict["frame_id"]
|
256 |
+
record["video_id"] = img_dict.get("vid_id", None)
|
257 |
+
contains_video_frame_info = True
|
258 |
+
objs = []
|
259 |
+
for ann_dict in ann_dicts:
|
260 |
+
assert ann_dict["image_id"] == record["image_id"]
|
261 |
+
assert ann_dict.get("ignore", 0) == 0
|
262 |
+
obj = {key: ann_dict[key] for key in ann_keys if key in ann_dict}
|
263 |
+
_maybe_add_bbox(obj, ann_dict)
|
264 |
+
_maybe_add_segm(obj, ann_dict)
|
265 |
+
_maybe_add_keypoints(obj, ann_dict)
|
266 |
+
_maybe_add_densepose(obj, ann_dict)
|
267 |
+
objs.append(obj)
|
268 |
+
record["annotations"] = objs
|
269 |
+
dataset_dicts.append(record)
|
270 |
+
if contains_video_frame_info:
|
271 |
+
create_video_frame_mapping(dataset_name, dataset_dicts)
|
272 |
+
return dataset_dicts
|
273 |
+
|
274 |
+
|
275 |
+
def get_contiguous_id_to_category_id_map(metadata):
|
276 |
+
cat_id_2_cont_id = metadata.thing_dataset_id_to_contiguous_id
|
277 |
+
cont_id_2_cat_id = {}
|
278 |
+
for cat_id, cont_id in cat_id_2_cont_id.items():
|
279 |
+
if cont_id in cont_id_2_cat_id:
|
280 |
+
continue
|
281 |
+
cont_id_2_cat_id[cont_id] = cat_id
|
282 |
+
return cont_id_2_cat_id
|
283 |
+
|
284 |
+
|
285 |
+
def maybe_filter_categories_cocoapi(dataset_name, coco_api):
|
286 |
+
meta = MetadataCatalog.get(dataset_name)
|
287 |
+
cont_id_2_cat_id = get_contiguous_id_to_category_id_map(meta)
|
288 |
+
cat_id_2_cont_id = meta.thing_dataset_id_to_contiguous_id
|
289 |
+
# filter categories
|
290 |
+
cats = []
|
291 |
+
for cat in coco_api.dataset["categories"]:
|
292 |
+
cat_id = cat["id"]
|
293 |
+
if cat_id not in cat_id_2_cont_id:
|
294 |
+
continue
|
295 |
+
cont_id = cat_id_2_cont_id[cat_id]
|
296 |
+
if (cont_id in cont_id_2_cat_id) and (cont_id_2_cat_id[cont_id] == cat_id):
|
297 |
+
cats.append(cat)
|
298 |
+
coco_api.dataset["categories"] = cats
|
299 |
+
# filter annotations, if multiple categories are mapped to a single
|
300 |
+
# contiguous ID, use only one category ID and map all annotations to that category ID
|
301 |
+
anns = []
|
302 |
+
for ann in coco_api.dataset["annotations"]:
|
303 |
+
cat_id = ann["category_id"]
|
304 |
+
if cat_id not in cat_id_2_cont_id:
|
305 |
+
continue
|
306 |
+
cont_id = cat_id_2_cont_id[cat_id]
|
307 |
+
ann["category_id"] = cont_id_2_cat_id[cont_id]
|
308 |
+
anns.append(ann)
|
309 |
+
coco_api.dataset["annotations"] = anns
|
310 |
+
# recreate index
|
311 |
+
coco_api.createIndex()
|
312 |
+
|
313 |
+
|
314 |
+
def maybe_filter_and_map_categories_cocoapi(dataset_name, coco_api):
|
315 |
+
meta = MetadataCatalog.get(dataset_name)
|
316 |
+
category_id_map = meta.thing_dataset_id_to_contiguous_id
|
317 |
+
# map categories
|
318 |
+
cats = []
|
319 |
+
for cat in coco_api.dataset["categories"]:
|
320 |
+
cat_id = cat["id"]
|
321 |
+
if cat_id not in category_id_map:
|
322 |
+
continue
|
323 |
+
cat["id"] = category_id_map[cat_id]
|
324 |
+
cats.append(cat)
|
325 |
+
coco_api.dataset["categories"] = cats
|
326 |
+
# map annotation categories
|
327 |
+
anns = []
|
328 |
+
for ann in coco_api.dataset["annotations"]:
|
329 |
+
cat_id = ann["category_id"]
|
330 |
+
if cat_id not in category_id_map:
|
331 |
+
continue
|
332 |
+
ann["category_id"] = category_id_map[cat_id]
|
333 |
+
anns.append(ann)
|
334 |
+
coco_api.dataset["annotations"] = anns
|
335 |
+
# recreate index
|
336 |
+
coco_api.createIndex()
|
337 |
+
|
338 |
+
|
339 |
+
def create_video_frame_mapping(dataset_name, dataset_dicts):
|
340 |
+
mapping = defaultdict(dict)
|
341 |
+
for d in dataset_dicts:
|
342 |
+
video_id = d.get("video_id")
|
343 |
+
if video_id is None:
|
344 |
+
continue
|
345 |
+
mapping[video_id].update({d["frame_id"]: d["file_name"]})
|
346 |
+
MetadataCatalog.get(dataset_name).set(video_frame_mapping=mapping)
|
347 |
+
|
348 |
+
|
349 |
+
def load_coco_json(annotations_json_file: str, image_root: str, dataset_name: str):
|
350 |
+
"""
|
351 |
+
Loads a JSON file with annotations in COCO instances format.
|
352 |
+
Replaces `detectron2.data.datasets.coco.load_coco_json` to handle metadata
|
353 |
+
in a more flexible way. Postpones category mapping to a later stage to be
|
354 |
+
able to combine several datasets with different (but coherent) sets of
|
355 |
+
categories.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
|
359 |
+
annotations_json_file: str
|
360 |
+
Path to the JSON file with annotations in COCO instances format.
|
361 |
+
image_root: str
|
362 |
+
directory that contains all the images
|
363 |
+
dataset_name: str
|
364 |
+
the name that identifies a dataset, e.g. "densepose_coco_2014_train"
|
365 |
+
extra_annotation_keys: Optional[List[str]]
|
366 |
+
If provided, these keys are used to extract additional data from
|
367 |
+
the annotations.
|
368 |
+
"""
|
369 |
+
coco_api = _load_coco_annotations(PathManager.get_local_path(annotations_json_file))
|
370 |
+
_add_categories_metadata(dataset_name, coco_api.loadCats(coco_api.getCatIds()))
|
371 |
+
# sort indices for reproducible results
|
372 |
+
img_ids = sorted(coco_api.imgs.keys())
|
373 |
+
# imgs is a list of dicts, each looks something like:
|
374 |
+
# {'license': 4,
|
375 |
+
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
|
376 |
+
# 'file_name': 'COCO_val2014_000000001268.jpg',
|
377 |
+
# 'height': 427,
|
378 |
+
# 'width': 640,
|
379 |
+
# 'date_captured': '2013-11-17 05:57:24',
|
380 |
+
# 'id': 1268}
|
381 |
+
imgs = coco_api.loadImgs(img_ids)
|
382 |
+
logger = logging.getLogger(__name__)
|
383 |
+
logger.info("Loaded {} images in COCO format from {}".format(len(imgs), annotations_json_file))
|
384 |
+
# anns is a list[list[dict]], where each dict is an annotation
|
385 |
+
# record for an object. The inner list enumerates the objects in an image
|
386 |
+
# and the outer list enumerates over images.
|
387 |
+
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
|
388 |
+
_verify_annotations_have_unique_ids(annotations_json_file, anns)
|
389 |
+
dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns)
|
390 |
+
return dataset_records
|
391 |
+
|
392 |
+
|
393 |
+
def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None):
|
394 |
+
"""
|
395 |
+
Registers provided COCO DensePose dataset
|
396 |
+
|
397 |
+
Args:
|
398 |
+
dataset_data: CocoDatasetInfo
|
399 |
+
Dataset data
|
400 |
+
datasets_root: Optional[str]
|
401 |
+
Datasets root folder (default: None)
|
402 |
+
"""
|
403 |
+
annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath)
|
404 |
+
images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root)
|
405 |
+
|
406 |
+
def load_annotations():
|
407 |
+
return load_coco_json(
|
408 |
+
annotations_json_file=annotations_fpath,
|
409 |
+
image_root=images_root,
|
410 |
+
dataset_name=dataset_data.name,
|
411 |
+
)
|
412 |
+
|
413 |
+
DatasetCatalog.register(dataset_data.name, load_annotations)
|
414 |
+
MetadataCatalog.get(dataset_data.name).set(
|
415 |
+
json_file=annotations_fpath,
|
416 |
+
image_root=images_root,
|
417 |
+
**get_metadata(DENSEPOSE_METADATA_URL_PREFIX)
|
418 |
+
)
|
419 |
+
|
420 |
+
|
421 |
+
def register_datasets(
|
422 |
+
datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None
|
423 |
+
):
|
424 |
+
"""
|
425 |
+
Registers provided COCO DensePose datasets
|
426 |
+
|
427 |
+
Args:
|
428 |
+
datasets_data: Iterable[CocoDatasetInfo]
|
429 |
+
An iterable of dataset datas
|
430 |
+
datasets_root: Optional[str]
|
431 |
+
Datasets root folder (default: None)
|
432 |
+
"""
|
433 |
+
for dataset_data in datasets_data:
|
434 |
+
register_dataset(dataset_data, datasets_root)
|
densepose/data/datasets/dataset_type.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from enum import Enum
|
6 |
+
|
7 |
+
|
8 |
+
class DatasetType(Enum):
|
9 |
+
"""
|
10 |
+
Dataset type, mostly used for datasets that contain data to bootstrap models on
|
11 |
+
"""
|
12 |
+
|
13 |
+
VIDEO_LIST = "video_list"
|
densepose/data/datasets/lvis.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from typing import Any, Dict, Iterable, List, Optional
|
7 |
+
from fvcore.common.timer import Timer
|
8 |
+
|
9 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
10 |
+
from detectron2.data.datasets.lvis import get_lvis_instances_meta
|
11 |
+
from detectron2.structures import BoxMode
|
12 |
+
from detectron2.utils.file_io import PathManager
|
13 |
+
|
14 |
+
from ..utils import maybe_prepend_base_path
|
15 |
+
from .coco import (
|
16 |
+
DENSEPOSE_ALL_POSSIBLE_KEYS,
|
17 |
+
DENSEPOSE_METADATA_URL_PREFIX,
|
18 |
+
CocoDatasetInfo,
|
19 |
+
get_metadata,
|
20 |
+
)
|
21 |
+
|
22 |
+
DATASETS = [
|
23 |
+
CocoDatasetInfo(
|
24 |
+
name="densepose_lvis_v1_ds1_train_v1",
|
25 |
+
images_root="coco_",
|
26 |
+
annotations_fpath="lvis/densepose_lvis_v1_ds1_train_v1.json",
|
27 |
+
),
|
28 |
+
CocoDatasetInfo(
|
29 |
+
name="densepose_lvis_v1_ds1_val_v1",
|
30 |
+
images_root="coco_",
|
31 |
+
annotations_fpath="lvis/densepose_lvis_v1_ds1_val_v1.json",
|
32 |
+
),
|
33 |
+
CocoDatasetInfo(
|
34 |
+
name="densepose_lvis_v1_ds2_train_v1",
|
35 |
+
images_root="coco_",
|
36 |
+
annotations_fpath="lvis/densepose_lvis_v1_ds2_train_v1.json",
|
37 |
+
),
|
38 |
+
CocoDatasetInfo(
|
39 |
+
name="densepose_lvis_v1_ds2_val_v1",
|
40 |
+
images_root="coco_",
|
41 |
+
annotations_fpath="lvis/densepose_lvis_v1_ds2_val_v1.json",
|
42 |
+
),
|
43 |
+
CocoDatasetInfo(
|
44 |
+
name="densepose_lvis_v1_ds1_val_animals_100",
|
45 |
+
images_root="coco_",
|
46 |
+
annotations_fpath="lvis/densepose_lvis_v1_val_animals_100_v2.json",
|
47 |
+
),
|
48 |
+
]
|
49 |
+
|
50 |
+
|
51 |
+
def _load_lvis_annotations(json_file: str):
|
52 |
+
"""
|
53 |
+
Load COCO annotations from a JSON file
|
54 |
+
|
55 |
+
Args:
|
56 |
+
json_file: str
|
57 |
+
Path to the file to load annotations from
|
58 |
+
Returns:
|
59 |
+
Instance of `pycocotools.coco.COCO` that provides access to annotations
|
60 |
+
data
|
61 |
+
"""
|
62 |
+
from lvis import LVIS
|
63 |
+
|
64 |
+
json_file = PathManager.get_local_path(json_file)
|
65 |
+
logger = logging.getLogger(__name__)
|
66 |
+
timer = Timer()
|
67 |
+
lvis_api = LVIS(json_file)
|
68 |
+
if timer.seconds() > 1:
|
69 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
70 |
+
return lvis_api
|
71 |
+
|
72 |
+
|
73 |
+
def _add_categories_metadata(dataset_name: str) -> None:
|
74 |
+
metadict = get_lvis_instances_meta(dataset_name)
|
75 |
+
categories = metadict["thing_classes"]
|
76 |
+
metadata = MetadataCatalog.get(dataset_name)
|
77 |
+
metadata.categories = {i + 1: categories[i] for i in range(len(categories))}
|
78 |
+
logger = logging.getLogger(__name__)
|
79 |
+
logger.info(f"Dataset {dataset_name} has {len(categories)} categories")
|
80 |
+
|
81 |
+
|
82 |
+
def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]) -> None:
|
83 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
84 |
+
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
|
85 |
+
json_file
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
|
90 |
+
if "bbox" not in ann_dict:
|
91 |
+
return
|
92 |
+
obj["bbox"] = ann_dict["bbox"]
|
93 |
+
obj["bbox_mode"] = BoxMode.XYWH_ABS
|
94 |
+
|
95 |
+
|
96 |
+
def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
|
97 |
+
if "segmentation" not in ann_dict:
|
98 |
+
return
|
99 |
+
segm = ann_dict["segmentation"]
|
100 |
+
if not isinstance(segm, dict):
|
101 |
+
# filter out invalid polygons (< 3 points)
|
102 |
+
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
|
103 |
+
if len(segm) == 0:
|
104 |
+
return
|
105 |
+
obj["segmentation"] = segm
|
106 |
+
|
107 |
+
|
108 |
+
def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
|
109 |
+
if "keypoints" not in ann_dict:
|
110 |
+
return
|
111 |
+
keypts = ann_dict["keypoints"] # list[int]
|
112 |
+
for idx, v in enumerate(keypts):
|
113 |
+
if idx % 3 != 2:
|
114 |
+
# COCO's segmentation coordinates are floating points in [0, H or W],
|
115 |
+
# but keypoint coordinates are integers in [0, H-1 or W-1]
|
116 |
+
# Therefore we assume the coordinates are "pixel indices" and
|
117 |
+
# add 0.5 to convert to floating point coordinates.
|
118 |
+
keypts[idx] = v + 0.5
|
119 |
+
obj["keypoints"] = keypts
|
120 |
+
|
121 |
+
|
122 |
+
def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
|
123 |
+
for key in DENSEPOSE_ALL_POSSIBLE_KEYS:
|
124 |
+
if key in ann_dict:
|
125 |
+
obj[key] = ann_dict[key]
|
126 |
+
|
127 |
+
|
128 |
+
def _combine_images_with_annotations(
|
129 |
+
dataset_name: str,
|
130 |
+
image_root: str,
|
131 |
+
img_datas: Iterable[Dict[str, Any]],
|
132 |
+
ann_datas: Iterable[Iterable[Dict[str, Any]]],
|
133 |
+
):
|
134 |
+
|
135 |
+
dataset_dicts = []
|
136 |
+
|
137 |
+
def get_file_name(img_root, img_dict):
|
138 |
+
# Determine the path including the split folder ("train2017", "val2017", "test2017") from
|
139 |
+
# the coco_url field. Example:
|
140 |
+
# 'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg'
|
141 |
+
split_folder, file_name = img_dict["coco_url"].split("/")[-2:]
|
142 |
+
return os.path.join(img_root + split_folder, file_name)
|
143 |
+
|
144 |
+
for img_dict, ann_dicts in zip(img_datas, ann_datas):
|
145 |
+
record = {}
|
146 |
+
record["file_name"] = get_file_name(image_root, img_dict)
|
147 |
+
record["height"] = img_dict["height"]
|
148 |
+
record["width"] = img_dict["width"]
|
149 |
+
record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", [])
|
150 |
+
record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
|
151 |
+
record["image_id"] = img_dict["id"]
|
152 |
+
record["dataset"] = dataset_name
|
153 |
+
|
154 |
+
objs = []
|
155 |
+
for ann_dict in ann_dicts:
|
156 |
+
assert ann_dict["image_id"] == record["image_id"]
|
157 |
+
obj = {}
|
158 |
+
_maybe_add_bbox(obj, ann_dict)
|
159 |
+
obj["iscrowd"] = ann_dict.get("iscrowd", 0)
|
160 |
+
obj["category_id"] = ann_dict["category_id"]
|
161 |
+
_maybe_add_segm(obj, ann_dict)
|
162 |
+
_maybe_add_keypoints(obj, ann_dict)
|
163 |
+
_maybe_add_densepose(obj, ann_dict)
|
164 |
+
objs.append(obj)
|
165 |
+
record["annotations"] = objs
|
166 |
+
dataset_dicts.append(record)
|
167 |
+
return dataset_dicts
|
168 |
+
|
169 |
+
|
170 |
+
def load_lvis_json(annotations_json_file: str, image_root: str, dataset_name: str):
|
171 |
+
"""
|
172 |
+
Loads a JSON file with annotations in LVIS instances format.
|
173 |
+
Replaces `detectron2.data.datasets.coco.load_lvis_json` to handle metadata
|
174 |
+
in a more flexible way. Postpones category mapping to a later stage to be
|
175 |
+
able to combine several datasets with different (but coherent) sets of
|
176 |
+
categories.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
|
180 |
+
annotations_json_file: str
|
181 |
+
Path to the JSON file with annotations in COCO instances format.
|
182 |
+
image_root: str
|
183 |
+
directory that contains all the images
|
184 |
+
dataset_name: str
|
185 |
+
the name that identifies a dataset, e.g. "densepose_coco_2014_train"
|
186 |
+
extra_annotation_keys: Optional[List[str]]
|
187 |
+
If provided, these keys are used to extract additional data from
|
188 |
+
the annotations.
|
189 |
+
"""
|
190 |
+
lvis_api = _load_lvis_annotations(PathManager.get_local_path(annotations_json_file))
|
191 |
+
|
192 |
+
_add_categories_metadata(dataset_name)
|
193 |
+
|
194 |
+
# sort indices for reproducible results
|
195 |
+
img_ids = sorted(lvis_api.imgs.keys())
|
196 |
+
# imgs is a list of dicts, each looks something like:
|
197 |
+
# {'license': 4,
|
198 |
+
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
|
199 |
+
# 'file_name': 'COCO_val2014_000000001268.jpg',
|
200 |
+
# 'height': 427,
|
201 |
+
# 'width': 640,
|
202 |
+
# 'date_captured': '2013-11-17 05:57:24',
|
203 |
+
# 'id': 1268}
|
204 |
+
imgs = lvis_api.load_imgs(img_ids)
|
205 |
+
logger = logging.getLogger(__name__)
|
206 |
+
logger.info("Loaded {} images in LVIS format from {}".format(len(imgs), annotations_json_file))
|
207 |
+
# anns is a list[list[dict]], where each dict is an annotation
|
208 |
+
# record for an object. The inner list enumerates the objects in an image
|
209 |
+
# and the outer list enumerates over images.
|
210 |
+
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
|
211 |
+
|
212 |
+
_verify_annotations_have_unique_ids(annotations_json_file, anns)
|
213 |
+
dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns)
|
214 |
+
return dataset_records
|
215 |
+
|
216 |
+
|
217 |
+
def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None) -> None:
|
218 |
+
"""
|
219 |
+
Registers provided LVIS DensePose dataset
|
220 |
+
|
221 |
+
Args:
|
222 |
+
dataset_data: CocoDatasetInfo
|
223 |
+
Dataset data
|
224 |
+
datasets_root: Optional[str]
|
225 |
+
Datasets root folder (default: None)
|
226 |
+
"""
|
227 |
+
annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath)
|
228 |
+
images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root)
|
229 |
+
|
230 |
+
def load_annotations():
|
231 |
+
return load_lvis_json(
|
232 |
+
annotations_json_file=annotations_fpath,
|
233 |
+
image_root=images_root,
|
234 |
+
dataset_name=dataset_data.name,
|
235 |
+
)
|
236 |
+
|
237 |
+
DatasetCatalog.register(dataset_data.name, load_annotations)
|
238 |
+
MetadataCatalog.get(dataset_data.name).set(
|
239 |
+
json_file=annotations_fpath,
|
240 |
+
image_root=images_root,
|
241 |
+
evaluator_type="lvis",
|
242 |
+
**get_metadata(DENSEPOSE_METADATA_URL_PREFIX),
|
243 |
+
)
|
244 |
+
|
245 |
+
|
246 |
+
def register_datasets(
|
247 |
+
datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None
|
248 |
+
) -> None:
|
249 |
+
"""
|
250 |
+
Registers provided LVIS DensePose datasets
|
251 |
+
|
252 |
+
Args:
|
253 |
+
datasets_data: Iterable[CocoDatasetInfo]
|
254 |
+
An iterable of dataset datas
|
255 |
+
datasets_root: Optional[str]
|
256 |
+
Datasets root folder (default: None)
|
257 |
+
"""
|
258 |
+
for dataset_data in datasets_data:
|
259 |
+
register_dataset(dataset_data, datasets_root)
|
densepose/data/image_list_dataset.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
# pyre-unsafe
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import numpy as np
|
8 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
9 |
+
import torch
|
10 |
+
from torch.utils.data.dataset import Dataset
|
11 |
+
|
12 |
+
from detectron2.data.detection_utils import read_image
|
13 |
+
|
14 |
+
ImageTransform = Callable[[torch.Tensor], torch.Tensor]
|
15 |
+
|
16 |
+
|
17 |
+
class ImageListDataset(Dataset):
|
18 |
+
"""
|
19 |
+
Dataset that provides images from a list.
|
20 |
+
"""
|
21 |
+
|
22 |
+
_EMPTY_IMAGE = torch.empty((0, 3, 1, 1))
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
image_list: List[str],
|
27 |
+
category_list: Union[str, List[str], None] = None,
|
28 |
+
transform: Optional[ImageTransform] = None,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
image_list (List[str]): list of paths to image files
|
33 |
+
category_list (Union[str, List[str], None]): list of animal categories for
|
34 |
+
each image. If it is a string, or None, this applies to all images
|
35 |
+
"""
|
36 |
+
if type(category_list) is list:
|
37 |
+
self.category_list = category_list
|
38 |
+
else:
|
39 |
+
self.category_list = [category_list] * len(image_list)
|
40 |
+
assert len(image_list) == len(
|
41 |
+
self.category_list
|
42 |
+
), "length of image and category lists must be equal"
|
43 |
+
self.image_list = image_list
|
44 |
+
self.transform = transform
|
45 |
+
|
46 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
47 |
+
"""
|
48 |
+
Gets selected images from the list
|
49 |
+
|
50 |
+
Args:
|
51 |
+
idx (int): video index in the video list file
|
52 |
+
Returns:
|
53 |
+
A dictionary containing two keys:
|
54 |
+
images (torch.Tensor): tensor of size [N, 3, H, W] (N = 1, or 0 for _EMPTY_IMAGE)
|
55 |
+
categories (List[str]): categories of the frames
|
56 |
+
"""
|
57 |
+
categories = [self.category_list[idx]]
|
58 |
+
fpath = self.image_list[idx]
|
59 |
+
transform = self.transform
|
60 |
+
|
61 |
+
try:
|
62 |
+
image = torch.from_numpy(np.ascontiguousarray(read_image(fpath, format="BGR")))
|
63 |
+
image = image.permute(2, 0, 1).unsqueeze(0).float() # HWC -> NCHW
|
64 |
+
if transform is not None:
|
65 |
+
image = transform(image)
|
66 |
+
return {"images": image, "categories": categories}
|
67 |
+
except (OSError, RuntimeError) as e:
|
68 |
+
logger = logging.getLogger(__name__)
|
69 |
+
logger.warning(f"Error opening image file container {fpath}: {e}")
|
70 |
+
|
71 |
+
return {"images": self._EMPTY_IMAGE, "categories": []}
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
return len(self.image_list)
|
densepose/data/inference_based_loader.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
import random
|
6 |
+
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
SampledData = Any
|
11 |
+
ModelOutput = Any
|
12 |
+
|
13 |
+
|
14 |
+
def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]:
|
15 |
+
"""
|
16 |
+
Group elements of an iterable by chunks of size `n`, e.g.
|
17 |
+
grouper(range(9), 4) ->
|
18 |
+
(0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None)
|
19 |
+
"""
|
20 |
+
it = iter(iterable)
|
21 |
+
while True:
|
22 |
+
values = []
|
23 |
+
for _ in range(n):
|
24 |
+
try:
|
25 |
+
value = next(it)
|
26 |
+
except StopIteration:
|
27 |
+
if values:
|
28 |
+
values.extend([fillvalue] * (n - len(values)))
|
29 |
+
yield tuple(values)
|
30 |
+
return
|
31 |
+
values.append(value)
|
32 |
+
yield tuple(values)
|
33 |
+
|
34 |
+
|
35 |
+
class ScoreBasedFilter:
|
36 |
+
"""
|
37 |
+
Filters entries in model output based on their scores
|
38 |
+
Discards all entries with score less than the specified minimum
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, min_score: float = 0.8):
|
42 |
+
self.min_score = min_score
|
43 |
+
|
44 |
+
def __call__(self, model_output: ModelOutput) -> ModelOutput:
|
45 |
+
for model_output_i in model_output:
|
46 |
+
instances = model_output_i["instances"]
|
47 |
+
if not instances.has("scores"):
|
48 |
+
continue
|
49 |
+
instances_filtered = instances[instances.scores >= self.min_score]
|
50 |
+
model_output_i["instances"] = instances_filtered
|
51 |
+
return model_output
|
52 |
+
|
53 |
+
|
54 |
+
class InferenceBasedLoader:
|
55 |
+
"""
|
56 |
+
Data loader based on results inferred by a model. Consists of:
|
57 |
+
- a data loader that provides batches of images
|
58 |
+
- a model that is used to infer the results
|
59 |
+
- a data sampler that converts inferred results to annotations
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
model: nn.Module,
|
65 |
+
data_loader: Iterable[List[Dict[str, Any]]],
|
66 |
+
data_sampler: Optional[Callable[[ModelOutput], List[SampledData]]] = None,
|
67 |
+
data_filter: Optional[Callable[[ModelOutput], ModelOutput]] = None,
|
68 |
+
shuffle: bool = True,
|
69 |
+
batch_size: int = 4,
|
70 |
+
inference_batch_size: int = 4,
|
71 |
+
drop_last: bool = False,
|
72 |
+
category_to_class_mapping: Optional[dict] = None,
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Constructor
|
76 |
+
|
77 |
+
Args:
|
78 |
+
model (torch.nn.Module): model used to produce data
|
79 |
+
data_loader (Iterable[List[Dict[str, Any]]]): iterable that provides
|
80 |
+
dictionaries with "images" and "categories" fields to perform inference on
|
81 |
+
data_sampler (Callable: ModelOutput -> SampledData): functor
|
82 |
+
that produces annotation data from inference results;
|
83 |
+
(optional, default: None)
|
84 |
+
data_filter (Callable: ModelOutput -> ModelOutput): filter
|
85 |
+
that selects model outputs for further processing
|
86 |
+
(optional, default: None)
|
87 |
+
shuffle (bool): if True, the input images get shuffled
|
88 |
+
batch_size (int): batch size for the produced annotation data
|
89 |
+
inference_batch_size (int): batch size for input images
|
90 |
+
drop_last (bool): if True, drop the last batch if it is undersized
|
91 |
+
category_to_class_mapping (dict): category to class mapping
|
92 |
+
"""
|
93 |
+
self.model = model
|
94 |
+
self.model.eval()
|
95 |
+
self.data_loader = data_loader
|
96 |
+
self.data_sampler = data_sampler
|
97 |
+
self.data_filter = data_filter
|
98 |
+
self.shuffle = shuffle
|
99 |
+
self.batch_size = batch_size
|
100 |
+
self.inference_batch_size = inference_batch_size
|
101 |
+
self.drop_last = drop_last
|
102 |
+
if category_to_class_mapping is not None:
|
103 |
+
self.category_to_class_mapping = category_to_class_mapping
|
104 |
+
else:
|
105 |
+
self.category_to_class_mapping = {}
|
106 |
+
|
107 |
+
def __iter__(self) -> Iterator[List[SampledData]]:
|
108 |
+
for batch in self.data_loader:
|
109 |
+
# batch : List[Dict[str: Tensor[N, C, H, W], str: Optional[str]]]
|
110 |
+
# images_batch : Tensor[N, C, H, W]
|
111 |
+
# image : Tensor[C, H, W]
|
112 |
+
images_and_categories = [
|
113 |
+
{"image": image, "category": category}
|
114 |
+
for element in batch
|
115 |
+
for image, category in zip(element["images"], element["categories"])
|
116 |
+
]
|
117 |
+
if not images_and_categories:
|
118 |
+
continue
|
119 |
+
if self.shuffle:
|
120 |
+
random.shuffle(images_and_categories)
|
121 |
+
yield from self._produce_data(images_and_categories) # pyre-ignore[6]
|
122 |
+
|
123 |
+
def _produce_data(
|
124 |
+
self, images_and_categories: List[Tuple[torch.Tensor, Optional[str]]]
|
125 |
+
) -> Iterator[List[SampledData]]:
|
126 |
+
"""
|
127 |
+
Produce batches of data from images
|
128 |
+
|
129 |
+
Args:
|
130 |
+
images_and_categories (List[Tuple[torch.Tensor, Optional[str]]]):
|
131 |
+
list of images and corresponding categories to process
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
Iterator over batches of data sampled from model outputs
|
135 |
+
"""
|
136 |
+
data_batches: List[SampledData] = []
|
137 |
+
category_to_class_mapping = self.category_to_class_mapping
|
138 |
+
batched_images_and_categories = _grouper(images_and_categories, self.inference_batch_size)
|
139 |
+
for batch in batched_images_and_categories:
|
140 |
+
batch = [
|
141 |
+
{
|
142 |
+
"image": image_and_category["image"].to(self.model.device),
|
143 |
+
"category": image_and_category["category"],
|
144 |
+
}
|
145 |
+
for image_and_category in batch
|
146 |
+
if image_and_category is not None
|
147 |
+
]
|
148 |
+
if not batch:
|
149 |
+
continue
|
150 |
+
with torch.no_grad():
|
151 |
+
model_output = self.model(batch)
|
152 |
+
for model_output_i, batch_i in zip(model_output, batch):
|
153 |
+
assert len(batch_i["image"].shape) == 3
|
154 |
+
model_output_i["image"] = batch_i["image"]
|
155 |
+
instance_class = category_to_class_mapping.get(batch_i["category"], 0)
|
156 |
+
model_output_i["instances"].dataset_classes = torch.tensor(
|
157 |
+
[instance_class] * len(model_output_i["instances"])
|
158 |
+
)
|
159 |
+
model_output_filtered = (
|
160 |
+
model_output if self.data_filter is None else self.data_filter(model_output)
|
161 |
+
)
|
162 |
+
data = (
|
163 |
+
model_output_filtered
|
164 |
+
if self.data_sampler is None
|
165 |
+
else self.data_sampler(model_output_filtered)
|
166 |
+
)
|
167 |
+
for data_i in data:
|
168 |
+
if len(data_i["instances"]):
|
169 |
+
data_batches.append(data_i)
|
170 |
+
if len(data_batches) >= self.batch_size:
|
171 |
+
yield data_batches[: self.batch_size]
|
172 |
+
data_batches = data_batches[self.batch_size :]
|
173 |
+
if not self.drop_last and data_batches:
|
174 |
+
yield data_batches
|
densepose/data/meshes/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from . import builtin
|
6 |
+
|
7 |
+
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
|
densepose/data/meshes/builtin.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from .catalog import MeshInfo, register_meshes
|
6 |
+
|
7 |
+
DENSEPOSE_MESHES_DIR = "https://dl.fbaipublicfiles.com/densepose/meshes/"
|
8 |
+
|
9 |
+
MESHES = [
|
10 |
+
MeshInfo(
|
11 |
+
name="smpl_27554",
|
12 |
+
data="smpl_27554.pkl",
|
13 |
+
geodists="geodists/geodists_smpl_27554.pkl",
|
14 |
+
symmetry="symmetry/symmetry_smpl_27554.pkl",
|
15 |
+
texcoords="texcoords/texcoords_smpl_27554.pkl",
|
16 |
+
),
|
17 |
+
MeshInfo(
|
18 |
+
name="chimp_5029",
|
19 |
+
data="chimp_5029.pkl",
|
20 |
+
geodists="geodists/geodists_chimp_5029.pkl",
|
21 |
+
symmetry="symmetry/symmetry_chimp_5029.pkl",
|
22 |
+
texcoords="texcoords/texcoords_chimp_5029.pkl",
|
23 |
+
),
|
24 |
+
MeshInfo(
|
25 |
+
name="cat_5001",
|
26 |
+
data="cat_5001.pkl",
|
27 |
+
geodists="geodists/geodists_cat_5001.pkl",
|
28 |
+
symmetry="symmetry/symmetry_cat_5001.pkl",
|
29 |
+
texcoords="texcoords/texcoords_cat_5001.pkl",
|
30 |
+
),
|
31 |
+
MeshInfo(
|
32 |
+
name="cat_7466",
|
33 |
+
data="cat_7466.pkl",
|
34 |
+
geodists="geodists/geodists_cat_7466.pkl",
|
35 |
+
symmetry="symmetry/symmetry_cat_7466.pkl",
|
36 |
+
texcoords="texcoords/texcoords_cat_7466.pkl",
|
37 |
+
),
|
38 |
+
MeshInfo(
|
39 |
+
name="sheep_5004",
|
40 |
+
data="sheep_5004.pkl",
|
41 |
+
geodists="geodists/geodists_sheep_5004.pkl",
|
42 |
+
symmetry="symmetry/symmetry_sheep_5004.pkl",
|
43 |
+
texcoords="texcoords/texcoords_sheep_5004.pkl",
|
44 |
+
),
|
45 |
+
MeshInfo(
|
46 |
+
name="zebra_5002",
|
47 |
+
data="zebra_5002.pkl",
|
48 |
+
geodists="geodists/geodists_zebra_5002.pkl",
|
49 |
+
symmetry="symmetry/symmetry_zebra_5002.pkl",
|
50 |
+
texcoords="texcoords/texcoords_zebra_5002.pkl",
|
51 |
+
),
|
52 |
+
MeshInfo(
|
53 |
+
name="horse_5004",
|
54 |
+
data="horse_5004.pkl",
|
55 |
+
geodists="geodists/geodists_horse_5004.pkl",
|
56 |
+
symmetry="symmetry/symmetry_horse_5004.pkl",
|
57 |
+
texcoords="texcoords/texcoords_zebra_5002.pkl",
|
58 |
+
),
|
59 |
+
MeshInfo(
|
60 |
+
name="giraffe_5002",
|
61 |
+
data="giraffe_5002.pkl",
|
62 |
+
geodists="geodists/geodists_giraffe_5002.pkl",
|
63 |
+
symmetry="symmetry/symmetry_giraffe_5002.pkl",
|
64 |
+
texcoords="texcoords/texcoords_giraffe_5002.pkl",
|
65 |
+
),
|
66 |
+
MeshInfo(
|
67 |
+
name="elephant_5002",
|
68 |
+
data="elephant_5002.pkl",
|
69 |
+
geodists="geodists/geodists_elephant_5002.pkl",
|
70 |
+
symmetry="symmetry/symmetry_elephant_5002.pkl",
|
71 |
+
texcoords="texcoords/texcoords_elephant_5002.pkl",
|
72 |
+
),
|
73 |
+
MeshInfo(
|
74 |
+
name="dog_5002",
|
75 |
+
data="dog_5002.pkl",
|
76 |
+
geodists="geodists/geodists_dog_5002.pkl",
|
77 |
+
symmetry="symmetry/symmetry_dog_5002.pkl",
|
78 |
+
texcoords="texcoords/texcoords_dog_5002.pkl",
|
79 |
+
),
|
80 |
+
MeshInfo(
|
81 |
+
name="dog_7466",
|
82 |
+
data="dog_7466.pkl",
|
83 |
+
geodists="geodists/geodists_dog_7466.pkl",
|
84 |
+
symmetry="symmetry/symmetry_dog_7466.pkl",
|
85 |
+
texcoords="texcoords/texcoords_dog_7466.pkl",
|
86 |
+
),
|
87 |
+
MeshInfo(
|
88 |
+
name="cow_5002",
|
89 |
+
data="cow_5002.pkl",
|
90 |
+
geodists="geodists/geodists_cow_5002.pkl",
|
91 |
+
symmetry="symmetry/symmetry_cow_5002.pkl",
|
92 |
+
texcoords="texcoords/texcoords_cow_5002.pkl",
|
93 |
+
),
|
94 |
+
MeshInfo(
|
95 |
+
name="bear_4936",
|
96 |
+
data="bear_4936.pkl",
|
97 |
+
geodists="geodists/geodists_bear_4936.pkl",
|
98 |
+
symmetry="symmetry/symmetry_bear_4936.pkl",
|
99 |
+
texcoords="texcoords/texcoords_bear_4936.pkl",
|
100 |
+
),
|
101 |
+
]
|
102 |
+
|
103 |
+
register_meshes(MESHES, DENSEPOSE_MESHES_DIR)
|
densepose/data/meshes/catalog.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
import logging
|
6 |
+
from collections import UserDict
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Iterable, Optional
|
9 |
+
|
10 |
+
from ..utils import maybe_prepend_base_path
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class MeshInfo:
|
15 |
+
name: str
|
16 |
+
data: str
|
17 |
+
geodists: Optional[str] = None
|
18 |
+
symmetry: Optional[str] = None
|
19 |
+
texcoords: Optional[str] = None
|
20 |
+
|
21 |
+
|
22 |
+
class _MeshCatalog(UserDict):
|
23 |
+
def __init__(self, *args, **kwargs):
|
24 |
+
super().__init__(*args, **kwargs)
|
25 |
+
self.mesh_ids = {}
|
26 |
+
self.mesh_names = {}
|
27 |
+
self.max_mesh_id = -1
|
28 |
+
|
29 |
+
def __setitem__(self, key, value):
|
30 |
+
if key in self:
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
logger.warning(
|
33 |
+
f"Overwriting mesh catalog entry '{key}': old value {self[key]}"
|
34 |
+
f", new value {value}"
|
35 |
+
)
|
36 |
+
mesh_id = self.mesh_ids[key]
|
37 |
+
else:
|
38 |
+
self.max_mesh_id += 1
|
39 |
+
mesh_id = self.max_mesh_id
|
40 |
+
super().__setitem__(key, value)
|
41 |
+
self.mesh_ids[key] = mesh_id
|
42 |
+
self.mesh_names[mesh_id] = key
|
43 |
+
|
44 |
+
def get_mesh_id(self, shape_name: str) -> int:
|
45 |
+
return self.mesh_ids[shape_name]
|
46 |
+
|
47 |
+
def get_mesh_name(self, mesh_id: int) -> str:
|
48 |
+
return self.mesh_names[mesh_id]
|
49 |
+
|
50 |
+
|
51 |
+
MeshCatalog = _MeshCatalog()
|
52 |
+
|
53 |
+
|
54 |
+
def register_mesh(mesh_info: MeshInfo, base_path: Optional[str]) -> None:
|
55 |
+
geodists, symmetry, texcoords = mesh_info.geodists, mesh_info.symmetry, mesh_info.texcoords
|
56 |
+
if geodists:
|
57 |
+
geodists = maybe_prepend_base_path(base_path, geodists)
|
58 |
+
if symmetry:
|
59 |
+
symmetry = maybe_prepend_base_path(base_path, symmetry)
|
60 |
+
if texcoords:
|
61 |
+
texcoords = maybe_prepend_base_path(base_path, texcoords)
|
62 |
+
MeshCatalog[mesh_info.name] = MeshInfo(
|
63 |
+
name=mesh_info.name,
|
64 |
+
data=maybe_prepend_base_path(base_path, mesh_info.data),
|
65 |
+
geodists=geodists,
|
66 |
+
symmetry=symmetry,
|
67 |
+
texcoords=texcoords,
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def register_meshes(mesh_infos: Iterable[MeshInfo], base_path: Optional[str]) -> None:
|
72 |
+
for mesh_info in mesh_infos:
|
73 |
+
register_mesh(mesh_info, base_path)
|
densepose/data/samplers/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from .densepose_uniform import DensePoseUniformSampler
|
6 |
+
from .densepose_confidence_based import DensePoseConfidenceBasedSampler
|
7 |
+
from .densepose_cse_uniform import DensePoseCSEUniformSampler
|
8 |
+
from .densepose_cse_confidence_based import DensePoseCSEConfidenceBasedSampler
|
9 |
+
from .mask_from_densepose import MaskFromDensePoseSampler
|
10 |
+
from .prediction_to_gt import PredictionToGroundTruthSampler
|
densepose/data/samplers/densepose_base.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from typing import Any, Dict, List, Tuple
|
6 |
+
import torch
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from detectron2.structures import BoxMode, Instances
|
10 |
+
|
11 |
+
from densepose.converters import ToChartResultConverter
|
12 |
+
from densepose.converters.base import IntTupleBox, make_int_box
|
13 |
+
from densepose.structures import DensePoseDataRelative, DensePoseList
|
14 |
+
|
15 |
+
|
16 |
+
class DensePoseBaseSampler:
|
17 |
+
"""
|
18 |
+
Base DensePose sampler to produce DensePose data from DensePose predictions.
|
19 |
+
Samples for each class are drawn according to some distribution over all pixels estimated
|
20 |
+
to belong to that class.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, count_per_class: int = 8):
|
24 |
+
"""
|
25 |
+
Constructor
|
26 |
+
|
27 |
+
Args:
|
28 |
+
count_per_class (int): the sampler produces at most `count_per_class`
|
29 |
+
samples for each category
|
30 |
+
"""
|
31 |
+
self.count_per_class = count_per_class
|
32 |
+
|
33 |
+
def __call__(self, instances: Instances) -> DensePoseList:
|
34 |
+
"""
|
35 |
+
Convert DensePose predictions (an instance of `DensePoseChartPredictorOutput`)
|
36 |
+
into DensePose annotations data (an instance of `DensePoseList`)
|
37 |
+
"""
|
38 |
+
boxes_xyxy_abs = instances.pred_boxes.tensor.clone().cpu()
|
39 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
40 |
+
dp_datas = []
|
41 |
+
for i in range(len(boxes_xywh_abs)):
|
42 |
+
annotation_i = self._sample(instances[i], make_int_box(boxes_xywh_abs[i]))
|
43 |
+
annotation_i[DensePoseDataRelative.S_KEY] = self._resample_mask( # pyre-ignore[6]
|
44 |
+
instances[i].pred_densepose
|
45 |
+
)
|
46 |
+
dp_datas.append(DensePoseDataRelative(annotation_i))
|
47 |
+
# create densepose annotations on CPU
|
48 |
+
dp_list = DensePoseList(dp_datas, boxes_xyxy_abs, instances.image_size)
|
49 |
+
return dp_list
|
50 |
+
|
51 |
+
def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
|
52 |
+
"""
|
53 |
+
Sample DensPoseDataRelative from estimation results
|
54 |
+
"""
|
55 |
+
labels, dp_result = self._produce_labels_and_results(instance)
|
56 |
+
annotation = {
|
57 |
+
DensePoseDataRelative.X_KEY: [],
|
58 |
+
DensePoseDataRelative.Y_KEY: [],
|
59 |
+
DensePoseDataRelative.U_KEY: [],
|
60 |
+
DensePoseDataRelative.V_KEY: [],
|
61 |
+
DensePoseDataRelative.I_KEY: [],
|
62 |
+
}
|
63 |
+
n, h, w = dp_result.shape
|
64 |
+
for part_id in range(1, DensePoseDataRelative.N_PART_LABELS + 1):
|
65 |
+
# indices - tuple of 3 1D tensors of size k
|
66 |
+
# 0: index along the first dimension N
|
67 |
+
# 1: index along H dimension
|
68 |
+
# 2: index along W dimension
|
69 |
+
indices = torch.nonzero(labels.expand(n, h, w) == part_id, as_tuple=True)
|
70 |
+
# values - an array of size [n, k]
|
71 |
+
# n: number of channels (U, V, confidences)
|
72 |
+
# k: number of points labeled with part_id
|
73 |
+
values = dp_result[indices].view(n, -1)
|
74 |
+
k = values.shape[1]
|
75 |
+
count = min(self.count_per_class, k)
|
76 |
+
if count <= 0:
|
77 |
+
continue
|
78 |
+
index_sample = self._produce_index_sample(values, count)
|
79 |
+
sampled_values = values[:, index_sample]
|
80 |
+
sampled_y = indices[1][index_sample] + 0.5
|
81 |
+
sampled_x = indices[2][index_sample] + 0.5
|
82 |
+
# prepare / normalize data
|
83 |
+
x = (sampled_x / w * 256.0).cpu().tolist()
|
84 |
+
y = (sampled_y / h * 256.0).cpu().tolist()
|
85 |
+
u = sampled_values[0].clamp(0, 1).cpu().tolist()
|
86 |
+
v = sampled_values[1].clamp(0, 1).cpu().tolist()
|
87 |
+
fine_segm_labels = [part_id] * count
|
88 |
+
# extend annotations
|
89 |
+
annotation[DensePoseDataRelative.X_KEY].extend(x)
|
90 |
+
annotation[DensePoseDataRelative.Y_KEY].extend(y)
|
91 |
+
annotation[DensePoseDataRelative.U_KEY].extend(u)
|
92 |
+
annotation[DensePoseDataRelative.V_KEY].extend(v)
|
93 |
+
annotation[DensePoseDataRelative.I_KEY].extend(fine_segm_labels)
|
94 |
+
return annotation
|
95 |
+
|
96 |
+
def _produce_index_sample(self, values: torch.Tensor, count: int):
|
97 |
+
"""
|
98 |
+
Abstract method to produce a sample of indices to select data
|
99 |
+
To be implemented in descendants
|
100 |
+
|
101 |
+
Args:
|
102 |
+
values (torch.Tensor): an array of size [n, k] that contains
|
103 |
+
estimated values (U, V, confidences);
|
104 |
+
n: number of channels (U, V, confidences)
|
105 |
+
k: number of points labeled with part_id
|
106 |
+
count (int): number of samples to produce, should be positive and <= k
|
107 |
+
|
108 |
+
Return:
|
109 |
+
list(int): indices of values (along axis 1) selected as a sample
|
110 |
+
"""
|
111 |
+
raise NotImplementedError
|
112 |
+
|
113 |
+
def _produce_labels_and_results(self, instance: Instances) -> Tuple[torch.Tensor, torch.Tensor]:
|
114 |
+
"""
|
115 |
+
Method to get labels and DensePose results from an instance
|
116 |
+
|
117 |
+
Args:
|
118 |
+
instance (Instances): an instance of `DensePoseChartPredictorOutput`
|
119 |
+
|
120 |
+
Return:
|
121 |
+
labels (torch.Tensor): shape [H, W], DensePose segmentation labels
|
122 |
+
dp_result (torch.Tensor): shape [2, H, W], stacked DensePose results u and v
|
123 |
+
"""
|
124 |
+
converter = ToChartResultConverter
|
125 |
+
chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
|
126 |
+
labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
|
127 |
+
return labels, dp_result
|
128 |
+
|
129 |
+
def _resample_mask(self, output: Any) -> torch.Tensor:
|
130 |
+
"""
|
131 |
+
Convert DensePose predictor output to segmentation annotation - tensors of size
|
132 |
+
(256, 256) and type `int64`.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
output: DensePose predictor output with the following attributes:
|
136 |
+
- coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
|
137 |
+
segmentation scores
|
138 |
+
- fine_segm: tensor of size [N, C, H, W] with unnormalized fine
|
139 |
+
segmentation scores
|
140 |
+
Return:
|
141 |
+
Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
|
142 |
+
where S = DensePoseDataRelative.MASK_SIZE
|
143 |
+
"""
|
144 |
+
sz = DensePoseDataRelative.MASK_SIZE
|
145 |
+
S = (
|
146 |
+
F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
|
147 |
+
.argmax(dim=1)
|
148 |
+
.long()
|
149 |
+
)
|
150 |
+
I = (
|
151 |
+
(
|
152 |
+
F.interpolate(
|
153 |
+
output.fine_segm,
|
154 |
+
(sz, sz),
|
155 |
+
mode="bilinear",
|
156 |
+
align_corners=False,
|
157 |
+
).argmax(dim=1)
|
158 |
+
* (S > 0).long()
|
159 |
+
)
|
160 |
+
.squeeze()
|
161 |
+
.cpu()
|
162 |
+
)
|
163 |
+
# Map fine segmentation results to coarse segmentation ground truth
|
164 |
+
# TODO: extract this into separate classes
|
165 |
+
# coarse segmentation: 1 = Torso, 2 = Right Hand, 3 = Left Hand,
|
166 |
+
# 4 = Left Foot, 5 = Right Foot, 6 = Upper Leg Right, 7 = Upper Leg Left,
|
167 |
+
# 8 = Lower Leg Right, 9 = Lower Leg Left, 10 = Upper Arm Left,
|
168 |
+
# 11 = Upper Arm Right, 12 = Lower Arm Left, 13 = Lower Arm Right,
|
169 |
+
# 14 = Head
|
170 |
+
# fine segmentation: 1, 2 = Torso, 3 = Right Hand, 4 = Left Hand,
|
171 |
+
# 5 = Left Foot, 6 = Right Foot, 7, 9 = Upper Leg Right,
|
172 |
+
# 8, 10 = Upper Leg Left, 11, 13 = Lower Leg Right,
|
173 |
+
# 12, 14 = Lower Leg Left, 15, 17 = Upper Arm Left,
|
174 |
+
# 16, 18 = Upper Arm Right, 19, 21 = Lower Arm Left,
|
175 |
+
# 20, 22 = Lower Arm Right, 23, 24 = Head
|
176 |
+
FINE_TO_COARSE_SEGMENTATION = {
|
177 |
+
1: 1,
|
178 |
+
2: 1,
|
179 |
+
3: 2,
|
180 |
+
4: 3,
|
181 |
+
5: 4,
|
182 |
+
6: 5,
|
183 |
+
7: 6,
|
184 |
+
8: 7,
|
185 |
+
9: 6,
|
186 |
+
10: 7,
|
187 |
+
11: 8,
|
188 |
+
12: 9,
|
189 |
+
13: 8,
|
190 |
+
14: 9,
|
191 |
+
15: 10,
|
192 |
+
16: 11,
|
193 |
+
17: 10,
|
194 |
+
18: 11,
|
195 |
+
19: 12,
|
196 |
+
20: 13,
|
197 |
+
21: 12,
|
198 |
+
22: 13,
|
199 |
+
23: 14,
|
200 |
+
24: 14,
|
201 |
+
}
|
202 |
+
mask = torch.zeros((sz, sz), dtype=torch.int64, device=torch.device("cpu"))
|
203 |
+
for i in range(DensePoseDataRelative.N_PART_LABELS):
|
204 |
+
mask[I == i + 1] = FINE_TO_COARSE_SEGMENTATION[i + 1]
|
205 |
+
return mask
|
densepose/data/samplers/densepose_confidence_based.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
import random
|
6 |
+
from typing import Optional, Tuple
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from densepose.converters import ToChartResultConverterWithConfidences
|
10 |
+
|
11 |
+
from .densepose_base import DensePoseBaseSampler
|
12 |
+
|
13 |
+
|
14 |
+
class DensePoseConfidenceBasedSampler(DensePoseBaseSampler):
|
15 |
+
"""
|
16 |
+
Samples DensePose data from DensePose predictions.
|
17 |
+
Samples for each class are drawn using confidence value estimates.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
confidence_channel: str,
|
23 |
+
count_per_class: int = 8,
|
24 |
+
search_count_multiplier: Optional[float] = None,
|
25 |
+
search_proportion: Optional[float] = None,
|
26 |
+
):
|
27 |
+
"""
|
28 |
+
Constructor
|
29 |
+
|
30 |
+
Args:
|
31 |
+
confidence_channel (str): confidence channel to use for sampling;
|
32 |
+
possible values:
|
33 |
+
"sigma_2": confidences for UV values
|
34 |
+
"fine_segm_confidence": confidences for fine segmentation
|
35 |
+
"coarse_segm_confidence": confidences for coarse segmentation
|
36 |
+
(default: "sigma_2")
|
37 |
+
count_per_class (int): the sampler produces at most `count_per_class`
|
38 |
+
samples for each category (default: 8)
|
39 |
+
search_count_multiplier (float or None): if not None, the total number
|
40 |
+
of the most confident estimates of a given class to consider is
|
41 |
+
defined as `min(search_count_multiplier * count_per_class, N)`,
|
42 |
+
where `N` is the total number of estimates of the class; cannot be
|
43 |
+
specified together with `search_proportion` (default: None)
|
44 |
+
search_proportion (float or None): if not None, the total number of the
|
45 |
+
of the most confident estimates of a given class to consider is
|
46 |
+
defined as `min(max(search_proportion * N, count_per_class), N)`,
|
47 |
+
where `N` is the total number of estimates of the class; cannot be
|
48 |
+
specified together with `search_count_multiplier` (default: None)
|
49 |
+
"""
|
50 |
+
super().__init__(count_per_class)
|
51 |
+
self.confidence_channel = confidence_channel
|
52 |
+
self.search_count_multiplier = search_count_multiplier
|
53 |
+
self.search_proportion = search_proportion
|
54 |
+
assert (search_count_multiplier is None) or (search_proportion is None), (
|
55 |
+
f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
|
56 |
+
f"and search_proportion (={search_proportion})"
|
57 |
+
)
|
58 |
+
|
59 |
+
def _produce_index_sample(self, values: torch.Tensor, count: int):
|
60 |
+
"""
|
61 |
+
Produce a sample of indices to select data based on confidences
|
62 |
+
|
63 |
+
Args:
|
64 |
+
values (torch.Tensor): an array of size [n, k] that contains
|
65 |
+
estimated values (U, V, confidences);
|
66 |
+
n: number of channels (U, V, confidences)
|
67 |
+
k: number of points labeled with part_id
|
68 |
+
count (int): number of samples to produce, should be positive and <= k
|
69 |
+
|
70 |
+
Return:
|
71 |
+
list(int): indices of values (along axis 1) selected as a sample
|
72 |
+
"""
|
73 |
+
k = values.shape[1]
|
74 |
+
if k == count:
|
75 |
+
index_sample = list(range(k))
|
76 |
+
else:
|
77 |
+
# take the best count * search_count_multiplier pixels,
|
78 |
+
# sample from them uniformly
|
79 |
+
# (here best = smallest variance)
|
80 |
+
_, sorted_confidence_indices = torch.sort(values[2])
|
81 |
+
if self.search_count_multiplier is not None:
|
82 |
+
search_count = min(int(count * self.search_count_multiplier), k)
|
83 |
+
elif self.search_proportion is not None:
|
84 |
+
search_count = min(max(int(k * self.search_proportion), count), k)
|
85 |
+
else:
|
86 |
+
search_count = min(count, k)
|
87 |
+
sample_from_top = random.sample(range(search_count), count)
|
88 |
+
index_sample = sorted_confidence_indices[:search_count][sample_from_top]
|
89 |
+
return index_sample
|
90 |
+
|
91 |
+
def _produce_labels_and_results(self, instance) -> Tuple[torch.Tensor, torch.Tensor]:
|
92 |
+
"""
|
93 |
+
Method to get labels and DensePose results from an instance, with confidences
|
94 |
+
|
95 |
+
Args:
|
96 |
+
instance (Instances): an instance of `DensePoseChartPredictorOutputWithConfidences`
|
97 |
+
|
98 |
+
Return:
|
99 |
+
labels (torch.Tensor): shape [H, W], DensePose segmentation labels
|
100 |
+
dp_result (torch.Tensor): shape [3, H, W], DensePose results u and v
|
101 |
+
stacked with the confidence channel
|
102 |
+
"""
|
103 |
+
converter = ToChartResultConverterWithConfidences
|
104 |
+
chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
|
105 |
+
labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
|
106 |
+
dp_result = torch.cat(
|
107 |
+
(dp_result, getattr(chart_result, self.confidence_channel)[None].cpu())
|
108 |
+
)
|
109 |
+
|
110 |
+
return labels, dp_result
|
densepose/data/samplers/densepose_cse_base.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from typing import Any, Dict, List, Tuple
|
6 |
+
import torch
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from detectron2.config import CfgNode
|
10 |
+
from detectron2.structures import Instances
|
11 |
+
|
12 |
+
from densepose.converters.base import IntTupleBox
|
13 |
+
from densepose.data.utils import get_class_to_mesh_name_mapping
|
14 |
+
from densepose.modeling.cse.utils import squared_euclidean_distance_matrix
|
15 |
+
from densepose.structures import DensePoseDataRelative
|
16 |
+
|
17 |
+
from .densepose_base import DensePoseBaseSampler
|
18 |
+
|
19 |
+
|
20 |
+
class DensePoseCSEBaseSampler(DensePoseBaseSampler):
|
21 |
+
"""
|
22 |
+
Base DensePose sampler to produce DensePose data from DensePose predictions.
|
23 |
+
Samples for each class are drawn according to some distribution over all pixels estimated
|
24 |
+
to belong to that class.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
cfg: CfgNode,
|
30 |
+
use_gt_categories: bool,
|
31 |
+
embedder: torch.nn.Module,
|
32 |
+
count_per_class: int = 8,
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
Constructor
|
36 |
+
|
37 |
+
Args:
|
38 |
+
cfg (CfgNode): the config of the model
|
39 |
+
embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
|
40 |
+
count_per_class (int): the sampler produces at most `count_per_class`
|
41 |
+
samples for each category
|
42 |
+
"""
|
43 |
+
super().__init__(count_per_class)
|
44 |
+
self.embedder = embedder
|
45 |
+
self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
|
46 |
+
self.use_gt_categories = use_gt_categories
|
47 |
+
|
48 |
+
def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
|
49 |
+
"""
|
50 |
+
Sample DensPoseDataRelative from estimation results
|
51 |
+
"""
|
52 |
+
if self.use_gt_categories:
|
53 |
+
instance_class = instance.dataset_classes.tolist()[0]
|
54 |
+
else:
|
55 |
+
instance_class = instance.pred_classes.tolist()[0]
|
56 |
+
mesh_name = self.class_to_mesh_name[instance_class]
|
57 |
+
|
58 |
+
annotation = {
|
59 |
+
DensePoseDataRelative.X_KEY: [],
|
60 |
+
DensePoseDataRelative.Y_KEY: [],
|
61 |
+
DensePoseDataRelative.VERTEX_IDS_KEY: [],
|
62 |
+
DensePoseDataRelative.MESH_NAME_KEY: mesh_name,
|
63 |
+
}
|
64 |
+
|
65 |
+
mask, embeddings, other_values = self._produce_mask_and_results(instance, bbox_xywh)
|
66 |
+
indices = torch.nonzero(mask, as_tuple=True)
|
67 |
+
selected_embeddings = embeddings.permute(1, 2, 0)[indices].cpu()
|
68 |
+
values = other_values[:, indices[0], indices[1]]
|
69 |
+
k = values.shape[1]
|
70 |
+
|
71 |
+
count = min(self.count_per_class, k)
|
72 |
+
if count <= 0:
|
73 |
+
return annotation
|
74 |
+
|
75 |
+
index_sample = self._produce_index_sample(values, count)
|
76 |
+
closest_vertices = squared_euclidean_distance_matrix(
|
77 |
+
selected_embeddings[index_sample], self.embedder(mesh_name)
|
78 |
+
)
|
79 |
+
closest_vertices = torch.argmin(closest_vertices, dim=1)
|
80 |
+
|
81 |
+
sampled_y = indices[0][index_sample] + 0.5
|
82 |
+
sampled_x = indices[1][index_sample] + 0.5
|
83 |
+
# prepare / normalize data
|
84 |
+
_, _, w, h = bbox_xywh
|
85 |
+
x = (sampled_x / w * 256.0).cpu().tolist()
|
86 |
+
y = (sampled_y / h * 256.0).cpu().tolist()
|
87 |
+
# extend annotations
|
88 |
+
annotation[DensePoseDataRelative.X_KEY].extend(x)
|
89 |
+
annotation[DensePoseDataRelative.Y_KEY].extend(y)
|
90 |
+
annotation[DensePoseDataRelative.VERTEX_IDS_KEY].extend(closest_vertices.cpu().tolist())
|
91 |
+
return annotation
|
92 |
+
|
93 |
+
def _produce_mask_and_results(
|
94 |
+
self, instance: Instances, bbox_xywh: IntTupleBox
|
95 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
96 |
+
"""
|
97 |
+
Method to get labels and DensePose results from an instance
|
98 |
+
|
99 |
+
Args:
|
100 |
+
instance (Instances): an instance of `DensePoseEmbeddingPredictorOutput`
|
101 |
+
bbox_xywh (IntTupleBox): the corresponding bounding box
|
102 |
+
|
103 |
+
Return:
|
104 |
+
mask (torch.Tensor): shape [H, W], DensePose segmentation mask
|
105 |
+
embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W],
|
106 |
+
DensePose CSE Embeddings
|
107 |
+
other_values (Tuple[torch.Tensor]): a tensor of shape [0, H, W],
|
108 |
+
for potential other values
|
109 |
+
"""
|
110 |
+
densepose_output = instance.pred_densepose
|
111 |
+
S = densepose_output.coarse_segm
|
112 |
+
E = densepose_output.embedding
|
113 |
+
_, _, w, h = bbox_xywh
|
114 |
+
embeddings = F.interpolate(E, size=(h, w), mode="bilinear")[0]
|
115 |
+
coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0]
|
116 |
+
mask = coarse_segm_resized.argmax(0) > 0
|
117 |
+
other_values = torch.empty((0, h, w), device=E.device)
|
118 |
+
return mask, embeddings, other_values
|
119 |
+
|
120 |
+
def _resample_mask(self, output: Any) -> torch.Tensor:
|
121 |
+
"""
|
122 |
+
Convert DensePose predictor output to segmentation annotation - tensors of size
|
123 |
+
(256, 256) and type `int64`.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
output: DensePose predictor output with the following attributes:
|
127 |
+
- coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
|
128 |
+
segmentation scores
|
129 |
+
Return:
|
130 |
+
Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
|
131 |
+
where S = DensePoseDataRelative.MASK_SIZE
|
132 |
+
"""
|
133 |
+
sz = DensePoseDataRelative.MASK_SIZE
|
134 |
+
mask = (
|
135 |
+
F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
|
136 |
+
.argmax(dim=1)
|
137 |
+
.long()
|
138 |
+
.squeeze()
|
139 |
+
.cpu()
|
140 |
+
)
|
141 |
+
return mask
|
densepose/data/samplers/densepose_cse_confidence_based.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
import random
|
6 |
+
from typing import Optional, Tuple
|
7 |
+
import torch
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from detectron2.config import CfgNode
|
11 |
+
from detectron2.structures import Instances
|
12 |
+
|
13 |
+
from densepose.converters.base import IntTupleBox
|
14 |
+
|
15 |
+
from .densepose_cse_base import DensePoseCSEBaseSampler
|
16 |
+
|
17 |
+
|
18 |
+
class DensePoseCSEConfidenceBasedSampler(DensePoseCSEBaseSampler):
|
19 |
+
"""
|
20 |
+
Samples DensePose data from DensePose predictions.
|
21 |
+
Samples for each class are drawn using confidence value estimates.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
cfg: CfgNode,
|
27 |
+
use_gt_categories: bool,
|
28 |
+
embedder: torch.nn.Module,
|
29 |
+
confidence_channel: str,
|
30 |
+
count_per_class: int = 8,
|
31 |
+
search_count_multiplier: Optional[float] = None,
|
32 |
+
search_proportion: Optional[float] = None,
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
Constructor
|
36 |
+
|
37 |
+
Args:
|
38 |
+
cfg (CfgNode): the config of the model
|
39 |
+
embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
|
40 |
+
confidence_channel (str): confidence channel to use for sampling;
|
41 |
+
possible values:
|
42 |
+
"coarse_segm_confidence": confidences for coarse segmentation
|
43 |
+
(default: "coarse_segm_confidence")
|
44 |
+
count_per_class (int): the sampler produces at most `count_per_class`
|
45 |
+
samples for each category (default: 8)
|
46 |
+
search_count_multiplier (float or None): if not None, the total number
|
47 |
+
of the most confident estimates of a given class to consider is
|
48 |
+
defined as `min(search_count_multiplier * count_per_class, N)`,
|
49 |
+
where `N` is the total number of estimates of the class; cannot be
|
50 |
+
specified together with `search_proportion` (default: None)
|
51 |
+
search_proportion (float or None): if not None, the total number of the
|
52 |
+
of the most confident estimates of a given class to consider is
|
53 |
+
defined as `min(max(search_proportion * N, count_per_class), N)`,
|
54 |
+
where `N` is the total number of estimates of the class; cannot be
|
55 |
+
specified together with `search_count_multiplier` (default: None)
|
56 |
+
"""
|
57 |
+
super().__init__(cfg, use_gt_categories, embedder, count_per_class)
|
58 |
+
self.confidence_channel = confidence_channel
|
59 |
+
self.search_count_multiplier = search_count_multiplier
|
60 |
+
self.search_proportion = search_proportion
|
61 |
+
assert (search_count_multiplier is None) or (search_proportion is None), (
|
62 |
+
f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
|
63 |
+
f"and search_proportion (={search_proportion})"
|
64 |
+
)
|
65 |
+
|
66 |
+
def _produce_index_sample(self, values: torch.Tensor, count: int):
|
67 |
+
"""
|
68 |
+
Produce a sample of indices to select data based on confidences
|
69 |
+
|
70 |
+
Args:
|
71 |
+
values (torch.Tensor): a tensor of length k that contains confidences
|
72 |
+
k: number of points labeled with part_id
|
73 |
+
count (int): number of samples to produce, should be positive and <= k
|
74 |
+
|
75 |
+
Return:
|
76 |
+
list(int): indices of values (along axis 1) selected as a sample
|
77 |
+
"""
|
78 |
+
k = values.shape[1]
|
79 |
+
if k == count:
|
80 |
+
index_sample = list(range(k))
|
81 |
+
else:
|
82 |
+
# take the best count * search_count_multiplier pixels,
|
83 |
+
# sample from them uniformly
|
84 |
+
# (here best = smallest variance)
|
85 |
+
_, sorted_confidence_indices = torch.sort(values[0])
|
86 |
+
if self.search_count_multiplier is not None:
|
87 |
+
search_count = min(int(count * self.search_count_multiplier), k)
|
88 |
+
elif self.search_proportion is not None:
|
89 |
+
search_count = min(max(int(k * self.search_proportion), count), k)
|
90 |
+
else:
|
91 |
+
search_count = min(count, k)
|
92 |
+
sample_from_top = random.sample(range(search_count), count)
|
93 |
+
index_sample = sorted_confidence_indices[-search_count:][sample_from_top]
|
94 |
+
return index_sample
|
95 |
+
|
96 |
+
def _produce_mask_and_results(
|
97 |
+
self, instance: Instances, bbox_xywh: IntTupleBox
|
98 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
99 |
+
"""
|
100 |
+
Method to get labels and DensePose results from an instance
|
101 |
+
|
102 |
+
Args:
|
103 |
+
instance (Instances): an instance of
|
104 |
+
`DensePoseEmbeddingPredictorOutputWithConfidences`
|
105 |
+
bbox_xywh (IntTupleBox): the corresponding bounding box
|
106 |
+
|
107 |
+
Return:
|
108 |
+
mask (torch.Tensor): shape [H, W], DensePose segmentation mask
|
109 |
+
embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W]
|
110 |
+
DensePose CSE Embeddings
|
111 |
+
other_values: a tensor of shape [1, H, W], DensePose CSE confidence
|
112 |
+
"""
|
113 |
+
_, _, w, h = bbox_xywh
|
114 |
+
densepose_output = instance.pred_densepose
|
115 |
+
mask, embeddings, _ = super()._produce_mask_and_results(instance, bbox_xywh)
|
116 |
+
other_values = F.interpolate(
|
117 |
+
getattr(densepose_output, self.confidence_channel),
|
118 |
+
size=(h, w),
|
119 |
+
mode="bilinear",
|
120 |
+
)[0].cpu()
|
121 |
+
return mask, embeddings, other_values
|
densepose/data/samplers/densepose_cse_uniform.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from .densepose_cse_base import DensePoseCSEBaseSampler
|
6 |
+
from .densepose_uniform import DensePoseUniformSampler
|
7 |
+
|
8 |
+
|
9 |
+
class DensePoseCSEUniformSampler(DensePoseCSEBaseSampler, DensePoseUniformSampler):
|
10 |
+
"""
|
11 |
+
Uniform Sampler for CSE
|
12 |
+
"""
|
13 |
+
|
14 |
+
pass
|
densepose/data/samplers/densepose_uniform.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from .densepose_base import DensePoseBaseSampler
|
9 |
+
|
10 |
+
|
11 |
+
class DensePoseUniformSampler(DensePoseBaseSampler):
|
12 |
+
"""
|
13 |
+
Samples DensePose data from DensePose predictions.
|
14 |
+
Samples for each class are drawn uniformly over all pixels estimated
|
15 |
+
to belong to that class.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, count_per_class: int = 8):
|
19 |
+
"""
|
20 |
+
Constructor
|
21 |
+
|
22 |
+
Args:
|
23 |
+
count_per_class (int): the sampler produces at most `count_per_class`
|
24 |
+
samples for each category
|
25 |
+
"""
|
26 |
+
super().__init__(count_per_class)
|
27 |
+
|
28 |
+
def _produce_index_sample(self, values: torch.Tensor, count: int):
|
29 |
+
"""
|
30 |
+
Produce a uniform sample of indices to select data
|
31 |
+
|
32 |
+
Args:
|
33 |
+
values (torch.Tensor): an array of size [n, k] that contains
|
34 |
+
estimated values (U, V, confidences);
|
35 |
+
n: number of channels (U, V, confidences)
|
36 |
+
k: number of points labeled with part_id
|
37 |
+
count (int): number of samples to produce, should be positive and <= k
|
38 |
+
|
39 |
+
Return:
|
40 |
+
list(int): indices of values (along axis 1) selected as a sample
|
41 |
+
"""
|
42 |
+
k = values.shape[1]
|
43 |
+
return random.sample(range(k), count)
|
densepose/data/samplers/mask_from_densepose.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from detectron2.structures import BitMasks, Instances
|
6 |
+
|
7 |
+
from densepose.converters import ToMaskConverter
|
8 |
+
|
9 |
+
|
10 |
+
class MaskFromDensePoseSampler:
|
11 |
+
"""
|
12 |
+
Produce mask GT from DensePose predictions
|
13 |
+
This sampler simply converts DensePose predictions to BitMasks
|
14 |
+
that a contain a bool tensor of the size of the input image
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __call__(self, instances: Instances) -> BitMasks:
|
18 |
+
"""
|
19 |
+
Converts predicted data from `instances` into the GT mask data
|
20 |
+
|
21 |
+
Args:
|
22 |
+
instances (Instances): predicted results, expected to have `pred_densepose` field
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
Boolean Tensor of the size of the input image that has non-zero
|
26 |
+
values at pixels that are estimated to belong to the detected object
|
27 |
+
"""
|
28 |
+
return ToMaskConverter.convert(
|
29 |
+
instances.pred_densepose, instances.pred_boxes, instances.image_size
|
30 |
+
)
|
densepose/data/samplers/prediction_to_gt.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Any, Callable, Dict, List, Optional
|
7 |
+
|
8 |
+
from detectron2.structures import Instances
|
9 |
+
|
10 |
+
ModelOutput = Dict[str, Any]
|
11 |
+
SampledData = Dict[str, Any]
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class _Sampler:
|
16 |
+
"""
|
17 |
+
Sampler registry entry that contains:
|
18 |
+
- src (str): source field to sample from (deleted after sampling)
|
19 |
+
- dst (Optional[str]): destination field to sample to, if not None
|
20 |
+
- func (Optional[Callable: Any -> Any]): function that performs sampling,
|
21 |
+
if None, reference copy is performed
|
22 |
+
"""
|
23 |
+
|
24 |
+
src: str
|
25 |
+
dst: Optional[str]
|
26 |
+
func: Optional[Callable[[Any], Any]]
|
27 |
+
|
28 |
+
|
29 |
+
class PredictionToGroundTruthSampler:
|
30 |
+
"""
|
31 |
+
Sampler implementation that converts predictions to GT using registered
|
32 |
+
samplers for different fields of `Instances`.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, dataset_name: str = ""):
|
36 |
+
self.dataset_name = dataset_name
|
37 |
+
self._samplers = {}
|
38 |
+
self.register_sampler("pred_boxes", "gt_boxes", None)
|
39 |
+
self.register_sampler("pred_classes", "gt_classes", None)
|
40 |
+
# delete scores
|
41 |
+
self.register_sampler("scores")
|
42 |
+
|
43 |
+
def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]:
|
44 |
+
"""
|
45 |
+
Transform model output into ground truth data through sampling
|
46 |
+
|
47 |
+
Args:
|
48 |
+
model_output (Dict[str, Any]): model output
|
49 |
+
Returns:
|
50 |
+
Dict[str, Any]: sampled data
|
51 |
+
"""
|
52 |
+
for model_output_i in model_output:
|
53 |
+
instances: Instances = model_output_i["instances"]
|
54 |
+
# transform data in each field
|
55 |
+
for _, sampler in self._samplers.items():
|
56 |
+
if not instances.has(sampler.src) or sampler.dst is None:
|
57 |
+
continue
|
58 |
+
if sampler.func is None:
|
59 |
+
instances.set(sampler.dst, instances.get(sampler.src))
|
60 |
+
else:
|
61 |
+
instances.set(sampler.dst, sampler.func(instances))
|
62 |
+
# delete model output data that was transformed
|
63 |
+
for _, sampler in self._samplers.items():
|
64 |
+
if sampler.src != sampler.dst and instances.has(sampler.src):
|
65 |
+
instances.remove(sampler.src)
|
66 |
+
model_output_i["dataset"] = self.dataset_name
|
67 |
+
return model_output
|
68 |
+
|
69 |
+
def register_sampler(
|
70 |
+
self,
|
71 |
+
prediction_attr: str,
|
72 |
+
gt_attr: Optional[str] = None,
|
73 |
+
func: Optional[Callable[[Any], Any]] = None,
|
74 |
+
):
|
75 |
+
"""
|
76 |
+
Register sampler for a field
|
77 |
+
|
78 |
+
Args:
|
79 |
+
prediction_attr (str): field to replace with a sampled value
|
80 |
+
gt_attr (Optional[str]): field to store the sampled value to, if not None
|
81 |
+
func (Optional[Callable: Any -> Any]): sampler function
|
82 |
+
"""
|
83 |
+
self._samplers[(prediction_attr, gt_attr)] = _Sampler(
|
84 |
+
src=prediction_attr, dst=gt_attr, func=func
|
85 |
+
)
|
86 |
+
|
87 |
+
def remove_sampler(
|
88 |
+
self,
|
89 |
+
prediction_attr: str,
|
90 |
+
gt_attr: Optional[str] = None,
|
91 |
+
):
|
92 |
+
"""
|
93 |
+
Remove sampler for a field
|
94 |
+
|
95 |
+
Args:
|
96 |
+
prediction_attr (str): field to replace with a sampled value
|
97 |
+
gt_attr (Optional[str]): field to store the sampled value to, if not None
|
98 |
+
"""
|
99 |
+
assert (prediction_attr, gt_attr) in self._samplers
|
100 |
+
del self._samplers[(prediction_attr, gt_attr)]
|
densepose/data/transform/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from .image import ImageResizeTransform
|
densepose/data/transform/image.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class ImageResizeTransform:
|
9 |
+
"""
|
10 |
+
Transform that resizes images loaded from a dataset
|
11 |
+
(BGR data in NCHW channel order, typically uint8) to a format ready to be
|
12 |
+
consumed by DensePose training (BGR float32 data in NCHW channel order)
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, min_size: int = 800, max_size: int = 1333):
|
16 |
+
self.min_size = min_size
|
17 |
+
self.max_size = max_size
|
18 |
+
|
19 |
+
def __call__(self, images: torch.Tensor) -> torch.Tensor:
|
20 |
+
"""
|
21 |
+
Args:
|
22 |
+
images (torch.Tensor): tensor of size [N, 3, H, W] that contains
|
23 |
+
BGR data (typically in uint8)
|
24 |
+
Returns:
|
25 |
+
images (torch.Tensor): tensor of size [N, 3, H1, W1] where
|
26 |
+
H1 and W1 are chosen to respect the specified min and max sizes
|
27 |
+
and preserve the original aspect ratio, the data channels
|
28 |
+
follow BGR order and the data type is `torch.float32`
|
29 |
+
"""
|
30 |
+
# resize with min size
|
31 |
+
images = images.float()
|
32 |
+
min_size = min(images.shape[-2:])
|
33 |
+
max_size = max(images.shape[-2:])
|
34 |
+
scale = min(self.min_size / min_size, self.max_size / max_size)
|
35 |
+
images = torch.nn.functional.interpolate(
|
36 |
+
images,
|
37 |
+
scale_factor=scale,
|
38 |
+
mode="bilinear",
|
39 |
+
align_corners=False,
|
40 |
+
)
|
41 |
+
return images
|
densepose/data/utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
import os
|
6 |
+
from typing import Dict, Optional
|
7 |
+
|
8 |
+
from detectron2.config import CfgNode
|
9 |
+
|
10 |
+
|
11 |
+
def is_relative_local_path(path: str) -> bool:
|
12 |
+
path_str = os.fsdecode(path)
|
13 |
+
return ("://" not in path_str) and not os.path.isabs(path)
|
14 |
+
|
15 |
+
|
16 |
+
def maybe_prepend_base_path(base_path: Optional[str], path: str):
|
17 |
+
"""
|
18 |
+
Prepends the provided path with a base path prefix if:
|
19 |
+
1) base path is not None;
|
20 |
+
2) path is a local path
|
21 |
+
"""
|
22 |
+
if base_path is None:
|
23 |
+
return path
|
24 |
+
if is_relative_local_path(path):
|
25 |
+
return os.path.join(base_path, path)
|
26 |
+
return path
|
27 |
+
|
28 |
+
|
29 |
+
def get_class_to_mesh_name_mapping(cfg: CfgNode) -> Dict[int, str]:
|
30 |
+
return {
|
31 |
+
int(class_id): mesh_name
|
32 |
+
for class_id, mesh_name in cfg.DATASETS.CLASS_TO_MESH_NAME_MAPPING.items()
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def get_category_to_class_mapping(dataset_cfg: CfgNode) -> Dict[str, int]:
|
37 |
+
return {
|
38 |
+
category: int(class_id)
|
39 |
+
for category, class_id in dataset_cfg.CATEGORY_TO_CLASS_MAPPING.items()
|
40 |
+
}
|
densepose/data/video/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from .frame_selector import (
|
6 |
+
FrameSelectionStrategy,
|
7 |
+
RandomKFramesSelector,
|
8 |
+
FirstKFramesSelector,
|
9 |
+
LastKFramesSelector,
|
10 |
+
FrameTsList,
|
11 |
+
FrameSelector,
|
12 |
+
)
|
13 |
+
|
14 |
+
from .video_keyframe_dataset import (
|
15 |
+
VideoKeyframeDataset,
|
16 |
+
video_list_from_file,
|
17 |
+
list_keyframes,
|
18 |
+
read_keyframes,
|
19 |
+
)
|
densepose/data/video/frame_selector.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
import random
|
6 |
+
from collections.abc import Callable
|
7 |
+
from enum import Enum
|
8 |
+
from typing import Callable as TCallable
|
9 |
+
from typing import List
|
10 |
+
|
11 |
+
FrameTsList = List[int]
|
12 |
+
FrameSelector = TCallable[[FrameTsList], FrameTsList]
|
13 |
+
|
14 |
+
|
15 |
+
class FrameSelectionStrategy(Enum):
|
16 |
+
"""
|
17 |
+
Frame selection strategy used with videos:
|
18 |
+
- "random_k": select k random frames
|
19 |
+
- "first_k": select k first frames
|
20 |
+
- "last_k": select k last frames
|
21 |
+
- "all": select all frames
|
22 |
+
"""
|
23 |
+
|
24 |
+
# fmt: off
|
25 |
+
RANDOM_K = "random_k"
|
26 |
+
FIRST_K = "first_k"
|
27 |
+
LAST_K = "last_k"
|
28 |
+
ALL = "all"
|
29 |
+
# fmt: on
|
30 |
+
|
31 |
+
|
32 |
+
class RandomKFramesSelector(Callable): # pyre-ignore[39]
|
33 |
+
"""
|
34 |
+
Selector that retains at most `k` random frames
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, k: int):
|
38 |
+
self.k = k
|
39 |
+
|
40 |
+
def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
|
41 |
+
"""
|
42 |
+
Select `k` random frames
|
43 |
+
|
44 |
+
Args:
|
45 |
+
frames_tss (List[int]): timestamps of input frames
|
46 |
+
Returns:
|
47 |
+
List[int]: timestamps of selected frames
|
48 |
+
"""
|
49 |
+
return random.sample(frame_tss, min(self.k, len(frame_tss)))
|
50 |
+
|
51 |
+
|
52 |
+
class FirstKFramesSelector(Callable): # pyre-ignore[39]
|
53 |
+
"""
|
54 |
+
Selector that retains at most `k` first frames
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, k: int):
|
58 |
+
self.k = k
|
59 |
+
|
60 |
+
def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
|
61 |
+
"""
|
62 |
+
Select `k` first frames
|
63 |
+
|
64 |
+
Args:
|
65 |
+
frames_tss (List[int]): timestamps of input frames
|
66 |
+
Returns:
|
67 |
+
List[int]: timestamps of selected frames
|
68 |
+
"""
|
69 |
+
return frame_tss[: self.k]
|
70 |
+
|
71 |
+
|
72 |
+
class LastKFramesSelector(Callable): # pyre-ignore[39]
|
73 |
+
"""
|
74 |
+
Selector that retains at most `k` last frames from video data
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(self, k: int):
|
78 |
+
self.k = k
|
79 |
+
|
80 |
+
def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
|
81 |
+
"""
|
82 |
+
Select `k` last frames
|
83 |
+
|
84 |
+
Args:
|
85 |
+
frames_tss (List[int]): timestamps of input frames
|
86 |
+
Returns:
|
87 |
+
List[int]: timestamps of selected frames
|
88 |
+
"""
|
89 |
+
return frame_tss[-self.k :]
|
densepose/data/video/video_keyframe_dataset.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
# pyre-unsafe
|
5 |
+
|
6 |
+
import csv
|
7 |
+
import logging
|
8 |
+
import numpy as np
|
9 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
10 |
+
import av
|
11 |
+
import torch
|
12 |
+
from torch.utils.data.dataset import Dataset
|
13 |
+
|
14 |
+
from detectron2.utils.file_io import PathManager
|
15 |
+
|
16 |
+
from ..utils import maybe_prepend_base_path
|
17 |
+
from .frame_selector import FrameSelector, FrameTsList
|
18 |
+
|
19 |
+
FrameList = List[av.frame.Frame] # pyre-ignore[16]
|
20 |
+
FrameTransform = Callable[[torch.Tensor], torch.Tensor]
|
21 |
+
|
22 |
+
|
23 |
+
def list_keyframes(video_fpath: str, video_stream_idx: int = 0) -> FrameTsList:
|
24 |
+
"""
|
25 |
+
Traverses all keyframes of a video file. Returns a list of keyframe
|
26 |
+
timestamps. Timestamps are counts in timebase units.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
video_fpath (str): Video file path
|
30 |
+
video_stream_idx (int): Video stream index (default: 0)
|
31 |
+
Returns:
|
32 |
+
List[int]: list of keyframe timestaps (timestamp is a count in timebase
|
33 |
+
units)
|
34 |
+
"""
|
35 |
+
try:
|
36 |
+
with PathManager.open(video_fpath, "rb") as io:
|
37 |
+
# pyre-fixme[16]: Module `av` has no attribute `open`.
|
38 |
+
container = av.open(io, mode="r")
|
39 |
+
stream = container.streams.video[video_stream_idx]
|
40 |
+
keyframes = []
|
41 |
+
pts = -1
|
42 |
+
# Note: even though we request forward seeks for keyframes, sometimes
|
43 |
+
# a keyframe in backwards direction is returned. We introduce tolerance
|
44 |
+
# as a max count of ignored backward seeks
|
45 |
+
tolerance_backward_seeks = 2
|
46 |
+
while True:
|
47 |
+
try:
|
48 |
+
container.seek(pts + 1, backward=False, any_frame=False, stream=stream)
|
49 |
+
except av.AVError as e:
|
50 |
+
# the exception occurs when the video length is exceeded,
|
51 |
+
# we then return whatever data we've already collected
|
52 |
+
logger = logging.getLogger(__name__)
|
53 |
+
logger.debug(
|
54 |
+
f"List keyframes: Error seeking video file {video_fpath}, "
|
55 |
+
f"video stream {video_stream_idx}, pts {pts + 1}, AV error: {e}"
|
56 |
+
)
|
57 |
+
return keyframes
|
58 |
+
except OSError as e:
|
59 |
+
logger = logging.getLogger(__name__)
|
60 |
+
logger.warning(
|
61 |
+
f"List keyframes: Error seeking video file {video_fpath}, "
|
62 |
+
f"video stream {video_stream_idx}, pts {pts + 1}, OS error: {e}"
|
63 |
+
)
|
64 |
+
return []
|
65 |
+
packet = next(container.demux(video=video_stream_idx))
|
66 |
+
if packet.pts is not None and packet.pts <= pts:
|
67 |
+
logger = logging.getLogger(__name__)
|
68 |
+
logger.warning(
|
69 |
+
f"Video file {video_fpath}, stream {video_stream_idx}: "
|
70 |
+
f"bad seek for packet {pts + 1} (got packet {packet.pts}), "
|
71 |
+
f"tolerance {tolerance_backward_seeks}."
|
72 |
+
)
|
73 |
+
tolerance_backward_seeks -= 1
|
74 |
+
if tolerance_backward_seeks == 0:
|
75 |
+
return []
|
76 |
+
pts += 1
|
77 |
+
continue
|
78 |
+
tolerance_backward_seeks = 2
|
79 |
+
pts = packet.pts
|
80 |
+
if pts is None:
|
81 |
+
return keyframes
|
82 |
+
if packet.is_keyframe:
|
83 |
+
keyframes.append(pts)
|
84 |
+
return keyframes
|
85 |
+
except OSError as e:
|
86 |
+
logger = logging.getLogger(__name__)
|
87 |
+
logger.warning(
|
88 |
+
f"List keyframes: Error opening video file container {video_fpath}, " f"OS error: {e}"
|
89 |
+
)
|
90 |
+
except RuntimeError as e:
|
91 |
+
logger = logging.getLogger(__name__)
|
92 |
+
logger.warning(
|
93 |
+
f"List keyframes: Error opening video file container {video_fpath}, "
|
94 |
+
f"Runtime error: {e}"
|
95 |
+
)
|
96 |
+
return []
|
97 |
+
|
98 |
+
|
99 |
+
def read_keyframes(
|
100 |
+
video_fpath: str, keyframes: FrameTsList, video_stream_idx: int = 0
|
101 |
+
) -> FrameList: # pyre-ignore[11]
|
102 |
+
"""
|
103 |
+
Reads keyframe data from a video file.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
video_fpath (str): Video file path
|
107 |
+
keyframes (List[int]): List of keyframe timestamps (as counts in
|
108 |
+
timebase units to be used in container seek operations)
|
109 |
+
video_stream_idx (int): Video stream index (default: 0)
|
110 |
+
Returns:
|
111 |
+
List[Frame]: list of frames that correspond to the specified timestamps
|
112 |
+
"""
|
113 |
+
try:
|
114 |
+
with PathManager.open(video_fpath, "rb") as io:
|
115 |
+
# pyre-fixme[16]: Module `av` has no attribute `open`.
|
116 |
+
container = av.open(io)
|
117 |
+
stream = container.streams.video[video_stream_idx]
|
118 |
+
frames = []
|
119 |
+
for pts in keyframes:
|
120 |
+
try:
|
121 |
+
container.seek(pts, any_frame=False, stream=stream)
|
122 |
+
frame = next(container.decode(video=0))
|
123 |
+
frames.append(frame)
|
124 |
+
except av.AVError as e:
|
125 |
+
logger = logging.getLogger(__name__)
|
126 |
+
logger.warning(
|
127 |
+
f"Read keyframes: Error seeking video file {video_fpath}, "
|
128 |
+
f"video stream {video_stream_idx}, pts {pts}, AV error: {e}"
|
129 |
+
)
|
130 |
+
container.close()
|
131 |
+
return frames
|
132 |
+
except OSError as e:
|
133 |
+
logger = logging.getLogger(__name__)
|
134 |
+
logger.warning(
|
135 |
+
f"Read keyframes: Error seeking video file {video_fpath}, "
|
136 |
+
f"video stream {video_stream_idx}, pts {pts}, OS error: {e}"
|
137 |
+
)
|
138 |
+
container.close()
|
139 |
+
return frames
|
140 |
+
except StopIteration:
|
141 |
+
logger = logging.getLogger(__name__)
|
142 |
+
logger.warning(
|
143 |
+
f"Read keyframes: Error decoding frame from {video_fpath}, "
|
144 |
+
f"video stream {video_stream_idx}, pts {pts}"
|
145 |
+
)
|
146 |
+
container.close()
|
147 |
+
return frames
|
148 |
+
|
149 |
+
container.close()
|
150 |
+
return frames
|
151 |
+
except OSError as e:
|
152 |
+
logger = logging.getLogger(__name__)
|
153 |
+
logger.warning(
|
154 |
+
f"Read keyframes: Error opening video file container {video_fpath}, OS error: {e}"
|
155 |
+
)
|
156 |
+
except RuntimeError as e:
|
157 |
+
logger = logging.getLogger(__name__)
|
158 |
+
logger.warning(
|
159 |
+
f"Read keyframes: Error opening video file container {video_fpath}, Runtime error: {e}"
|
160 |
+
)
|
161 |
+
return []
|
162 |
+
|
163 |
+
|
164 |
+
def video_list_from_file(video_list_fpath: str, base_path: Optional[str] = None):
|
165 |
+
"""
|
166 |
+
Create a list of paths to video files from a text file.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
video_list_fpath (str): path to a plain text file with the list of videos
|
170 |
+
base_path (str): base path for entries from the video list (default: None)
|
171 |
+
"""
|
172 |
+
video_list = []
|
173 |
+
with PathManager.open(video_list_fpath, "r") as io:
|
174 |
+
for line in io:
|
175 |
+
video_list.append(maybe_prepend_base_path(base_path, str(line.strip())))
|
176 |
+
return video_list
|
177 |
+
|
178 |
+
|
179 |
+
def read_keyframe_helper_data(fpath: str):
|
180 |
+
"""
|
181 |
+
Read keyframe data from a file in CSV format: the header should contain
|
182 |
+
"video_id" and "keyframes" fields. Value specifications are:
|
183 |
+
video_id: int
|
184 |
+
keyframes: list(int)
|
185 |
+
Example of contents:
|
186 |
+
video_id,keyframes
|
187 |
+
2,"[1,11,21,31,41,51,61,71,81]"
|
188 |
+
|
189 |
+
Args:
|
190 |
+
fpath (str): File containing keyframe data
|
191 |
+
|
192 |
+
Return:
|
193 |
+
video_id_to_keyframes (dict: int -> list(int)): for a given video ID it
|
194 |
+
contains a list of keyframes for that video
|
195 |
+
"""
|
196 |
+
video_id_to_keyframes = {}
|
197 |
+
try:
|
198 |
+
with PathManager.open(fpath, "r") as io:
|
199 |
+
csv_reader = csv.reader(io)
|
200 |
+
header = next(csv_reader)
|
201 |
+
video_id_idx = header.index("video_id")
|
202 |
+
keyframes_idx = header.index("keyframes")
|
203 |
+
for row in csv_reader:
|
204 |
+
video_id = int(row[video_id_idx])
|
205 |
+
assert (
|
206 |
+
video_id not in video_id_to_keyframes
|
207 |
+
), f"Duplicate keyframes entry for video {fpath}"
|
208 |
+
video_id_to_keyframes[video_id] = (
|
209 |
+
[int(v) for v in row[keyframes_idx][1:-1].split(",")]
|
210 |
+
if len(row[keyframes_idx]) > 2
|
211 |
+
else []
|
212 |
+
)
|
213 |
+
except Exception as e:
|
214 |
+
logger = logging.getLogger(__name__)
|
215 |
+
logger.warning(f"Error reading keyframe helper data from {fpath}: {e}")
|
216 |
+
return video_id_to_keyframes
|
217 |
+
|
218 |
+
|
219 |
+
class VideoKeyframeDataset(Dataset):
|
220 |
+
"""
|
221 |
+
Dataset that provides keyframes for a set of videos.
|
222 |
+
"""
|
223 |
+
|
224 |
+
_EMPTY_FRAMES = torch.empty((0, 3, 1, 1))
|
225 |
+
|
226 |
+
def __init__(
|
227 |
+
self,
|
228 |
+
video_list: List[str],
|
229 |
+
category_list: Union[str, List[str], None] = None,
|
230 |
+
frame_selector: Optional[FrameSelector] = None,
|
231 |
+
transform: Optional[FrameTransform] = None,
|
232 |
+
keyframe_helper_fpath: Optional[str] = None,
|
233 |
+
):
|
234 |
+
"""
|
235 |
+
Dataset constructor
|
236 |
+
|
237 |
+
Args:
|
238 |
+
video_list (List[str]): list of paths to video files
|
239 |
+
category_list (Union[str, List[str], None]): list of animal categories for each
|
240 |
+
video file. If it is a string, or None, this applies to all videos
|
241 |
+
frame_selector (Callable: KeyFrameList -> KeyFrameList):
|
242 |
+
selects keyframes to process, keyframes are given by
|
243 |
+
packet timestamps in timebase counts. If None, all keyframes
|
244 |
+
are selected (default: None)
|
245 |
+
transform (Callable: torch.Tensor -> torch.Tensor):
|
246 |
+
transforms a batch of RGB images (tensors of size [B, 3, H, W]),
|
247 |
+
returns a tensor of the same size. If None, no transform is
|
248 |
+
applied (default: None)
|
249 |
+
|
250 |
+
"""
|
251 |
+
if type(category_list) is list:
|
252 |
+
self.category_list = category_list
|
253 |
+
else:
|
254 |
+
self.category_list = [category_list] * len(video_list)
|
255 |
+
assert len(video_list) == len(
|
256 |
+
self.category_list
|
257 |
+
), "length of video and category lists must be equal"
|
258 |
+
self.video_list = video_list
|
259 |
+
self.frame_selector = frame_selector
|
260 |
+
self.transform = transform
|
261 |
+
self.keyframe_helper_data = (
|
262 |
+
read_keyframe_helper_data(keyframe_helper_fpath)
|
263 |
+
if keyframe_helper_fpath is not None
|
264 |
+
else None
|
265 |
+
)
|
266 |
+
|
267 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
268 |
+
"""
|
269 |
+
Gets selected keyframes from a given video
|
270 |
+
|
271 |
+
Args:
|
272 |
+
idx (int): video index in the video list file
|
273 |
+
Returns:
|
274 |
+
A dictionary containing two keys:
|
275 |
+
images (torch.Tensor): tensor of size [N, H, W, 3] or of size
|
276 |
+
defined by the transform that contains keyframes data
|
277 |
+
categories (List[str]): categories of the frames
|
278 |
+
"""
|
279 |
+
categories = [self.category_list[idx]]
|
280 |
+
fpath = self.video_list[idx]
|
281 |
+
keyframes = (
|
282 |
+
list_keyframes(fpath)
|
283 |
+
if self.keyframe_helper_data is None or idx not in self.keyframe_helper_data
|
284 |
+
else self.keyframe_helper_data[idx]
|
285 |
+
)
|
286 |
+
transform = self.transform
|
287 |
+
frame_selector = self.frame_selector
|
288 |
+
if not keyframes:
|
289 |
+
return {"images": self._EMPTY_FRAMES, "categories": []}
|
290 |
+
if frame_selector is not None:
|
291 |
+
keyframes = frame_selector(keyframes)
|
292 |
+
frames = read_keyframes(fpath, keyframes)
|
293 |
+
if not frames:
|
294 |
+
return {"images": self._EMPTY_FRAMES, "categories": []}
|
295 |
+
frames = np.stack([frame.to_rgb().to_ndarray() for frame in frames])
|
296 |
+
frames = torch.as_tensor(frames, device=torch.device("cpu"))
|
297 |
+
frames = frames[..., [2, 1, 0]] # RGB -> BGR
|
298 |
+
frames = frames.permute(0, 3, 1, 2).float() # NHWC -> NCHW
|
299 |
+
if transform is not None:
|
300 |
+
frames = transform(frames)
|
301 |
+
return {"images": frames, "categories": categories}
|
302 |
+
|
303 |
+
def __len__(self):
|
304 |
+
return len(self.video_list)
|
densepose/engine/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from .trainer import Trainer
|
densepose/engine/trainer.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
from collections import OrderedDict
|
8 |
+
from typing import List, Optional, Union
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
13 |
+
from detectron2.config import CfgNode
|
14 |
+
from detectron2.engine import DefaultTrainer
|
15 |
+
from detectron2.evaluation import (
|
16 |
+
DatasetEvaluator,
|
17 |
+
DatasetEvaluators,
|
18 |
+
inference_on_dataset,
|
19 |
+
print_csv_format,
|
20 |
+
)
|
21 |
+
from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping
|
22 |
+
from detectron2.utils import comm
|
23 |
+
from detectron2.utils.events import EventWriter, get_event_storage
|
24 |
+
|
25 |
+
from densepose import DensePoseDatasetMapperTTA, DensePoseGeneralizedRCNNWithTTA, load_from_cfg
|
26 |
+
from densepose.data import (
|
27 |
+
DatasetMapper,
|
28 |
+
build_combined_loader,
|
29 |
+
build_detection_test_loader,
|
30 |
+
build_detection_train_loader,
|
31 |
+
build_inference_based_loaders,
|
32 |
+
has_inference_based_loaders,
|
33 |
+
)
|
34 |
+
from densepose.evaluation.d2_evaluator_adapter import Detectron2COCOEvaluatorAdapter
|
35 |
+
from densepose.evaluation.evaluator import DensePoseCOCOEvaluator, build_densepose_evaluator_storage
|
36 |
+
from densepose.modeling.cse import Embedder
|
37 |
+
|
38 |
+
|
39 |
+
class SampleCountingLoader:
|
40 |
+
def __init__(self, loader):
|
41 |
+
self.loader = loader
|
42 |
+
|
43 |
+
def __iter__(self):
|
44 |
+
it = iter(self.loader)
|
45 |
+
storage = get_event_storage()
|
46 |
+
while True:
|
47 |
+
try:
|
48 |
+
batch = next(it)
|
49 |
+
num_inst_per_dataset = {}
|
50 |
+
for data in batch:
|
51 |
+
dataset_name = data["dataset"]
|
52 |
+
if dataset_name not in num_inst_per_dataset:
|
53 |
+
num_inst_per_dataset[dataset_name] = 0
|
54 |
+
num_inst = len(data["instances"])
|
55 |
+
num_inst_per_dataset[dataset_name] += num_inst
|
56 |
+
for dataset_name in num_inst_per_dataset:
|
57 |
+
storage.put_scalar(f"batch/{dataset_name}", num_inst_per_dataset[dataset_name])
|
58 |
+
yield batch
|
59 |
+
except StopIteration:
|
60 |
+
break
|
61 |
+
|
62 |
+
|
63 |
+
class SampleCountMetricPrinter(EventWriter):
|
64 |
+
def __init__(self):
|
65 |
+
self.logger = logging.getLogger(__name__)
|
66 |
+
|
67 |
+
def write(self):
|
68 |
+
storage = get_event_storage()
|
69 |
+
batch_stats_strs = []
|
70 |
+
for key, buf in storage.histories().items():
|
71 |
+
if key.startswith("batch/"):
|
72 |
+
batch_stats_strs.append(f"{key} {buf.avg(20)}")
|
73 |
+
self.logger.info(", ".join(batch_stats_strs))
|
74 |
+
|
75 |
+
|
76 |
+
class Trainer(DefaultTrainer):
|
77 |
+
@classmethod
|
78 |
+
def extract_embedder_from_model(cls, model: nn.Module) -> Optional[Embedder]:
|
79 |
+
if isinstance(model, nn.parallel.DistributedDataParallel):
|
80 |
+
model = model.module
|
81 |
+
if hasattr(model, "roi_heads") and hasattr(model.roi_heads, "embedder"):
|
82 |
+
return model.roi_heads.embedder
|
83 |
+
return None
|
84 |
+
|
85 |
+
# TODO: the only reason to copy the base class code here is to pass the embedder from
|
86 |
+
# the model to the evaluator; that should be refactored to avoid unnecessary copy-pasting
|
87 |
+
@classmethod
|
88 |
+
def test(
|
89 |
+
cls,
|
90 |
+
cfg: CfgNode,
|
91 |
+
model: nn.Module,
|
92 |
+
evaluators: Optional[Union[DatasetEvaluator, List[DatasetEvaluator]]] = None,
|
93 |
+
):
|
94 |
+
"""
|
95 |
+
Args:
|
96 |
+
cfg (CfgNode):
|
97 |
+
model (nn.Module):
|
98 |
+
evaluators (DatasetEvaluator, list[DatasetEvaluator] or None): if None, will call
|
99 |
+
:meth:`build_evaluator`. Otherwise, must have the same length as
|
100 |
+
``cfg.DATASETS.TEST``.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
dict: a dict of result metrics
|
104 |
+
"""
|
105 |
+
logger = logging.getLogger(__name__)
|
106 |
+
if isinstance(evaluators, DatasetEvaluator):
|
107 |
+
evaluators = [evaluators]
|
108 |
+
if evaluators is not None:
|
109 |
+
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
|
110 |
+
len(cfg.DATASETS.TEST), len(evaluators)
|
111 |
+
)
|
112 |
+
|
113 |
+
results = OrderedDict()
|
114 |
+
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
|
115 |
+
data_loader = cls.build_test_loader(cfg, dataset_name)
|
116 |
+
# When evaluators are passed in as arguments,
|
117 |
+
# implicitly assume that evaluators can be created before data_loader.
|
118 |
+
if evaluators is not None:
|
119 |
+
evaluator = evaluators[idx]
|
120 |
+
else:
|
121 |
+
try:
|
122 |
+
embedder = cls.extract_embedder_from_model(model)
|
123 |
+
evaluator = cls.build_evaluator(cfg, dataset_name, embedder=embedder)
|
124 |
+
except NotImplementedError:
|
125 |
+
logger.warn(
|
126 |
+
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
|
127 |
+
"or implement its `build_evaluator` method."
|
128 |
+
)
|
129 |
+
results[dataset_name] = {}
|
130 |
+
continue
|
131 |
+
if cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE or comm.is_main_process():
|
132 |
+
results_i = inference_on_dataset(model, data_loader, evaluator)
|
133 |
+
else:
|
134 |
+
results_i = {}
|
135 |
+
results[dataset_name] = results_i
|
136 |
+
if comm.is_main_process():
|
137 |
+
assert isinstance(
|
138 |
+
results_i, dict
|
139 |
+
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
140 |
+
results_i
|
141 |
+
)
|
142 |
+
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
143 |
+
print_csv_format(results_i)
|
144 |
+
|
145 |
+
if len(results) == 1:
|
146 |
+
results = list(results.values())[0]
|
147 |
+
return results
|
148 |
+
|
149 |
+
@classmethod
|
150 |
+
def build_evaluator(
|
151 |
+
cls,
|
152 |
+
cfg: CfgNode,
|
153 |
+
dataset_name: str,
|
154 |
+
output_folder: Optional[str] = None,
|
155 |
+
embedder: Optional[Embedder] = None,
|
156 |
+
) -> DatasetEvaluators:
|
157 |
+
if output_folder is None:
|
158 |
+
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
159 |
+
evaluators = []
|
160 |
+
distributed = cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE
|
161 |
+
# Note: we currently use COCO evaluator for both COCO and LVIS datasets
|
162 |
+
# to have compatible metrics. LVIS bbox evaluator could also be used
|
163 |
+
# with an adapter to properly handle filtered / mapped categories
|
164 |
+
# evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
|
165 |
+
# if evaluator_type == "coco":
|
166 |
+
# evaluators.append(COCOEvaluator(dataset_name, output_dir=output_folder))
|
167 |
+
# elif evaluator_type == "lvis":
|
168 |
+
# evaluators.append(LVISEvaluator(dataset_name, output_dir=output_folder))
|
169 |
+
evaluators.append(
|
170 |
+
Detectron2COCOEvaluatorAdapter(
|
171 |
+
dataset_name, output_dir=output_folder, distributed=distributed
|
172 |
+
)
|
173 |
+
)
|
174 |
+
if cfg.MODEL.DENSEPOSE_ON:
|
175 |
+
storage = build_densepose_evaluator_storage(cfg, output_folder)
|
176 |
+
evaluators.append(
|
177 |
+
DensePoseCOCOEvaluator(
|
178 |
+
dataset_name,
|
179 |
+
distributed,
|
180 |
+
output_folder,
|
181 |
+
evaluator_type=cfg.DENSEPOSE_EVALUATION.TYPE,
|
182 |
+
min_iou_threshold=cfg.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD,
|
183 |
+
storage=storage,
|
184 |
+
embedder=embedder,
|
185 |
+
should_evaluate_mesh_alignment=cfg.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT,
|
186 |
+
mesh_alignment_mesh_names=cfg.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES,
|
187 |
+
)
|
188 |
+
)
|
189 |
+
return DatasetEvaluators(evaluators)
|
190 |
+
|
191 |
+
@classmethod
|
192 |
+
def build_optimizer(cls, cfg: CfgNode, model: nn.Module):
|
193 |
+
params = get_default_optimizer_params(
|
194 |
+
model,
|
195 |
+
base_lr=cfg.SOLVER.BASE_LR,
|
196 |
+
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
|
197 |
+
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
|
198 |
+
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
|
199 |
+
overrides={
|
200 |
+
"features": {
|
201 |
+
"lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR,
|
202 |
+
},
|
203 |
+
"embeddings": {
|
204 |
+
"lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR,
|
205 |
+
},
|
206 |
+
},
|
207 |
+
)
|
208 |
+
optimizer = torch.optim.SGD(
|
209 |
+
params,
|
210 |
+
cfg.SOLVER.BASE_LR,
|
211 |
+
momentum=cfg.SOLVER.MOMENTUM,
|
212 |
+
nesterov=cfg.SOLVER.NESTEROV,
|
213 |
+
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
|
214 |
+
)
|
215 |
+
# pyre-fixme[6]: For 2nd param expected `Type[Optimizer]` but got `SGD`.
|
216 |
+
return maybe_add_gradient_clipping(cfg, optimizer)
|
217 |
+
|
218 |
+
@classmethod
|
219 |
+
def build_test_loader(cls, cfg: CfgNode, dataset_name):
|
220 |
+
return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False))
|
221 |
+
|
222 |
+
@classmethod
|
223 |
+
def build_train_loader(cls, cfg: CfgNode):
|
224 |
+
data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True))
|
225 |
+
if not has_inference_based_loaders(cfg):
|
226 |
+
return data_loader
|
227 |
+
model = cls.build_model(cfg)
|
228 |
+
model.to(cfg.BOOTSTRAP_MODEL.DEVICE)
|
229 |
+
DetectionCheckpointer(model).resume_or_load(cfg.BOOTSTRAP_MODEL.WEIGHTS, resume=False)
|
230 |
+
inference_based_loaders, ratios = build_inference_based_loaders(cfg, model)
|
231 |
+
loaders = [data_loader] + inference_based_loaders
|
232 |
+
ratios = [1.0] + ratios
|
233 |
+
combined_data_loader = build_combined_loader(cfg, loaders, ratios)
|
234 |
+
sample_counting_loader = SampleCountingLoader(combined_data_loader)
|
235 |
+
return sample_counting_loader
|
236 |
+
|
237 |
+
def build_writers(self):
|
238 |
+
writers = super().build_writers()
|
239 |
+
writers.append(SampleCountMetricPrinter())
|
240 |
+
return writers
|
241 |
+
|
242 |
+
@classmethod
|
243 |
+
def test_with_TTA(cls, cfg: CfgNode, model):
|
244 |
+
logger = logging.getLogger("detectron2.trainer")
|
245 |
+
# In the end of training, run an evaluation with TTA
|
246 |
+
# Only support some R-CNN models.
|
247 |
+
logger.info("Running inference with test-time augmentation ...")
|
248 |
+
transform_data = load_from_cfg(cfg)
|
249 |
+
model = DensePoseGeneralizedRCNNWithTTA(
|
250 |
+
cfg, model, transform_data, DensePoseDatasetMapperTTA(cfg)
|
251 |
+
)
|
252 |
+
evaluators = [
|
253 |
+
cls.build_evaluator(
|
254 |
+
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
|
255 |
+
)
|
256 |
+
for name in cfg.DATASETS.TEST
|
257 |
+
]
|
258 |
+
res = cls.test(cfg, model, evaluators) # pyre-ignore[6]
|
259 |
+
res = OrderedDict({k + "_TTA": v for k, v in res.items()})
|
260 |
+
return res
|
densepose/evaluation/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from .evaluator import DensePoseCOCOEvaluator
|
densepose/evaluation/d2_evaluator_adapter.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# pyre-unsafe
|
4 |
+
|
5 |
+
from detectron2.data.catalog import Metadata
|
6 |
+
from detectron2.evaluation import COCOEvaluator
|
7 |
+
|
8 |
+
from densepose.data.datasets.coco import (
|
9 |
+
get_contiguous_id_to_category_id_map,
|
10 |
+
maybe_filter_categories_cocoapi,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def _maybe_add_iscrowd_annotations(cocoapi) -> None:
|
15 |
+
for ann in cocoapi.dataset["annotations"]:
|
16 |
+
if "iscrowd" not in ann:
|
17 |
+
ann["iscrowd"] = 0
|
18 |
+
|
19 |
+
|
20 |
+
class Detectron2COCOEvaluatorAdapter(COCOEvaluator):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
dataset_name,
|
24 |
+
output_dir=None,
|
25 |
+
distributed=True,
|
26 |
+
):
|
27 |
+
super().__init__(dataset_name, output_dir=output_dir, distributed=distributed)
|
28 |
+
maybe_filter_categories_cocoapi(dataset_name, self._coco_api)
|
29 |
+
_maybe_add_iscrowd_annotations(self._coco_api)
|
30 |
+
# substitute category metadata to account for categories
|
31 |
+
# that are mapped to the same contiguous id
|
32 |
+
if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
|
33 |
+
self._maybe_substitute_metadata()
|
34 |
+
|
35 |
+
def _maybe_substitute_metadata(self):
|
36 |
+
cont_id_2_cat_id = get_contiguous_id_to_category_id_map(self._metadata)
|
37 |
+
cat_id_2_cont_id = self._metadata.thing_dataset_id_to_contiguous_id
|
38 |
+
if len(cont_id_2_cat_id) == len(cat_id_2_cont_id):
|
39 |
+
return
|
40 |
+
|
41 |
+
cat_id_2_cont_id_injective = {}
|
42 |
+
for cat_id, cont_id in cat_id_2_cont_id.items():
|
43 |
+
if (cont_id in cont_id_2_cat_id) and (cont_id_2_cat_id[cont_id] == cat_id):
|
44 |
+
cat_id_2_cont_id_injective[cat_id] = cont_id
|
45 |
+
|
46 |
+
metadata_new = Metadata(name=self._metadata.name)
|
47 |
+
for key, value in self._metadata.__dict__.items():
|
48 |
+
if key == "thing_dataset_id_to_contiguous_id":
|
49 |
+
setattr(metadata_new, key, cat_id_2_cont_id_injective)
|
50 |
+
else:
|
51 |
+
setattr(metadata_new, key, value)
|
52 |
+
self._metadata = metadata_new
|