rtallam45 commited on
Commit
2907cb7
Β·
1 Parent(s): 55f15d2

Add files with LFS

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. LICENSE +107 -0
  2. README.md +186 -13
  3. app.py +617 -0
  4. app_flux.py +305 -0
  5. app_p2p.py +567 -0
  6. densepose/__init__.py +22 -0
  7. densepose/config.py +277 -0
  8. densepose/converters/__init__.py +17 -0
  9. densepose/converters/base.py +95 -0
  10. densepose/converters/builtin.py +33 -0
  11. densepose/converters/chart_output_hflip.py +73 -0
  12. densepose/converters/chart_output_to_chart_result.py +190 -0
  13. densepose/converters/hflip.py +36 -0
  14. densepose/converters/segm_to_mask.py +152 -0
  15. densepose/converters/to_chart_result.py +72 -0
  16. densepose/converters/to_mask.py +51 -0
  17. densepose/data/__init__.py +27 -0
  18. densepose/data/build.py +738 -0
  19. densepose/data/combined_loader.py +46 -0
  20. densepose/data/dataset_mapper.py +170 -0
  21. densepose/data/datasets/__init__.py +7 -0
  22. densepose/data/datasets/builtin.py +18 -0
  23. densepose/data/datasets/chimpnsee.py +31 -0
  24. densepose/data/datasets/coco.py +434 -0
  25. densepose/data/datasets/dataset_type.py +13 -0
  26. densepose/data/datasets/lvis.py +259 -0
  27. densepose/data/image_list_dataset.py +74 -0
  28. densepose/data/inference_based_loader.py +174 -0
  29. densepose/data/meshes/__init__.py +7 -0
  30. densepose/data/meshes/builtin.py +103 -0
  31. densepose/data/meshes/catalog.py +73 -0
  32. densepose/data/samplers/__init__.py +10 -0
  33. densepose/data/samplers/densepose_base.py +205 -0
  34. densepose/data/samplers/densepose_confidence_based.py +110 -0
  35. densepose/data/samplers/densepose_cse_base.py +141 -0
  36. densepose/data/samplers/densepose_cse_confidence_based.py +121 -0
  37. densepose/data/samplers/densepose_cse_uniform.py +14 -0
  38. densepose/data/samplers/densepose_uniform.py +43 -0
  39. densepose/data/samplers/mask_from_densepose.py +30 -0
  40. densepose/data/samplers/prediction_to_gt.py +100 -0
  41. densepose/data/transform/__init__.py +5 -0
  42. densepose/data/transform/image.py +41 -0
  43. densepose/data/utils.py +40 -0
  44. densepose/data/video/__init__.py +19 -0
  45. densepose/data/video/frame_selector.py +89 -0
  46. densepose/data/video/video_keyframe_dataset.py +304 -0
  47. densepose/engine/__init__.py +5 -0
  48. densepose/engine/trainer.py +260 -0
  49. densepose/evaluation/__init__.py +5 -0
  50. densepose/evaluation/d2_evaluator_adapter.py +52 -0
LICENSE ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ Creative Commons Corporation ("Creative Commons") is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an "as-is" basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
4
+
5
+ Using Creative Commons Public Licenses
6
+
7
+ Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
8
+
9
+ Considerations for licensors: Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. More considerations for licensors : wiki.creativecommons.org/Considerations_for_licensors
10
+
11
+ Considerations for the public: By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor's permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. More considerations for the public : wiki.creativecommons.org/Considerations_for_licensees
12
+
13
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
14
+
15
+ By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
16
+
17
+ Section 1 – Definitions.
18
+
19
+ a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
20
+ b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
21
+ c. BY-NC-SA Compatible License means a license listed at creativecommons.org/compatiblelicenses, approved by Creative Commons as essentially the equivalent of this Public License.
22
+ d. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
23
+ e. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
24
+ f. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
25
+ g. License Elements means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
26
+ h. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
27
+ i. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
28
+ j. Licensor means the individual(s) or entity(ies) granting rights under this Public License.
29
+ k. NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
30
+ l. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
31
+ m. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
32
+ n. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
33
+ Section 2 – Scope.
34
+
35
+ a. License grant.
36
+ 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
37
+ A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
38
+ B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
39
+ 2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
40
+ 3. Term. The term of this Public License is specified in Section 6(a).
41
+ 4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
42
+ 5. Downstream recipients.
43
+ A. Offer from the Licensor – Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
44
+ B. Additional offer from the Licensor – Adapted Material. Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter's License You apply.
45
+ C. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
46
+ 6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
47
+ b. Other rights.
48
+ 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
49
+ 2. Patent and trademark rights are not licensed under this Public License.
50
+ 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
51
+ Section 3 – License Conditions.
52
+
53
+ Your exercise of the Licensed Rights is expressly made subject to the following conditions.
54
+
55
+ a. Attribution.
56
+ 1. If You Share the Licensed Material (including in modified form), You must:
57
+ A. retain the following if it is supplied by the Licensor with the Licensed Material:
58
+ i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
59
+ ii. a copyright notice;
60
+ iii. a notice that refers to this Public License;
61
+ iv. a notice that refers to the disclaimer of warranties;
62
+ v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
63
+
64
+ B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
65
+ C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
66
+ 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
67
+ 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
68
+ b. ShareAlike.In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
69
+ 1. The Adapter's License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
70
+ 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
71
+ 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
72
+ Section 4 – Sui Generis Database Rights.
73
+
74
+ Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
75
+
76
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
77
+ b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
78
+ c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
79
+ For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
80
+ Section 5 – Disclaimer of Warranties and Limitation of Liability.
81
+
82
+ a. Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.
83
+ b. To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.
84
+ c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
85
+ Section 6 – Term and Termination.
86
+
87
+ a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
88
+ b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
89
+ 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
90
+ 2. upon express reinstatement by the Licensor.
91
+ For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
92
+
93
+ c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
94
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
95
+ Section 7 – Other Terms and Conditions.
96
+
97
+ a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
98
+ b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
99
+ Section 8 – Interpretation.
100
+
101
+ a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
102
+ b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
103
+ c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
104
+ d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
105
+ Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the "Licensor." The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark "Creative Commons" or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
106
+
107
+ Creative Commons may be contacted at creativecommons.org.
README.md CHANGED
@@ -1,13 +1,186 @@
1
- ---
2
- title: MarketingCopilot
3
- emoji: πŸ‘€
4
- colorFrom: yellow
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.16.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: Virtual Cloth try on project
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models
2
+
3
+ <div style="display: flex; justify-content: center; align-items: center;">
4
+ <a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
5
+ <img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
6
+ </a>
7
+ <a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
8
+ <img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
9
+ </a>
10
+ <a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
11
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
12
+ </a>
13
+ <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
14
+ <img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
15
+ </a>
16
+ <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
17
+ <img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
18
+ </a>
19
+ <a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
20
+ <img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
21
+ </a>
22
+ <a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
23
+ <img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
24
+ </a>
25
+ </div>
26
+
27
+
28
+ **CatVTON** is a simple and efficient virtual try-on diffusion model with ***1) Lightweight Network (899.06M parameters totally)***, ***2) Parameter-Efficient Training (49.57M parameters trainable)*** and ***3) Simplified Inference (< 8G VRAM for 1024X768 resolution)***.
29
+ <div align="center">
30
+ <img src="resource/img/teaser.jpg" width="100%" height="100%"/>
31
+ </div>
32
+
33
+
34
+
35
+ ## Updates
36
+ - **`2024/12/20`**: πŸ˜„ Code for gradio app of [**CatVTON-FLUX**] has been released! It is not a stable version, but it is a good start!
37
+ - **`2024/12/19`**: [**CatVTON-FLUX**](https://huggingface.co/spaces/zhengchong/CatVTON) has been released! It is a extremely lightweight LoRA (only 37.4M checkpints) for [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev), the lora weights are available in **[huggingface repo](https://huggingface.co/zhengchong/CatVTON/tree/main/flux-lora)**, code will be released soon!
38
+ - **`2024/11/26`**: Our **unified vision-based model for image and video try-on** will be released soon, bringing a brand-new virtual try-on experience! While our demo page will be temporarily taken offline, [**the demo on HuggingFace Space**](https://huggingface.co/spaces/zhengchong/CatVTON) will remain available for use !
39
+ - **`2024/10/17`**:[**Mask-free version**](https://huggingface.co/zhengchong/CatVTON-MaskFree)πŸ€— of CatVTON is release !
40
+ - **`2024/10/13`**: We have built a repo [**Awesome-Try-On-Models**](https://github.com/Zheng-Chong/Awesome-Try-On-Models) that focuses on image, video, and 3D-based try-on models published after 2023, aiming to provide insights into the latest technological trends. If you're interested, feel free to contribute or give it a 🌟 star!
41
+ - **`2024/08/13`**: We localize DensePose & SCHP to avoid certain environment issues.
42
+ - **`2024/08/10`**: Our πŸ€— [**HuggingFace Space**](https://huggingface.co/spaces/zhengchong/CatVTON) is available now! Thanks for the grant from [**ZeroGPU**](https://huggingface.co/zero-gpu-explorers)!
43
+ - **`2024/08/09`**: [**Evaluation code**](https://github.com/Zheng-Chong/CatVTON?tab=readme-ov-file#3-calculate-metrics) is provided to calculate metrics πŸ“š.
44
+ - **`2024/07/27`**: We provide code and workflow for deploying CatVTON on [**ComfyUI**](https://github.com/Zheng-Chong/CatVTON?tab=readme-ov-file#comfyui-workflow) πŸ’₯.
45
+ - **`2024/07/24`**: Our [**Paper on ArXiv**](http://arxiv.org/abs/2407.15886) is available πŸ₯³!
46
+ - **`2024/07/22`**: Our [**App Code**](https://github.com/Zheng-Chong/CatVTON/blob/main/app.py) is released, deploy and enjoy CatVTON on your mechine πŸŽ‰!
47
+ - **`2024/07/21`**: Our [**Inference Code**](https://github.com/Zheng-Chong/CatVTON/blob/main/inference.py) and [**Weights** πŸ€—](https://huggingface.co/zhengchong/CatVTON) are released.
48
+ - **`2024/07/11`**: Our [**Online Demo**](https://huggingface.co/spaces/zhengchong/CatVTON) is released 😁.
49
+
50
+
51
+
52
+
53
+ ## Installation
54
+
55
+ Create a conda environment & Install requirments
56
+ ```shell
57
+ conda create -n catvton python==3.9.0
58
+ conda activate catvton
59
+ cd CatVTON-main # or your path to CatVTON project dir
60
+ pip install -r requirements.txt
61
+ ```
62
+
63
+ ## Deployment
64
+ ### ComfyUI Workflow
65
+ We have modified the main code to enable easy deployment of CatVTON on [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Due to the incompatibility of the code structure, we have released this part in the [Releases](https://github.com/Zheng-Chong/CatVTON/releases/tag/ComfyUI), which includes the code placed under `custom_nodes` of ComfyUI and our workflow JSON files.
66
+
67
+ To deploy CatVTON to your ComfyUI, follow these steps:
68
+ 1. Install all the requirements for both CatVTON and ComfyUI, refer to [Installation Guide for CatVTON](https://github.com/Zheng-Chong/CatVTON/blob/main/INSTALL.md) and [Installation Guide for ComfyUI](https://github.com/comfyanonymous/ComfyUI?tab=readme-ov-file#installing).
69
+ 2. Download [`ComfyUI-CatVTON.zip`](https://github.com/Zheng-Chong/CatVTON/releases/download/ComfyUI/ComfyUI-CatVTON.zip) and unzip it in the `custom_nodes` folder under your ComfyUI project (clone from [ComfyUI](https://github.com/comfyanonymous/ComfyUI)).
70
+ 3. Run the ComfyUI.
71
+ 4. Download [`catvton_workflow.json`](https://github.com/Zheng-Chong/CatVTON/releases/download/ComfyUI/catvton_workflow.json) and drag it into you ComfyUI webpage and enjoy πŸ˜†!
72
+
73
+ > Problems under Windows OS, please refer to [issue#8](https://github.com/Zheng-Chong/CatVTON/issues/8).
74
+ >
75
+ When you run the CatVTON workflow for the first time, the weight files will be automatically downloaded, usually taking dozens of minutes.
76
+
77
+ <div align="center">
78
+ <img src="resource/img/comfyui-1.png" width="100%" height="100%"/>
79
+ </div>
80
+
81
+ <!-- <div align="center">
82
+ <img src="resource/img/comfyui.png" width="100%" height="100%"/>
83
+ </div> -->
84
+
85
+ ### Gradio App
86
+
87
+ To deploy the Gradio App for CatVTON on your machine, run the following command, and checkpoints will be automatically downloaded from HuggingFace.
88
+
89
+ ```PowerShell
90
+ CUDA_VISIBLE_DEVICES=0 python app.py \
91
+ --output_dir="resource/demo/output" \
92
+ --mixed_precision="bf16" \
93
+ --allow_tf32
94
+ ```
95
+ When using `bf16` precision, generating results with a resolution of `1024x768` only requires about `8G` VRAM.
96
+
97
+ ## Inference
98
+ ### 1. Data Preparation
99
+ Before inference, you need to download the [VITON-HD](https://github.com/shadow2496/VITON-HD) or [DressCode](https://github.com/aimagelab/dress-code) dataset.
100
+ Once the datasets are downloaded, the folder structures should look like these:
101
+ ```
102
+ β”œβ”€β”€ VITON-HD
103
+ | β”œβ”€β”€ test_pairs_unpaired.txt
104
+ β”‚ β”œβ”€β”€ test
105
+ | | β”œβ”€β”€ image
106
+ β”‚ β”‚ β”‚ β”œβ”€β”€ [000006_00.jpg | 000008_00.jpg | ...]
107
+ β”‚ β”‚ β”œβ”€β”€ cloth
108
+ β”‚ β”‚ β”‚ β”œβ”€β”€ [000006_00.jpg | 000008_00.jpg | ...]
109
+ β”‚ β”‚ β”œβ”€β”€ agnostic-mask
110
+ β”‚ β”‚ β”‚ β”œβ”€β”€ [000006_00_mask.png | 000008_00.png | ...]
111
+ ...
112
+ ```
113
+
114
+ ```
115
+ β”œβ”€β”€ DressCode
116
+ | β”œβ”€β”€ test_pairs_paired.txt
117
+ | β”œβ”€β”€ test_pairs_unpaired.txt
118
+ β”‚ β”œβ”€β”€ [dresses | lower_body | upper_body]
119
+ | | β”œβ”€β”€ test_pairs_paired.txt
120
+ | | β”œβ”€β”€ test_pairs_unpaired.txt
121
+ β”‚ β”‚ β”œβ”€β”€ images
122
+ β”‚ β”‚ β”‚ β”œβ”€β”€ [013563_0.jpg | 013563_1.jpg | 013564_0.jpg | 013564_1.jpg | ...]
123
+ β”‚ β”‚ β”œβ”€β”€ agnostic_masks
124
+ β”‚ β”‚ β”‚ β”œβ”€β”€ [013563_0.png| 013564_0.png | ...]
125
+ ...
126
+ ```
127
+ For the DressCode dataset, we provide script to preprocessed agnostic masks, run the following command:
128
+ ```PowerShell
129
+ CUDA_VISIBLE_DEVICES=0 python preprocess_agnostic_mask.py \
130
+ --data_root_path <your_path_to_DressCode>
131
+ ```
132
+
133
+ ### 2. Inference on VTIONHD/DressCode
134
+ To run the inference on the DressCode or VITON-HD dataset, run the following command, checkpoints will be automatically downloaded from HuggingFace.
135
+
136
+ ```PowerShell
137
+ CUDA_VISIBLE_DEVICES=0 python inference.py \
138
+ --dataset [dresscode | vitonhd] \
139
+ --data_root_path <path> \
140
+ --output_dir <path>
141
+ --dataloader_num_workers 8 \
142
+ --batch_size 8 \
143
+ --seed 555 \
144
+ --mixed_precision [no | fp16 | bf16] \
145
+ --allow_tf32 \
146
+ --repaint \
147
+ --eval_pair
148
+ ```
149
+ ### 3. Calculate Metrics
150
+
151
+ After obtaining the inference results, calculate the metrics using the following command:
152
+
153
+ ```PowerShell
154
+ CUDA_VISIBLE_DEVICES=0 python eval.py \
155
+ --gt_folder <your_path_to_gt_image_folder> \
156
+ --pred_folder <your_path_to_predicted_image_folder> \
157
+ --paired \
158
+ --batch_size=16 \
159
+ --num_workers=16
160
+ ```
161
+
162
+ - `--gt_folder` and `--pred_folder` should be folders that contain **only images**.
163
+ - To evaluate the results in a paired setting, use `--paired`; for an unpaired setting, simply omit it.
164
+ - `--batch_size` and `--num_workers` should be adjusted based on your machine.
165
+
166
+
167
+ ## Acknowledgement
168
+ Our code is modified based on [Diffusers](https://github.com/huggingface/diffusers). We adopt [Stable Diffusion v1.5 inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting) as the base model. We use [SCHP](https://github.com/GoGoDuck912/Self-Correction-Human-Parsing/tree/master) and [DensePose](https://github.com/facebookresearch/DensePose) to automatically generate masks in our [Gradio](https://github.com/gradio-app/gradio) App and [ComfyUI](https://github.com/comfyanonymous/ComfyUI) workflow. Thanks to all the contributors!
169
+
170
+ ## License
171
+ All the materials, including code, checkpoints, and demo, are made available under the [Creative Commons BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) license. You are free to copy, redistribute, remix, transform, and build upon the project for non-commercial purposes, as long as you give appropriate credit and distribute your contributions under the same license.
172
+
173
+
174
+ ## Citation
175
+
176
+ ```bibtex
177
+ @misc{chong2024catvtonconcatenationneedvirtual,
178
+ title={CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models},
179
+ author={Zheng Chong and Xiao Dong and Haoxiang Li and Shiyue Zhang and Wenqing Zhang and Xujie Zhang and Hanqing Zhao and Xiaodan Liang},
180
+ year={2024},
181
+ eprint={2407.15886},
182
+ archivePrefix={arXiv},
183
+ primaryClass={cs.CV},
184
+ url={https://arxiv.org/abs/2407.15886},
185
+ }
186
+ ```
app.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from datetime import datetime
4
+ from openai import AzureOpenAI
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from huggingface_hub import snapshot_download
10
+ from PIL import Image
11
+ import base64
12
+ from model.cloth_masker import AutoMasker, vis_mask
13
+ from model.pipeline import CatVTONPipeline
14
+ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
15
+ from dotenv import load_dotenv
16
+ from io import BytesIO
17
+
18
+
19
+ load_dotenv()
20
+
21
+ import gc
22
+ from transformers import T5EncoderModel
23
+ from diffusers import FluxPipeline, FluxTransformer2DModel
24
+
25
+ openai_client = AzureOpenAI(
26
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
27
+ azure_endpoint=os.getenv("AZURE_ENDPOINT"),
28
+ api_version="2024-02-15-preview",
29
+ azure_deployment="gpt-4o-mvp-dev"
30
+ )
31
+
32
+ def parse_args():
33
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
34
+ parser.add_argument(
35
+ "--base_model_path",
36
+ type=str,
37
+ default="booksforcharlie/stable-diffusion-inpainting", # Change to a copy repo as runawayml delete original repo
38
+ help=(
39
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
40
+ ),
41
+ )
42
+ parser.add_argument(
43
+ "--resume_path",
44
+ type=str,
45
+ default="zhengchong/CatVTON",
46
+ help=(
47
+ "The Path to the checkpoint of trained tryon model."
48
+ ),
49
+ )
50
+ parser.add_argument(
51
+ "--output_dir",
52
+ type=str,
53
+ default="resource/demo/output",
54
+ help="The output directory where the model predictions will be written.",
55
+ )
56
+
57
+ parser.add_argument(
58
+ "--width",
59
+ type=int,
60
+ default=768,
61
+ help=(
62
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
63
+ " resolution"
64
+ ),
65
+ )
66
+ parser.add_argument(
67
+ "--height",
68
+ type=int,
69
+ default=1024,
70
+ help=(
71
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
72
+ " resolution"
73
+ ),
74
+ )
75
+ parser.add_argument(
76
+ "--repaint",
77
+ action="store_true",
78
+ help="Whether to repaint the result image with the original background."
79
+ )
80
+ parser.add_argument(
81
+ "--allow_tf32",
82
+ action="store_true",
83
+ default=True,
84
+ help=(
85
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
86
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
87
+ ),
88
+ )
89
+ parser.add_argument(
90
+ "--mixed_precision",
91
+ type=str,
92
+ default="bf16",
93
+ choices=["no", "fp16", "bf16"],
94
+ help=(
95
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
96
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
97
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
98
+ ),
99
+ )
100
+
101
+ args = parser.parse_args()
102
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
103
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
104
+ args.local_rank = env_local_rank
105
+
106
+ return args
107
+
108
+ def image_grid(imgs, rows, cols):
109
+ assert len(imgs) == rows * cols
110
+
111
+ w, h = imgs[0].size
112
+ grid = Image.new("RGB", size=(cols * w, rows * h))
113
+
114
+ for i, img in enumerate(imgs):
115
+ grid.paste(img, box=(i % cols * w, i // cols * h))
116
+ return grid
117
+
118
+
119
+ args = parse_args()
120
+ repo_path = snapshot_download(repo_id=args.resume_path)
121
+
122
+ def flush():
123
+ gc.collect()
124
+ torch.cuda.empty_cache()
125
+ torch.cuda.reset_max_memory_allocated()
126
+ torch.cuda.reset_peak_memory_stats()
127
+
128
+ flush()
129
+
130
+ ckpt_4bit_id = "sayakpaul/flux.1-dev-nf4-pkg"
131
+
132
+ text_encoder_2_4bit = T5EncoderModel.from_pretrained(
133
+ ckpt_4bit_id,
134
+ subfolder="text_encoder_2",
135
+ )
136
+
137
+ # image gen pipeline
138
+ # Pipeline
139
+ pipeline = CatVTONPipeline(
140
+ base_ckpt=args.base_model_path,
141
+ attn_ckpt=repo_path,
142
+ attn_ckpt_version="mix",
143
+ weight_dtype=init_weight_dtype(args.mixed_precision),
144
+ use_tf32=args.allow_tf32,
145
+ device='cuda'
146
+ )
147
+ # AutoMasker
148
+ mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
149
+ automasker = AutoMasker(
150
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
151
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
152
+ device='cuda',
153
+ )
154
+
155
+ def submit_function(
156
+ person_image,
157
+ cloth_image,
158
+ cloth_type,
159
+ num_inference_steps,
160
+ guidance_scale,
161
+ seed,
162
+ show_type,
163
+ campaign_context,
164
+ ):
165
+ person_image, mask = person_image["background"], person_image["layers"][0]
166
+ mask = Image.open(mask).convert("L")
167
+ if len(np.unique(np.array(mask))) == 1:
168
+ mask = None
169
+ else:
170
+ mask = np.array(mask)
171
+ mask[mask > 0] = 255
172
+ mask = Image.fromarray(mask)
173
+
174
+ tmp_folder = args.output_dir
175
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
176
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
177
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
178
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
179
+
180
+ generator = None
181
+ if seed != -1:
182
+ generator = torch.Generator(device='cuda').manual_seed(seed)
183
+
184
+ person_image = Image.open(person_image).convert("RGB")
185
+ cloth_image = Image.open(cloth_image).convert("RGB")
186
+ person_image = resize_and_crop(person_image, (args.width, args.height))
187
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
188
+
189
+
190
+ # Process mask
191
+ if mask is not None:
192
+ mask = resize_and_crop(mask, (args.width, args.height))
193
+ else:
194
+ mask = automasker(
195
+ person_image,
196
+ cloth_type
197
+ )['mask']
198
+ mask = mask_processor.blur(mask, blur_factor=9)
199
+
200
+ # Inference
201
+ # try:
202
+ result_image = pipeline(
203
+ image=person_image,
204
+ condition_image=cloth_image,
205
+ mask=mask,
206
+ num_inference_steps=num_inference_steps,
207
+ guidance_scale=guidance_scale,
208
+ generator=generator
209
+ )[0]
210
+ # except Exception as e:
211
+ # raise gr.Error(
212
+ # "An error occurred. Please try again later: {}".format(e)
213
+ # )
214
+
215
+ # Post-process
216
+ masked_person = vis_mask(person_image, mask)
217
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
218
+ save_result_image.save(result_save_path)
219
+ # Generate product description
220
+ product_description = generate_upper_cloth_description(cloth_image, cloth_type)
221
+
222
+ # Generate captions for the campaign
223
+ captions = generate_captions(product_description, campaign_context)
224
+
225
+ if show_type == "result only":
226
+ return result_image
227
+ else:
228
+ width, height = person_image.size
229
+ if show_type == "input & result":
230
+ condition_width = width // 2
231
+ conditions = image_grid([person_image, cloth_image], 2, 1)
232
+ else:
233
+ condition_width = width // 3
234
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
235
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
236
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
237
+ new_result_image.paste(conditions, (0, 0))
238
+ new_result_image.paste(result_image, (condition_width + 5, 0))
239
+ return new_result_image, captions
240
+
241
+
242
+ def person_example_fn(image_path):
243
+ return image_path
244
+
245
+ def generate_person_image(text, cloth_description):
246
+ """
247
+ Creates a test image based on the prompt.
248
+ Returns the path to the generated image.
249
+ """
250
+ prompt = generate_ai_model_prompt(text, cloth_description)
251
+ ckpt_id = "black-forest-labs/FLUX.1-dev"
252
+
253
+ print("generating image with prompt: ", prompt)
254
+ image_gen_pipeline = FluxPipeline.from_pretrained(
255
+ ckpt_id,
256
+ text_encoder_2=text_encoder_2_4bit,
257
+ transformer=None,
258
+ vae=None,
259
+ torch_dtype=torch.float16,
260
+ )
261
+ image_gen_pipeline.enable_model_cpu_offload()
262
+ # Create a new image with a random background color
263
+
264
+ with torch.no_grad():
265
+ print("Encoding prompts.")
266
+ prompt_embeds, pooled_prompt_embeds, text_ids = image_gen_pipeline.encode_prompt(
267
+ prompt=prompt, prompt_2=None, max_sequence_length=256
268
+ )
269
+
270
+ image_gen_pipeline = image_gen_pipeline.to("cpu")
271
+ del image_gen_pipeline
272
+
273
+ flush()
274
+
275
+ print(f"prompt_embeds shape: {prompt_embeds.shape}")
276
+ print(f"pooled_prompt_embeds shape: {pooled_prompt_embeds.shape}")
277
+ # Add the prompt text to the image
278
+ transformer_4bit = FluxTransformer2DModel.from_pretrained(ckpt_4bit_id, subfolder="transformer")
279
+ image_gen_pipeline = FluxPipeline.from_pretrained(
280
+ ckpt_id,
281
+ text_encoder=None,
282
+ text_encoder_2=None,
283
+ tokenizer=None,
284
+ tokenizer_2=None,
285
+ transformer=transformer_4bit,
286
+ torch_dtype=torch.float16,
287
+ )
288
+ image_gen_pipeline.enable_model_cpu_offload()
289
+
290
+ print("Running denoising.")
291
+ height, width = 1024, 1024
292
+
293
+ images = image_gen_pipeline(
294
+ prompt_embeds=prompt_embeds,
295
+ pooled_prompt_embeds=pooled_prompt_embeds,
296
+ num_inference_steps=50,
297
+ guidance_scale=5.5,
298
+ height=height,
299
+ width=width,
300
+ output_type="pil",
301
+ ).images
302
+
303
+ # Add current time to make each image unique
304
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
305
+
306
+ # Create output directory if it doesn't exist
307
+ os.makedirs('generated_images', exist_ok=True)
308
+
309
+ # Save the image
310
+ output_path = f'generated_images/generated_{timestamp}.png'
311
+ images[0].save(output_path)
312
+
313
+ return output_path
314
+
315
+ def pil_image_to_base64(image, format: str = "PNG") -> str:
316
+ """
317
+ Converts an image to a Base64 encoded string.
318
+
319
+ Args:
320
+ image: Either a file path (str) or a PIL Image object
321
+ format (str): The format to save the image as (default is PNG).
322
+
323
+ Returns:
324
+ str: A Base64 encoded string of the image.
325
+ """
326
+ try:
327
+ # If image is a file path, open it
328
+ if isinstance(image, str):
329
+ image = Image.open(image)
330
+ elif not isinstance(image, Image.Image):
331
+ raise ValueError("Input must be either a file path or a PIL Image object")
332
+
333
+ # Convert the image to Base64
334
+ buffered = BytesIO()
335
+ image.save(buffered, format=format)
336
+ buffered.seek(0) # Go to the start of the BytesIO stream
337
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
338
+ return image_base64
339
+ except Exception as e:
340
+ print(f"Error converting image to Base64: {e}")
341
+ raise e
342
+
343
+ def generate_upper_cloth_description(product_image, cloth_type: str):
344
+ try:
345
+ base_64_image = pil_image_to_base64(product_image)
346
+
347
+ if cloth_type == "upper":
348
+ system_prompt = """
349
+ You are world class fahsion designer
350
+ Your task is to Write a detailed description of the upper body garment shown in the image, focusing on its fit, sleeve style, fabric type, neckline, and any notable design elements or features in one or two lines for given image.
351
+ Don't start with "This image shows a pair of beige cargo ..." but instead start with "a pair of beige cargo ..."
352
+ """
353
+ elif cloth_type == "lower":
354
+ system_prompt = """
355
+ You are world class fahsion designer
356
+ Your task is to Write a detailed description of the lower body garment shown in the image, focusing on its fit, fabric type, waist style, and any notable design elements or features in one or two lines for given image.
357
+ Don't start with "This image shows a pair of beige cargo ..." but instead start with "a pair of beige cargo ..."
358
+ """
359
+ elif cloth_type == "overall":
360
+ system_prompt = """
361
+ You are world class fahsion designer
362
+ Your task is to Write a detailed description of the overall garment shown in the image, focusing on its fit, fabric type, sleeve style, neckline, and any notable design elements or features in one or two lines for given image.
363
+ Don't start with "This image shows a pair of beige cargo ..." but instead start with "a pair of beige cargo ..."
364
+ """
365
+ else:
366
+ system_prompt = """
367
+ You are world class fahsion designer
368
+ Your task is to Write a detailed description of the upper body garment shown in the image, focusing on its fit, sleeve style, fabric type, neckline, and any notable design elements or features in one or two lines for given image.
369
+ Don't start with "This image shows a pair of beige cargo ..." but instead start with "a pair of beige cargo ..."
370
+ """
371
+
372
+ response = openai_client.chat.completions.create(
373
+ model="gpt-4o",
374
+ messages=[
375
+ {"role": "system", "content": system_prompt},
376
+ {"role": "user", "content": [
377
+ {
378
+ "type": "image_url",
379
+ "image_url": {
380
+ "url": f"data:image/jpeg;base64,{base_64_image}"
381
+ }
382
+ }
383
+ ]},
384
+ ],
385
+ )
386
+
387
+ return response.choices[0].message.content
388
+ except Exception as e:
389
+ print(f"Error in generate_upper_cloth_description: {e}")
390
+ raise e
391
+
392
+ def generate_caption_for_image(image):
393
+ """
394
+ Generates a caption for the given image using OpenAI's vision model.
395
+ """
396
+ if image is None:
397
+ return "Please generate a try-on result first."
398
+
399
+ # Convert the image to base64
400
+ if isinstance(image, str):
401
+ base64_image = pil_image_to_base64(image)
402
+ else:
403
+ # Convert numpy array to PIL Image
404
+ if isinstance(image, np.ndarray):
405
+ image = Image.fromarray((image * 255).astype(np.uint8))
406
+ buffered = BytesIO()
407
+ image.save(buffered, format="PNG")
408
+ base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
409
+
410
+ system_prompt = """
411
+ You are a world class campaign generator for cloth that model is wearing.
412
+ Create campaign caption for the image shown below.
413
+ create engaging campaign captions for products in the merchandise for instagram stories that attract, convert and retain customers.
414
+ """
415
+
416
+ try:
417
+ response = openai_client.chat.completions.create(
418
+ model="gpt-4o",
419
+ messages=[
420
+ {"role": "system", "content": system_prompt},
421
+ {"role": "user", "content": [
422
+ {
423
+ "type": "image_url",
424
+ "image_url": {
425
+ "url": f"data:image/jpeg;base64,{base64_image}"
426
+ }
427
+ }
428
+ ]},
429
+ ],
430
+ )
431
+ return response.choices[0].message.content
432
+ except Exception as e:
433
+ return f"Error generating caption: {str(e)}"
434
+
435
+ def generate_ai_model_prompt(model_description, product_description):
436
+ print("prompt for ai model generation", f" {model_description} wearing {product_description}.")
437
+ return f" {model_description} wearing {product_description}, full image"
438
+
439
+ def generate_captions(product_description, campaign_context):
440
+
441
+ #system prompt
442
+ system_prompt = """
443
+ You are a world-class marketing expert.
444
+ Your task is to create engaging, professional, and contextually relevant campaign captions based on the details provided.
445
+ Use creative language to highlight the product's key features and align with the campaign's goals.
446
+ Ensure the captions are tailored to the specific advertising context provided.
447
+ """
448
+
449
+ # user prompt
450
+ user_prompt = f"""
451
+ Campaign Context: {campaign_context}
452
+ Product Description: {product_description}
453
+ Generate captivating captions for this campaign that align with the provided context.
454
+ """
455
+
456
+ # Call OpenAI API
457
+ response = openai_client.chat.completions.create(
458
+ model="gpt-4o",
459
+ messages=[
460
+ {"role": "system", "content": system_prompt},
461
+ {"role": "user", "content": user_prompt},
462
+ ],
463
+ )
464
+
465
+ # Extract generated captions
466
+ captions = response.choices[0].message.content.strip()
467
+ return captions
468
+
469
+
470
+ HEADER = """
471
+ """
472
+
473
+ def app_gradio():
474
+ with gr.Blocks(title="CatVTON") as demo:
475
+ gr.Markdown(HEADER)
476
+ with gr.Row():
477
+ with gr.Column(scale=1, min_width=350):
478
+ text_prompt = gr.Textbox(
479
+ label="Describe the person (e.g., 'a young woman in a neutral pose')",
480
+ lines=3
481
+ )
482
+
483
+
484
+ generate_button = gr.Button("Generate Person Image")
485
+
486
+ # Hidden image path component
487
+ image_path = gr.Image(
488
+ type="filepath",
489
+ interactive=True,
490
+ visible=False,
491
+ )
492
+
493
+ # Display generated person image
494
+ person_image = gr.ImageEditor(
495
+ interactive=True,
496
+ label="Generated Person Image",
497
+ type="filepath"
498
+ )
499
+
500
+ campaign_context = gr.Textbox(
501
+ label="Describe your campaign context (e.g., 'Summer sale campaign focusing on vibrant colors')",
502
+ lines=3,
503
+ placeholder="What message do you want to convey in this campaign?",
504
+ )
505
+
506
+
507
+ with gr.Row():
508
+ with gr.Column(scale=1, min_width=230):
509
+ cloth_image = gr.Image(
510
+ interactive=True, label="Condition Image", type="filepath"
511
+ )
512
+
513
+ cloth_description = gr.Textbox(
514
+ label="Cloth Description",
515
+ interactive=False,
516
+ lines=3
517
+ )
518
+
519
+
520
+
521
+ with gr.Column(scale=1, min_width=120):
522
+ gr.Markdown(
523
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `πŸ–ŒοΈ` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
524
+ )
525
+ cloth_type = gr.Radio(
526
+ label="Try-On Cloth Type",
527
+ choices=["upper", "lower", "overall"],
528
+ value="upper",
529
+ )
530
+
531
+ cloth_image.change(
532
+ generate_upper_cloth_description,
533
+ inputs=[cloth_image, cloth_type],
534
+ outputs=[cloth_description],
535
+ )
536
+
537
+
538
+ submit = gr.Button("Submit")
539
+ gr.Markdown(
540
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
541
+ )
542
+
543
+ gr.Markdown(
544
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
545
+ )
546
+ with gr.Accordion("Advanced Options", open=False):
547
+ num_inference_steps = gr.Slider(
548
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
549
+ )
550
+ # Guidence Scale
551
+ guidance_scale = gr.Slider(
552
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
553
+ )
554
+ # Random Seed
555
+ seed = gr.Slider(
556
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
557
+ )
558
+ show_type = gr.Radio(
559
+ label="Show Type",
560
+ choices=["result only", "input & result", "input & mask & result"],
561
+ value="input & mask & result",
562
+ )
563
+
564
+ with gr.Column(scale=2, min_width=500):
565
+ # single or multiple image
566
+
567
+ result_image = gr.Image(interactive=False, label="Result")
568
+ captions_textbox = gr.Textbox(
569
+ label="Generated Campaign Captions",
570
+ interactive=False,
571
+ lines=6
572
+ )
573
+
574
+
575
+ with gr.Row():
576
+ # Photo Examples
577
+ root_path = "resource/demo/example"
578
+
579
+
580
+ image_path.change(
581
+ person_example_fn, inputs=image_path, outputs=person_image
582
+ )
583
+
584
+
585
+ # Connect the generation button
586
+ generate_button.click(
587
+ generate_person_image,
588
+ inputs=[text_prompt, cloth_description],
589
+ outputs=[person_image]
590
+ )
591
+
592
+
593
+ submit.click(
594
+ submit_function,
595
+ [
596
+ person_image,
597
+ cloth_image,
598
+ cloth_type,
599
+ num_inference_steps,
600
+ guidance_scale,
601
+ seed,
602
+ show_type,
603
+ campaign_context,
604
+ ],
605
+ [result_image, captions_textbox]
606
+ )
607
+
608
+ # generate_caption_btn.click(
609
+ # generate_caption_for_image,
610
+ # inputs=[result_image],
611
+ # outputs=[caption_text]
612
+ # )
613
+ demo.queue().launch(share=True, show_error=True)
614
+
615
+
616
+ if __name__ == "__main__":
617
+ app_gradio()
app_flux.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import gradio as gr
4
+ from datetime import datetime
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from huggingface_hub import snapshot_download
10
+ from PIL import Image
11
+
12
+ from model.cloth_masker import AutoMasker, vis_mask
13
+ from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
14
+ from utils import resize_and_crop, resize_and_padding
15
+
16
+ def parse_args():
17
+ parser = argparse.ArgumentParser(description="FLUX Try-On Demo")
18
+ parser.add_argument(
19
+ "--base_model_path",
20
+ type=str,
21
+ default="black-forest-labs/FLUX.1-Fill-dev",
22
+ # default="Models/FLUX.1-Fill-dev",
23
+ help="The path to the base model to use for evaluation."
24
+ )
25
+ parser.add_argument(
26
+ "--resume_path",
27
+ type=str,
28
+ default="zhengchong/CatVTON",
29
+ help="The Path to the checkpoint of trained tryon model."
30
+ )
31
+ parser.add_argument(
32
+ "--output_dir",
33
+ type=str,
34
+ default="resource/demo/output",
35
+ help="The output directory where the model predictions will be written."
36
+ )
37
+ parser.add_argument(
38
+ "--mixed_precision",
39
+ type=str,
40
+ default="bf16",
41
+ choices=["no", "fp16", "bf16"],
42
+ help="Whether to use mixed precision."
43
+ )
44
+ parser.add_argument(
45
+ "--allow_tf32",
46
+ action="store_true",
47
+ default=True,
48
+ help="Whether or not to allow TF32 on Ampere GPUs."
49
+ )
50
+ parser.add_argument(
51
+ "--width",
52
+ type=int,
53
+ default=768,
54
+ help="The width of the input image."
55
+ )
56
+ parser.add_argument(
57
+ "--height",
58
+ type=int,
59
+ default=1024,
60
+ help="The height of the input image."
61
+ )
62
+ return parser.parse_args()
63
+
64
+ def image_grid(imgs, rows, cols):
65
+ assert len(imgs) == rows * cols
66
+ w, h = imgs[0].size
67
+ grid = Image.new("RGB", size=(cols * w, rows * h))
68
+ for i, img in enumerate(imgs):
69
+ grid.paste(img, box=(i % cols * w, i // cols * h))
70
+ return grid
71
+
72
+
73
+ def submit_function_flux(
74
+ person_image,
75
+ cloth_image,
76
+ cloth_type,
77
+ num_inference_steps,
78
+ guidance_scale,
79
+ seed,
80
+ show_type
81
+ ):
82
+
83
+ # Process image editor input
84
+ person_image, mask = person_image["background"], person_image["layers"][0]
85
+ mask = Image.open(mask).convert("L")
86
+ if len(np.unique(np.array(mask))) == 1:
87
+ mask = None
88
+ else:
89
+ mask = np.array(mask)
90
+ mask[mask > 0] = 255
91
+ mask = Image.fromarray(mask)
92
+
93
+ # Set random seed
94
+ generator = None
95
+ if seed != -1:
96
+ generator = torch.Generator(device='cuda').manual_seed(seed)
97
+
98
+ # Process input images
99
+ person_image = Image.open(person_image).convert("RGB")
100
+ cloth_image = Image.open(cloth_image).convert("RGB")
101
+
102
+ # Adjust image sizes
103
+ person_image = resize_and_crop(person_image, (args.width, args.height))
104
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
105
+
106
+ # Process mask
107
+ if mask is not None:
108
+ mask = resize_and_crop(mask, (args.width, args.height))
109
+ else:
110
+ mask = automasker(
111
+ person_image,
112
+ cloth_type
113
+ )['mask']
114
+ mask = mask_processor.blur(mask, blur_factor=9)
115
+
116
+ # Inference
117
+ result_image = pipeline_flux(
118
+ image=person_image,
119
+ condition_image=cloth_image,
120
+ mask_image=mask,
121
+ height=args.height,
122
+ width=args.width,
123
+ num_inference_steps=num_inference_steps,
124
+ guidance_scale=guidance_scale,
125
+ generator=generator
126
+ ).images[0]
127
+
128
+ # Post-processing
129
+ masked_person = vis_mask(person_image, mask)
130
+
131
+ # Return result based on show type
132
+ if show_type == "result only":
133
+ return result_image
134
+ else:
135
+ width, height = person_image.size
136
+ if show_type == "input & result":
137
+ condition_width = width // 2
138
+ conditions = image_grid([person_image, cloth_image], 2, 1)
139
+ else:
140
+ condition_width = width // 3
141
+ conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
142
+
143
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
144
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
145
+ new_result_image.paste(conditions, (0, 0))
146
+ new_result_image.paste(result_image, (condition_width + 5, 0))
147
+ return new_result_image
148
+
149
+ def person_example_fn(image_path):
150
+ return image_path
151
+
152
+
153
+ def app_gradio():
154
+ with gr.Blocks(title="CatVTON with FLUX.1-Fill-dev") as demo:
155
+ gr.Markdown("# CatVTON with FLUX.1-Fill-dev")
156
+ with gr.Row():
157
+ with gr.Column(scale=1, min_width=350):
158
+ with gr.Row():
159
+ image_path_flux = gr.Image(
160
+ type="filepath",
161
+ interactive=True,
162
+ visible=False,
163
+ )
164
+ person_image_flux = gr.ImageEditor(
165
+ interactive=True, label="Person Image", type="filepath"
166
+ )
167
+
168
+ with gr.Row():
169
+ with gr.Column(scale=1, min_width=230):
170
+ cloth_image_flux = gr.Image(
171
+ interactive=True, label="Condition Image", type="filepath"
172
+ )
173
+ with gr.Column(scale=1, min_width=120):
174
+ gr.Markdown(
175
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `πŸ–ŒοΈ` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
176
+ )
177
+ cloth_type = gr.Radio(
178
+ label="Try-On Cloth Type",
179
+ choices=["upper", "lower", "overall"],
180
+ value="upper",
181
+ )
182
+
183
+ submit_flux = gr.Button("Submit")
184
+ gr.Markdown(
185
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
186
+ )
187
+
188
+ with gr.Accordion("Advanced Options", open=False):
189
+ num_inference_steps_flux = gr.Slider(
190
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
191
+ )
192
+ # Guidence Scale
193
+ guidance_scale_flux = gr.Slider(
194
+ label="CFG Strenth", minimum=0.0, maximum=50, step=0.5, value=30
195
+ )
196
+ # Random Seed
197
+ seed_flux = gr.Slider(
198
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
199
+ )
200
+ show_type = gr.Radio(
201
+ label="Show Type",
202
+ choices=["result only", "input & result", "input & mask & result"],
203
+ value="input & mask & result",
204
+ )
205
+
206
+ with gr.Column(scale=2, min_width=500):
207
+ result_image_flux = gr.Image(interactive=False, label="Result")
208
+ with gr.Row():
209
+ # Photo Examples
210
+ root_path = "resource/demo/example"
211
+ with gr.Column():
212
+ gr.Examples(
213
+ examples=[
214
+ os.path.join(root_path, "person", "men", _)
215
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
216
+ ],
217
+ examples_per_page=4,
218
+ inputs=image_path_flux,
219
+ label="Person Examples β‘ ",
220
+ )
221
+ gr.Examples(
222
+ examples=[
223
+ os.path.join(root_path, "person", "women", _)
224
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
225
+ ],
226
+ examples_per_page=4,
227
+ inputs=image_path_flux,
228
+ label="Person Examples β‘‘",
229
+ )
230
+ gr.Markdown(
231
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
232
+ )
233
+ with gr.Column():
234
+ gr.Examples(
235
+ examples=[
236
+ os.path.join(root_path, "condition", "upper", _)
237
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
238
+ ],
239
+ examples_per_page=4,
240
+ inputs=cloth_image_flux,
241
+ label="Condition Upper Examples",
242
+ )
243
+ gr.Examples(
244
+ examples=[
245
+ os.path.join(root_path, "condition", "overall", _)
246
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
247
+ ],
248
+ examples_per_page=4,
249
+ inputs=cloth_image_flux,
250
+ label="Condition Overall Examples",
251
+ )
252
+ condition_person_exm = gr.Examples(
253
+ examples=[
254
+ os.path.join(root_path, "condition", "person", _)
255
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
256
+ ],
257
+ examples_per_page=4,
258
+ inputs=cloth_image_flux,
259
+ label="Condition Reference Person Examples",
260
+ )
261
+ gr.Markdown(
262
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
263
+ )
264
+
265
+
266
+ image_path_flux.change(
267
+ person_example_fn, inputs=image_path_flux, outputs=person_image_flux
268
+ )
269
+
270
+ submit_flux.click(
271
+ submit_function_flux,
272
+ [person_image_flux, cloth_image_flux, cloth_type, num_inference_steps_flux, guidance_scale_flux, seed_flux, show_type],
273
+ result_image_flux,
274
+ )
275
+
276
+
277
+ demo.queue().launch(share=True, show_error=True)
278
+
279
+ # θ§£ζžε‚ζ•°
280
+ args = parse_args()
281
+
282
+ # εŠ θ½½ζ¨‘εž‹
283
+ repo_path = snapshot_download(repo_id=args.resume_path)
284
+ pipeline_flux = FluxTryOnPipeline.from_pretrained(args.base_model_path)
285
+ pipeline_flux.load_lora_weights(
286
+ os.path.join(repo_path, "flux-lora"),
287
+ weight_name='pytorch_lora_weights.safetensors'
288
+ )
289
+ pipeline_flux.to("cuda", torch.bfloat16)
290
+
291
+ # εˆε§‹εŒ– AutoMasker
292
+ mask_processor = VaeImageProcessor(
293
+ vae_scale_factor=8,
294
+ do_normalize=False,
295
+ do_binarize=True,
296
+ do_convert_grayscale=True
297
+ )
298
+ automasker = AutoMasker(
299
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
300
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
301
+ device='cuda'
302
+ )
303
+
304
+ if __name__ == "__main__":
305
+ app_gradio()
app_p2p.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from datetime import datetime
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from huggingface_hub import snapshot_download
10
+ from PIL import Image
11
+
12
+ from model.cloth_masker import AutoMasker, vis_mask
13
+ from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
14
+ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
15
+
16
+
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
19
+ parser.add_argument(
20
+ "--p2p_base_model_path",
21
+ type=str,
22
+ default="timbrooks/instruct-pix2pix",
23
+ help=(
24
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
25
+ ),
26
+ )
27
+ parser.add_argument(
28
+ "--ip_base_model_path",
29
+ type=str,
30
+ default="booksforcharlie/stable-diffusion-inpainting",
31
+ help=(
32
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
33
+ ),
34
+ )
35
+ parser.add_argument(
36
+ "--p2p_resume_path",
37
+ type=str,
38
+ default="zhengchong/CatVTON-MaskFree",
39
+ help=(
40
+ "The Path to the checkpoint of trained tryon model."
41
+ ),
42
+ )
43
+ parser.add_argument(
44
+ "--ip_resume_path",
45
+ type=str,
46
+ default="zhengchong/CatVTON",
47
+ help=(
48
+ "The Path to the checkpoint of trained tryon model."
49
+ ),
50
+ )
51
+ parser.add_argument(
52
+ "--output_dir",
53
+ type=str,
54
+ default="resource/demo/output",
55
+ help="The output directory where the model predictions will be written.",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--width",
60
+ type=int,
61
+ default=768,
62
+ help=(
63
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
64
+ " resolution"
65
+ ),
66
+ )
67
+ parser.add_argument(
68
+ "--height",
69
+ type=int,
70
+ default=1024,
71
+ help=(
72
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
73
+ " resolution"
74
+ ),
75
+ )
76
+ parser.add_argument(
77
+ "--repaint",
78
+ action="store_true",
79
+ help="Whether to repaint the result image with the original background."
80
+ )
81
+ parser.add_argument(
82
+ "--allow_tf32",
83
+ action="store_true",
84
+ default=True,
85
+ help=(
86
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
87
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
88
+ ),
89
+ )
90
+ parser.add_argument(
91
+ "--mixed_precision",
92
+ type=str,
93
+ default="bf16",
94
+ choices=["no", "fp16", "bf16"],
95
+ help=(
96
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
97
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
98
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
99
+ ),
100
+ )
101
+
102
+ args = parser.parse_args()
103
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
104
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
105
+ args.local_rank = env_local_rank
106
+
107
+ return args
108
+
109
+ def image_grid(imgs, rows, cols):
110
+ assert len(imgs) == rows * cols
111
+
112
+ w, h = imgs[0].size
113
+ grid = Image.new("RGB", size=(cols * w, rows * h))
114
+
115
+ for i, img in enumerate(imgs):
116
+ grid.paste(img, box=(i % cols * w, i // cols * h))
117
+ return grid
118
+
119
+
120
+ args = parse_args()
121
+ repo_path = snapshot_download(repo_id=args.ip_resume_path)
122
+ # Pipeline
123
+ pipeline_p2p = CatVTONPix2PixPipeline(
124
+ base_ckpt=args.p2p_base_model_path,
125
+ attn_ckpt=repo_path,
126
+ attn_ckpt_version="mix-48k-1024",
127
+ weight_dtype=init_weight_dtype(args.mixed_precision),
128
+ use_tf32=args.allow_tf32,
129
+ device='cuda'
130
+ )
131
+
132
+ # Pipeline
133
+ repo_path = snapshot_download(repo_id=args.ip_resume_path)
134
+ pipeline = CatVTONPipeline(
135
+ base_ckpt=args.ip_base_model_path,
136
+ attn_ckpt=repo_path,
137
+ attn_ckpt_version="mix",
138
+ weight_dtype=init_weight_dtype(args.mixed_precision),
139
+ use_tf32=args.allow_tf32,
140
+ device='cuda'
141
+ )
142
+
143
+ # AutoMasker
144
+ mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
145
+ automasker = AutoMasker(
146
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
147
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
148
+ device='cuda',
149
+ )
150
+
151
+
152
+ def submit_function_p2p(
153
+ person_image,
154
+ cloth_image,
155
+ num_inference_steps,
156
+ guidance_scale,
157
+ seed):
158
+ person_image= person_image["background"]
159
+
160
+ tmp_folder = args.output_dir
161
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
162
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
163
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
164
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
165
+
166
+ generator = None
167
+ if seed != -1:
168
+ generator = torch.Generator(device='cuda').manual_seed(seed)
169
+
170
+ person_image = Image.open(person_image).convert("RGB")
171
+ cloth_image = Image.open(cloth_image).convert("RGB")
172
+ person_image = resize_and_crop(person_image, (args.width, args.height))
173
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
174
+
175
+ # Inference
176
+ try:
177
+ result_image = pipeline_p2p(
178
+ image=person_image,
179
+ condition_image=cloth_image,
180
+ num_inference_steps=num_inference_steps,
181
+ guidance_scale=guidance_scale,
182
+ generator=generator
183
+ )[0]
184
+ except Exception as e:
185
+ raise gr.Error(
186
+ "An error occurred. Please try again later: {}".format(e)
187
+ )
188
+
189
+ # Post-process
190
+ save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
191
+ save_result_image.save(result_save_path)
192
+ return result_image
193
+
194
+ def submit_function(
195
+ person_image,
196
+ cloth_image,
197
+ cloth_type,
198
+ num_inference_steps,
199
+ guidance_scale,
200
+ seed,
201
+ show_type
202
+ ):
203
+ person_image, mask = person_image["background"], person_image["layers"][0]
204
+ mask = Image.open(mask).convert("L")
205
+ if len(np.unique(np.array(mask))) == 1:
206
+ mask = None
207
+ else:
208
+ mask = np.array(mask)
209
+ mask[mask > 0] = 255
210
+ mask = Image.fromarray(mask)
211
+
212
+ tmp_folder = args.output_dir
213
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
214
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
215
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
216
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
217
+
218
+ generator = None
219
+ if seed != -1:
220
+ generator = torch.Generator(device='cuda').manual_seed(seed)
221
+
222
+ person_image = Image.open(person_image).convert("RGB")
223
+ cloth_image = Image.open(cloth_image).convert("RGB")
224
+ person_image = resize_and_crop(person_image, (args.width, args.height))
225
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
226
+
227
+ # Process mask
228
+ if mask is not None:
229
+ mask = resize_and_crop(mask, (args.width, args.height))
230
+ else:
231
+ mask = automasker(
232
+ person_image,
233
+ cloth_type
234
+ )['mask']
235
+ mask = mask_processor.blur(mask, blur_factor=9)
236
+
237
+ # Inference
238
+ # try:
239
+ result_image = pipeline(
240
+ image=person_image,
241
+ condition_image=cloth_image,
242
+ mask=mask,
243
+ num_inference_steps=num_inference_steps,
244
+ guidance_scale=guidance_scale,
245
+ generator=generator
246
+ )[0]
247
+ # except Exception as e:
248
+ # raise gr.Error(
249
+ # "An error occurred. Please try again later: {}".format(e)
250
+ # )
251
+
252
+ # Post-process
253
+ masked_person = vis_mask(person_image, mask)
254
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
255
+ save_result_image.save(result_save_path)
256
+ if show_type == "result only":
257
+ return result_image
258
+ else:
259
+ width, height = person_image.size
260
+ if show_type == "input & result":
261
+ condition_width = width // 2
262
+ conditions = image_grid([person_image, cloth_image], 2, 1)
263
+ else:
264
+ condition_width = width // 3
265
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
266
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
267
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
268
+ new_result_image.paste(conditions, (0, 0))
269
+ new_result_image.paste(result_image, (condition_width + 5, 0))
270
+ return new_result_image
271
+
272
+
273
+
274
+ def person_example_fn(image_path):
275
+ return image_path
276
+
277
+ HEADER = """
278
+ <h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
279
+ <div style="display: flex; justify-content: center; align-items: center;">
280
+ <a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
281
+ <img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
282
+ </a>
283
+ <a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
284
+ <img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
285
+ </a>
286
+ <a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
287
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
288
+ </a>
289
+ <a href="http://120.76.142.206:8888" style="margin: 0 2px;">
290
+ <img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
291
+ </a>
292
+ <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
293
+ <img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
294
+ </a>
295
+ <a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
296
+ <img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
297
+ </a>
298
+ <a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
299
+ <img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
300
+ </a>
301
+ </div>
302
+ <br>
303
+ Β· This demo and our weights are only for <span>Non-commercial Use</span>. <br>
304
+ Β· You can try CatVTON in our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a> or our <a href="http://120.76.142.206:8888">online demo</a> (run on 3090). <br>
305
+ Β· Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
306
+ Β· SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
307
+ """
308
+
309
+ def app_gradio():
310
+ with gr.Blocks(title="CatVTON") as demo:
311
+ gr.Markdown(HEADER)
312
+ with gr.Tab("Mask-based Virtual Try-On"):
313
+ with gr.Row():
314
+ with gr.Column(scale=1, min_width=350):
315
+ with gr.Row():
316
+ image_path = gr.Image(
317
+ type="filepath",
318
+ interactive=True,
319
+ visible=False,
320
+ )
321
+ person_image = gr.ImageEditor(
322
+ interactive=True, label="Person Image", type="filepath"
323
+ )
324
+
325
+ with gr.Row():
326
+ with gr.Column(scale=1, min_width=230):
327
+ cloth_image = gr.Image(
328
+ interactive=True, label="Condition Image", type="filepath"
329
+ )
330
+ with gr.Column(scale=1, min_width=120):
331
+ gr.Markdown(
332
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `πŸ–ŒοΈ` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
333
+ )
334
+ cloth_type = gr.Radio(
335
+ label="Try-On Cloth Type",
336
+ choices=["upper", "lower", "overall"],
337
+ value="upper",
338
+ )
339
+
340
+
341
+ submit = gr.Button("Submit")
342
+ gr.Markdown(
343
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
344
+ )
345
+
346
+ gr.Markdown(
347
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
348
+ )
349
+ with gr.Accordion("Advanced Options", open=False):
350
+ num_inference_steps = gr.Slider(
351
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
352
+ )
353
+ # Guidence Scale
354
+ guidance_scale = gr.Slider(
355
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
356
+ )
357
+ # Random Seed
358
+ seed = gr.Slider(
359
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
360
+ )
361
+ show_type = gr.Radio(
362
+ label="Show Type",
363
+ choices=["result only", "input & result", "input & mask & result"],
364
+ value="input & mask & result",
365
+ )
366
+
367
+ with gr.Column(scale=2, min_width=500):
368
+ result_image = gr.Image(interactive=False, label="Result")
369
+ with gr.Row():
370
+ # Photo Examples
371
+ root_path = "resource/demo/example"
372
+ with gr.Column():
373
+ men_exm = gr.Examples(
374
+ examples=[
375
+ os.path.join(root_path, "person", "men", _)
376
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
377
+ ],
378
+ examples_per_page=4,
379
+ inputs=image_path,
380
+ label="Person Examples β‘ ",
381
+ )
382
+ women_exm = gr.Examples(
383
+ examples=[
384
+ os.path.join(root_path, "person", "women", _)
385
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
386
+ ],
387
+ examples_per_page=4,
388
+ inputs=image_path,
389
+ label="Person Examples β‘‘",
390
+ )
391
+ gr.Markdown(
392
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
393
+ )
394
+ with gr.Column():
395
+ condition_upper_exm = gr.Examples(
396
+ examples=[
397
+ os.path.join(root_path, "condition", "upper", _)
398
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
399
+ ],
400
+ examples_per_page=4,
401
+ inputs=cloth_image,
402
+ label="Condition Upper Examples",
403
+ )
404
+ condition_overall_exm = gr.Examples(
405
+ examples=[
406
+ os.path.join(root_path, "condition", "overall", _)
407
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
408
+ ],
409
+ examples_per_page=4,
410
+ inputs=cloth_image,
411
+ label="Condition Overall Examples",
412
+ )
413
+ condition_person_exm = gr.Examples(
414
+ examples=[
415
+ os.path.join(root_path, "condition", "person", _)
416
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
417
+ ],
418
+ examples_per_page=4,
419
+ inputs=cloth_image,
420
+ label="Condition Reference Person Examples",
421
+ )
422
+ gr.Markdown(
423
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
424
+ )
425
+
426
+ image_path.change(
427
+ person_example_fn, inputs=image_path, outputs=person_image
428
+ )
429
+
430
+ submit.click(
431
+ submit_function,
432
+ [
433
+ person_image,
434
+ cloth_image,
435
+ cloth_type,
436
+ num_inference_steps,
437
+ guidance_scale,
438
+ seed,
439
+ show_type,
440
+ ],
441
+ result_image,
442
+ )
443
+
444
+ with gr.Tab("Mask-Free Virtual Try-On"):
445
+ with gr.Row():
446
+ with gr.Column(scale=1, min_width=350):
447
+ with gr.Row():
448
+ image_path_p2p = gr.Image(
449
+ type="filepath",
450
+ interactive=True,
451
+ visible=False,
452
+ )
453
+ person_image_p2p = gr.ImageEditor(
454
+ interactive=True, label="Person Image", type="filepath"
455
+ )
456
+
457
+ with gr.Row():
458
+ with gr.Column(scale=1, min_width=230):
459
+ cloth_image_p2p = gr.Image(
460
+ interactive=True, label="Condition Image", type="filepath"
461
+ )
462
+
463
+ submit_p2p = gr.Button("Submit")
464
+ gr.Markdown(
465
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
466
+ )
467
+
468
+ gr.Markdown(
469
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
470
+ )
471
+ with gr.Accordion("Advanced Options", open=False):
472
+ num_inference_steps_p2p = gr.Slider(
473
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
474
+ )
475
+ # Guidence Scale
476
+ guidance_scale_p2p = gr.Slider(
477
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
478
+ )
479
+ # Random Seed
480
+ seed_p2p = gr.Slider(
481
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
482
+ )
483
+ # show_type = gr.Radio(
484
+ # label="Show Type",
485
+ # choices=["result only", "input & result", "input & mask & result"],
486
+ # value="input & mask & result",
487
+ # )
488
+
489
+ with gr.Column(scale=2, min_width=500):
490
+ result_image_p2p = gr.Image(interactive=False, label="Result")
491
+ with gr.Row():
492
+ # Photo Examples
493
+ root_path = "resource/demo/example"
494
+ with gr.Column():
495
+ gr.Examples(
496
+ examples=[
497
+ os.path.join(root_path, "person", "men", _)
498
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
499
+ ],
500
+ examples_per_page=4,
501
+ inputs=image_path_p2p,
502
+ label="Person Examples β‘ ",
503
+ )
504
+ gr.Examples(
505
+ examples=[
506
+ os.path.join(root_path, "person", "women", _)
507
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
508
+ ],
509
+ examples_per_page=4,
510
+ inputs=image_path_p2p,
511
+ label="Person Examples β‘‘",
512
+ )
513
+ gr.Markdown(
514
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
515
+ )
516
+ with gr.Column():
517
+ gr.Examples(
518
+ examples=[
519
+ os.path.join(root_path, "condition", "upper", _)
520
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
521
+ ],
522
+ examples_per_page=4,
523
+ inputs=cloth_image_p2p,
524
+ label="Condition Upper Examples",
525
+ )
526
+ gr.Examples(
527
+ examples=[
528
+ os.path.join(root_path, "condition", "overall", _)
529
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
530
+ ],
531
+ examples_per_page=4,
532
+ inputs=cloth_image_p2p,
533
+ label="Condition Overall Examples",
534
+ )
535
+ condition_person_exm = gr.Examples(
536
+ examples=[
537
+ os.path.join(root_path, "condition", "person", _)
538
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
539
+ ],
540
+ examples_per_page=4,
541
+ inputs=cloth_image_p2p,
542
+ label="Condition Reference Person Examples",
543
+ )
544
+ gr.Markdown(
545
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
546
+ )
547
+
548
+ image_path_p2p.change(
549
+ person_example_fn, inputs=image_path_p2p, outputs=person_image_p2p
550
+ )
551
+
552
+ submit_p2p.click(
553
+ submit_function_p2p,
554
+ [
555
+ person_image_p2p,
556
+ cloth_image_p2p,
557
+ num_inference_steps_p2p,
558
+ guidance_scale_p2p,
559
+ seed_p2p],
560
+ result_image_p2p,
561
+ )
562
+
563
+ demo.queue().launch(share=True, show_error=True)
564
+
565
+
566
+ if __name__ == "__main__":
567
+ app_gradio()
densepose/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from .data.datasets import builtin # just to register data
5
+ from .converters import builtin as builtin_converters # register converters
6
+ from .config import (
7
+ add_densepose_config,
8
+ add_densepose_head_config,
9
+ add_hrnet_config,
10
+ add_dataset_category_config,
11
+ add_bootstrap_config,
12
+ load_bootstrap_config,
13
+ )
14
+ from .structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
15
+ from .evaluation import DensePoseCOCOEvaluator
16
+ from .modeling.roi_heads import DensePoseROIHeads
17
+ from .modeling.test_time_augmentation import (
18
+ DensePoseGeneralizedRCNNWithTTA,
19
+ DensePoseDatasetMapperTTA,
20
+ )
21
+ from .utils.transform import load_from_cfg
22
+ from .modeling.hrfpn import build_hrfpn_backbone
densepose/config.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding = utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # pyre-ignore-all-errors
4
+
5
+ from detectron2.config import CfgNode as CN
6
+
7
+
8
+ def add_dataset_category_config(cfg: CN) -> None:
9
+ """
10
+ Add config for additional category-related dataset options
11
+ - category whitelisting
12
+ - category mapping
13
+ """
14
+ _C = cfg
15
+ _C.DATASETS.CATEGORY_MAPS = CN(new_allowed=True)
16
+ _C.DATASETS.WHITELISTED_CATEGORIES = CN(new_allowed=True)
17
+ # class to mesh mapping
18
+ _C.DATASETS.CLASS_TO_MESH_NAME_MAPPING = CN(new_allowed=True)
19
+
20
+
21
+ def add_evaluation_config(cfg: CN) -> None:
22
+ _C = cfg
23
+ _C.DENSEPOSE_EVALUATION = CN()
24
+ # evaluator type, possible values:
25
+ # - "iou": evaluator for models that produce iou data
26
+ # - "cse": evaluator for models that produce cse data
27
+ _C.DENSEPOSE_EVALUATION.TYPE = "iou"
28
+ # storage for DensePose results, possible values:
29
+ # - "none": no explicit storage, all the results are stored in the
30
+ # dictionary with predictions, memory intensive;
31
+ # historically the default storage type
32
+ # - "ram": RAM storage, uses per-process RAM storage, which is
33
+ # reduced to a single process storage on later stages,
34
+ # less memory intensive
35
+ # - "file": file storage, uses per-process file-based storage,
36
+ # the least memory intensive, but may create bottlenecks
37
+ # on file system accesses
38
+ _C.DENSEPOSE_EVALUATION.STORAGE = "none"
39
+ # minimum threshold for IOU values: the lower its values is,
40
+ # the more matches are produced (and the higher the AP score)
41
+ _C.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD = 0.5
42
+ # Non-distributed inference is slower (at inference time) but can avoid RAM OOM
43
+ _C.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE = True
44
+ # evaluate mesh alignment based on vertex embeddings, only makes sense in CSE context
45
+ _C.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT = False
46
+ # meshes to compute mesh alignment for
47
+ _C.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES = []
48
+
49
+
50
+ def add_bootstrap_config(cfg: CN) -> None:
51
+ """ """
52
+ _C = cfg
53
+ _C.BOOTSTRAP_DATASETS = []
54
+ _C.BOOTSTRAP_MODEL = CN()
55
+ _C.BOOTSTRAP_MODEL.WEIGHTS = ""
56
+ _C.BOOTSTRAP_MODEL.DEVICE = "cuda"
57
+
58
+
59
+ def get_bootstrap_dataset_config() -> CN:
60
+ _C = CN()
61
+ _C.DATASET = ""
62
+ # ratio used to mix data loaders
63
+ _C.RATIO = 0.1
64
+ # image loader
65
+ _C.IMAGE_LOADER = CN(new_allowed=True)
66
+ _C.IMAGE_LOADER.TYPE = ""
67
+ _C.IMAGE_LOADER.BATCH_SIZE = 4
68
+ _C.IMAGE_LOADER.NUM_WORKERS = 4
69
+ _C.IMAGE_LOADER.CATEGORIES = []
70
+ _C.IMAGE_LOADER.MAX_COUNT_PER_CATEGORY = 1_000_000
71
+ _C.IMAGE_LOADER.CATEGORY_TO_CLASS_MAPPING = CN(new_allowed=True)
72
+ # inference
73
+ _C.INFERENCE = CN()
74
+ # batch size for model inputs
75
+ _C.INFERENCE.INPUT_BATCH_SIZE = 4
76
+ # batch size to group model outputs
77
+ _C.INFERENCE.OUTPUT_BATCH_SIZE = 2
78
+ # sampled data
79
+ _C.DATA_SAMPLER = CN(new_allowed=True)
80
+ _C.DATA_SAMPLER.TYPE = ""
81
+ _C.DATA_SAMPLER.USE_GROUND_TRUTH_CATEGORIES = False
82
+ # filter
83
+ _C.FILTER = CN(new_allowed=True)
84
+ _C.FILTER.TYPE = ""
85
+ return _C
86
+
87
+
88
+ def load_bootstrap_config(cfg: CN) -> None:
89
+ """
90
+ Bootstrap datasets are given as a list of `dict` that are not automatically
91
+ converted into CfgNode. This method processes all bootstrap dataset entries
92
+ and ensures that they are in CfgNode format and comply with the specification
93
+ """
94
+ if not cfg.BOOTSTRAP_DATASETS:
95
+ return
96
+
97
+ bootstrap_datasets_cfgnodes = []
98
+ for dataset_cfg in cfg.BOOTSTRAP_DATASETS:
99
+ _C = get_bootstrap_dataset_config().clone()
100
+ _C.merge_from_other_cfg(CN(dataset_cfg))
101
+ bootstrap_datasets_cfgnodes.append(_C)
102
+ cfg.BOOTSTRAP_DATASETS = bootstrap_datasets_cfgnodes
103
+
104
+
105
+ def add_densepose_head_cse_config(cfg: CN) -> None:
106
+ """
107
+ Add configuration options for Continuous Surface Embeddings (CSE)
108
+ """
109
+ _C = cfg
110
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE = CN()
111
+ # Dimensionality D of the embedding space
112
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE = 16
113
+ # Embedder specifications for various mesh IDs
114
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS = CN(new_allowed=True)
115
+ # normalization coefficient for embedding distances
116
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA = 0.01
117
+ # normalization coefficient for geodesic distances
118
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA = 0.01
119
+ # embedding loss weight
120
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT = 0.6
121
+ # embedding loss name, currently the following options are supported:
122
+ # - EmbeddingLoss: cross-entropy on vertex labels
123
+ # - SoftEmbeddingLoss: cross-entropy on vertex label combined with
124
+ # Gaussian penalty on distance between vertices
125
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME = "EmbeddingLoss"
126
+ # optimizer hyperparameters
127
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR = 1.0
128
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR = 1.0
129
+ # Shape to shape cycle consistency loss parameters:
130
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
131
+ # shape to shape cycle consistency loss weight
132
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.025
133
+ # norm type used for loss computation
134
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
135
+ # normalization term for embedding similarity matrices
136
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE = 0.05
137
+ # maximum number of vertices to include into shape to shape cycle loss
138
+ # if negative or zero, all vertices are considered
139
+ # if positive, random subset of vertices of given size is considered
140
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES = 4936
141
+ # Pixel to shape cycle consistency loss parameters:
142
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
143
+ # pixel to shape cycle consistency loss weight
144
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.0001
145
+ # norm type used for loss computation
146
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
147
+ # map images to all meshes and back (if false, use only gt meshes from the batch)
148
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY = False
149
+ # Randomly select at most this number of pixels from every instance
150
+ # if negative or zero, all vertices are considered
151
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE = 100
152
+ # normalization factor for pixel to pixel distances (higher value = smoother distribution)
153
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA = 5.0
154
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX = 0.05
155
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL = 0.05
156
+
157
+
158
+ def add_densepose_head_config(cfg: CN) -> None:
159
+ """
160
+ Add config for densepose head.
161
+ """
162
+ _C = cfg
163
+
164
+ _C.MODEL.DENSEPOSE_ON = True
165
+
166
+ _C.MODEL.ROI_DENSEPOSE_HEAD = CN()
167
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NAME = ""
168
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS = 8
169
+ # Number of parts used for point labels
170
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES = 24
171
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL = 4
172
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM = 512
173
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL = 3
174
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE = 2
175
+ _C.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE = 112
176
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE = "ROIAlignV2"
177
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION = 28
178
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO = 2
179
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS = 2 # 15 or 2
180
+ # Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD)
181
+ _C.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD = 0.7
182
+ # Loss weights for annotation masks.(14 Parts)
183
+ _C.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS = 5.0
184
+ # Loss weights for surface parts. (24 Parts)
185
+ _C.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS = 1.0
186
+ # Loss weights for UV regression.
187
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS = 0.01
188
+ # Coarse segmentation is trained using instance segmentation task data
189
+ _C.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS = False
190
+ # For Decoder
191
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_ON = True
192
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NUM_CLASSES = 256
193
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_CONV_DIMS = 256
194
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NORM = ""
195
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_COMMON_STRIDE = 4
196
+ # For DeepLab head
197
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB = CN()
198
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM = "GN"
199
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON = 0
200
+ # Predictor class name, must be registered in DENSEPOSE_PREDICTOR_REGISTRY
201
+ # Some registered predictors:
202
+ # "DensePoseChartPredictor": predicts segmentation and UV coordinates for predefined charts
203
+ # "DensePoseChartWithConfidencePredictor": predicts segmentation, UV coordinates
204
+ # and associated confidences for predefined charts (default)
205
+ # "DensePoseEmbeddingWithConfidencePredictor": predicts segmentation, embeddings
206
+ # and associated confidences for CSE
207
+ _C.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME = "DensePoseChartWithConfidencePredictor"
208
+ # Loss class name, must be registered in DENSEPOSE_LOSS_REGISTRY
209
+ # Some registered losses:
210
+ # "DensePoseChartLoss": loss for chart-based models that estimate
211
+ # segmentation and UV coordinates
212
+ # "DensePoseChartWithConfidenceLoss": loss for chart-based models that estimate
213
+ # segmentation, UV coordinates and the corresponding confidences (default)
214
+ _C.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME = "DensePoseChartWithConfidenceLoss"
215
+ # Confidences
216
+ # Enable learning UV confidences (variances) along with the actual values
217
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE = CN({"ENABLED": False})
218
+ # UV confidence lower bound
219
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.EPSILON = 0.01
220
+ # Enable learning segmentation confidences (variances) along with the actual values
221
+ _C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE = CN({"ENABLED": False})
222
+ # Segmentation confidence lower bound
223
+ _C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.EPSILON = 0.01
224
+ # Statistical model type for confidence learning, possible values:
225
+ # - "iid_iso": statistically independent identically distributed residuals
226
+ # with isotropic covariance
227
+ # - "indep_aniso": statistically independent residuals with anisotropic
228
+ # covariances
229
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.TYPE = "iid_iso"
230
+ # List of angles for rotation in data augmentation during training
231
+ _C.INPUT.ROTATION_ANGLES = [0]
232
+ _C.TEST.AUG.ROTATION_ANGLES = () # Rotation TTA
233
+
234
+ add_densepose_head_cse_config(cfg)
235
+
236
+
237
+ def add_hrnet_config(cfg: CN) -> None:
238
+ """
239
+ Add config for HRNet backbone.
240
+ """
241
+ _C = cfg
242
+
243
+ # For HigherHRNet w32
244
+ _C.MODEL.HRNET = CN()
245
+ _C.MODEL.HRNET.STEM_INPLANES = 64
246
+ _C.MODEL.HRNET.STAGE2 = CN()
247
+ _C.MODEL.HRNET.STAGE2.NUM_MODULES = 1
248
+ _C.MODEL.HRNET.STAGE2.NUM_BRANCHES = 2
249
+ _C.MODEL.HRNET.STAGE2.BLOCK = "BASIC"
250
+ _C.MODEL.HRNET.STAGE2.NUM_BLOCKS = [4, 4]
251
+ _C.MODEL.HRNET.STAGE2.NUM_CHANNELS = [32, 64]
252
+ _C.MODEL.HRNET.STAGE2.FUSE_METHOD = "SUM"
253
+ _C.MODEL.HRNET.STAGE3 = CN()
254
+ _C.MODEL.HRNET.STAGE3.NUM_MODULES = 4
255
+ _C.MODEL.HRNET.STAGE3.NUM_BRANCHES = 3
256
+ _C.MODEL.HRNET.STAGE3.BLOCK = "BASIC"
257
+ _C.MODEL.HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
258
+ _C.MODEL.HRNET.STAGE3.NUM_CHANNELS = [32, 64, 128]
259
+ _C.MODEL.HRNET.STAGE3.FUSE_METHOD = "SUM"
260
+ _C.MODEL.HRNET.STAGE4 = CN()
261
+ _C.MODEL.HRNET.STAGE4.NUM_MODULES = 3
262
+ _C.MODEL.HRNET.STAGE4.NUM_BRANCHES = 4
263
+ _C.MODEL.HRNET.STAGE4.BLOCK = "BASIC"
264
+ _C.MODEL.HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
265
+ _C.MODEL.HRNET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
266
+ _C.MODEL.HRNET.STAGE4.FUSE_METHOD = "SUM"
267
+
268
+ _C.MODEL.HRNET.HRFPN = CN()
269
+ _C.MODEL.HRNET.HRFPN.OUT_CHANNELS = 256
270
+
271
+
272
+ def add_densepose_config(cfg: CN) -> None:
273
+ add_densepose_head_config(cfg)
274
+ add_hrnet_config(cfg)
275
+ add_bootstrap_config(cfg)
276
+ add_dataset_category_config(cfg)
277
+ add_evaluation_config(cfg)
densepose/converters/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .hflip import HFlipConverter
6
+ from .to_mask import ToMaskConverter
7
+ from .to_chart_result import ToChartResultConverter, ToChartResultConverterWithConfidences
8
+ from .segm_to_mask import (
9
+ predictor_output_with_fine_and_coarse_segm_to_mask,
10
+ predictor_output_with_coarse_segm_to_mask,
11
+ resample_fine_and_coarse_segm_to_bbox,
12
+ )
13
+ from .chart_output_to_chart_result import (
14
+ densepose_chart_predictor_output_to_result,
15
+ densepose_chart_predictor_output_to_result_with_confidences,
16
+ )
17
+ from .chart_output_hflip import densepose_chart_predictor_output_hflip
densepose/converters/base.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Tuple, Type
6
+ import torch
7
+
8
+
9
+ class BaseConverter:
10
+ """
11
+ Converter base class to be reused by various converters.
12
+ Converter allows one to convert data from various source types to a particular
13
+ destination type. Each source type needs to register its converter. The
14
+ registration for each source type is valid for all descendants of that type.
15
+ """
16
+
17
+ @classmethod
18
+ def register(cls, from_type: Type, converter: Any = None):
19
+ """
20
+ Registers a converter for the specified type.
21
+ Can be used as a decorator (if converter is None), or called as a method.
22
+
23
+ Args:
24
+ from_type (type): type to register the converter for;
25
+ all instances of this type will use the same converter
26
+ converter (callable): converter to be registered for the given
27
+ type; if None, this method is assumed to be a decorator for the converter
28
+ """
29
+
30
+ if converter is not None:
31
+ cls._do_register(from_type, converter)
32
+
33
+ def wrapper(converter: Any) -> Any:
34
+ cls._do_register(from_type, converter)
35
+ return converter
36
+
37
+ return wrapper
38
+
39
+ @classmethod
40
+ def _do_register(cls, from_type: Type, converter: Any):
41
+ cls.registry[from_type] = converter # pyre-ignore[16]
42
+
43
+ @classmethod
44
+ def _lookup_converter(cls, from_type: Type) -> Any:
45
+ """
46
+ Perform recursive lookup for the given type
47
+ to find registered converter. If a converter was found for some base
48
+ class, it gets registered for this class to save on further lookups.
49
+
50
+ Args:
51
+ from_type: type for which to find a converter
52
+ Return:
53
+ callable or None - registered converter or None
54
+ if no suitable entry was found in the registry
55
+ """
56
+ if from_type in cls.registry: # pyre-ignore[16]
57
+ return cls.registry[from_type]
58
+ for base in from_type.__bases__:
59
+ converter = cls._lookup_converter(base)
60
+ if converter is not None:
61
+ cls._do_register(from_type, converter)
62
+ return converter
63
+ return None
64
+
65
+ @classmethod
66
+ def convert(cls, instance: Any, *args, **kwargs):
67
+ """
68
+ Convert an instance to the destination type using some registered
69
+ converter. Does recursive lookup for base classes, so there's no need
70
+ for explicit registration for derived classes.
71
+
72
+ Args:
73
+ instance: source instance to convert to the destination type
74
+ Return:
75
+ An instance of the destination type obtained from the source instance
76
+ Raises KeyError, if no suitable converter found
77
+ """
78
+ instance_type = type(instance)
79
+ converter = cls._lookup_converter(instance_type)
80
+ if converter is None:
81
+ if cls.dst_type is None: # pyre-ignore[16]
82
+ output_type_str = "itself"
83
+ else:
84
+ output_type_str = cls.dst_type
85
+ raise KeyError(f"Could not find converter from {instance_type} to {output_type_str}")
86
+ return converter(instance, *args, **kwargs)
87
+
88
+
89
+ IntTupleBox = Tuple[int, int, int, int]
90
+
91
+
92
+ def make_int_box(box: torch.Tensor) -> IntTupleBox:
93
+ int_box = [0, 0, 0, 0]
94
+ int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
95
+ return int_box[0], int_box[1], int_box[2], int_box[3]
densepose/converters/builtin.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from ..structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
6
+ from . import (
7
+ HFlipConverter,
8
+ ToChartResultConverter,
9
+ ToChartResultConverterWithConfidences,
10
+ ToMaskConverter,
11
+ densepose_chart_predictor_output_hflip,
12
+ densepose_chart_predictor_output_to_result,
13
+ densepose_chart_predictor_output_to_result_with_confidences,
14
+ predictor_output_with_coarse_segm_to_mask,
15
+ predictor_output_with_fine_and_coarse_segm_to_mask,
16
+ )
17
+
18
+ ToMaskConverter.register(
19
+ DensePoseChartPredictorOutput, predictor_output_with_fine_and_coarse_segm_to_mask
20
+ )
21
+ ToMaskConverter.register(
22
+ DensePoseEmbeddingPredictorOutput, predictor_output_with_coarse_segm_to_mask
23
+ )
24
+
25
+ ToChartResultConverter.register(
26
+ DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result
27
+ )
28
+
29
+ ToChartResultConverterWithConfidences.register(
30
+ DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result_with_confidences
31
+ )
32
+
33
+ HFlipConverter.register(DensePoseChartPredictorOutput, densepose_chart_predictor_output_hflip)
densepose/converters/chart_output_hflip.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from dataclasses import fields
5
+ import torch
6
+
7
+ from densepose.structures import DensePoseChartPredictorOutput, DensePoseTransformData
8
+
9
+
10
+ def densepose_chart_predictor_output_hflip(
11
+ densepose_predictor_output: DensePoseChartPredictorOutput,
12
+ transform_data: DensePoseTransformData,
13
+ ) -> DensePoseChartPredictorOutput:
14
+ """
15
+ Change to take into account a Horizontal flip.
16
+ """
17
+ if len(densepose_predictor_output) > 0:
18
+
19
+ PredictorOutput = type(densepose_predictor_output)
20
+ output_dict = {}
21
+
22
+ for field in fields(densepose_predictor_output):
23
+ field_value = getattr(densepose_predictor_output, field.name)
24
+ # flip tensors
25
+ if isinstance(field_value, torch.Tensor):
26
+ setattr(densepose_predictor_output, field.name, torch.flip(field_value, [3]))
27
+
28
+ densepose_predictor_output = _flip_iuv_semantics_tensor(
29
+ densepose_predictor_output, transform_data
30
+ )
31
+ densepose_predictor_output = _flip_segm_semantics_tensor(
32
+ densepose_predictor_output, transform_data
33
+ )
34
+
35
+ for field in fields(densepose_predictor_output):
36
+ output_dict[field.name] = getattr(densepose_predictor_output, field.name)
37
+
38
+ return PredictorOutput(**output_dict)
39
+ else:
40
+ return densepose_predictor_output
41
+
42
+
43
+ def _flip_iuv_semantics_tensor(
44
+ densepose_predictor_output: DensePoseChartPredictorOutput,
45
+ dp_transform_data: DensePoseTransformData,
46
+ ) -> DensePoseChartPredictorOutput:
47
+ point_label_symmetries = dp_transform_data.point_label_symmetries
48
+ uv_symmetries = dp_transform_data.uv_symmetries
49
+
50
+ N, C, H, W = densepose_predictor_output.u.shape
51
+ u_loc = (densepose_predictor_output.u[:, 1:, :, :].clamp(0, 1) * 255).long()
52
+ v_loc = (densepose_predictor_output.v[:, 1:, :, :].clamp(0, 1) * 255).long()
53
+ Iindex = torch.arange(C - 1, device=densepose_predictor_output.u.device)[
54
+ None, :, None, None
55
+ ].expand(N, C - 1, H, W)
56
+ densepose_predictor_output.u[:, 1:, :, :] = uv_symmetries["U_transforms"][Iindex, v_loc, u_loc]
57
+ densepose_predictor_output.v[:, 1:, :, :] = uv_symmetries["V_transforms"][Iindex, v_loc, u_loc]
58
+
59
+ for el in ["fine_segm", "u", "v"]:
60
+ densepose_predictor_output.__dict__[el] = densepose_predictor_output.__dict__[el][
61
+ :, point_label_symmetries, :, :
62
+ ]
63
+ return densepose_predictor_output
64
+
65
+
66
+ def _flip_segm_semantics_tensor(
67
+ densepose_predictor_output: DensePoseChartPredictorOutput, dp_transform_data
68
+ ):
69
+ if densepose_predictor_output.coarse_segm.shape[1] > 2:
70
+ densepose_predictor_output.coarse_segm = densepose_predictor_output.coarse_segm[
71
+ :, dp_transform_data.mask_label_symmetries, :, :
72
+ ]
73
+ return densepose_predictor_output
densepose/converters/chart_output_to_chart_result.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Dict
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures.boxes import Boxes, BoxMode
10
+
11
+ from ..structures import (
12
+ DensePoseChartPredictorOutput,
13
+ DensePoseChartResult,
14
+ DensePoseChartResultWithConfidences,
15
+ )
16
+ from . import resample_fine_and_coarse_segm_to_bbox
17
+ from .base import IntTupleBox, make_int_box
18
+
19
+
20
+ def resample_uv_tensors_to_bbox(
21
+ u: torch.Tensor,
22
+ v: torch.Tensor,
23
+ labels: torch.Tensor,
24
+ box_xywh_abs: IntTupleBox,
25
+ ) -> torch.Tensor:
26
+ """
27
+ Resamples U and V coordinate estimates for the given bounding box
28
+
29
+ Args:
30
+ u (tensor [1, C, H, W] of float): U coordinates
31
+ v (tensor [1, C, H, W] of float): V coordinates
32
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
33
+ outputs for the given bounding box
34
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
35
+ Return:
36
+ Resampled U and V coordinates - a tensor [2, H, W] of float
37
+ """
38
+ x, y, w, h = box_xywh_abs
39
+ w = max(int(w), 1)
40
+ h = max(int(h), 1)
41
+ u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
42
+ v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
43
+ uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
44
+ for part_id in range(1, u_bbox.size(1)):
45
+ uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
46
+ uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
47
+ return uv
48
+
49
+
50
+ def resample_uv_to_bbox(
51
+ predictor_output: DensePoseChartPredictorOutput,
52
+ labels: torch.Tensor,
53
+ box_xywh_abs: IntTupleBox,
54
+ ) -> torch.Tensor:
55
+ """
56
+ Resamples U and V coordinate estimates for the given bounding box
57
+
58
+ Args:
59
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
60
+ output to be resampled
61
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
62
+ outputs for the given bounding box
63
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
64
+ Return:
65
+ Resampled U and V coordinates - a tensor [2, H, W] of float
66
+ """
67
+ return resample_uv_tensors_to_bbox(
68
+ predictor_output.u,
69
+ predictor_output.v,
70
+ labels,
71
+ box_xywh_abs,
72
+ )
73
+
74
+
75
+ def densepose_chart_predictor_output_to_result(
76
+ predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
77
+ ) -> DensePoseChartResult:
78
+ """
79
+ Convert densepose chart predictor outputs to results
80
+
81
+ Args:
82
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
83
+ output to be converted to results, must contain only 1 output
84
+ boxes (Boxes): bounding box that corresponds to the predictor output,
85
+ must contain only 1 bounding box
86
+ Return:
87
+ DensePose chart-based result (DensePoseChartResult)
88
+ """
89
+ assert len(predictor_output) == 1 and len(boxes) == 1, (
90
+ f"Predictor output to result conversion can operate only single outputs"
91
+ f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
92
+ )
93
+
94
+ boxes_xyxy_abs = boxes.tensor.clone()
95
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
96
+ box_xywh = make_int_box(boxes_xywh_abs[0])
97
+
98
+ labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
99
+ uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
100
+ return DensePoseChartResult(labels=labels, uv=uv)
101
+
102
+
103
+ def resample_confidences_to_bbox(
104
+ predictor_output: DensePoseChartPredictorOutput,
105
+ labels: torch.Tensor,
106
+ box_xywh_abs: IntTupleBox,
107
+ ) -> Dict[str, torch.Tensor]:
108
+ """
109
+ Resamples confidences for the given bounding box
110
+
111
+ Args:
112
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
113
+ output to be resampled
114
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
115
+ outputs for the given bounding box
116
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
117
+ Return:
118
+ Resampled confidences - a dict of [H, W] tensors of float
119
+ """
120
+
121
+ x, y, w, h = box_xywh_abs
122
+ w = max(int(w), 1)
123
+ h = max(int(h), 1)
124
+
125
+ confidence_names = [
126
+ "sigma_1",
127
+ "sigma_2",
128
+ "kappa_u",
129
+ "kappa_v",
130
+ "fine_segm_confidence",
131
+ "coarse_segm_confidence",
132
+ ]
133
+ confidence_results = {key: None for key in confidence_names}
134
+ confidence_names = [
135
+ key for key in confidence_names if getattr(predictor_output, key) is not None
136
+ ]
137
+ confidence_base = torch.zeros([h, w], dtype=torch.float32, device=predictor_output.u.device)
138
+
139
+ # assign data from channels that correspond to the labels
140
+ for key in confidence_names:
141
+ resampled_confidence = F.interpolate(
142
+ getattr(predictor_output, key),
143
+ (h, w),
144
+ mode="bilinear",
145
+ align_corners=False,
146
+ )
147
+ result = confidence_base.clone()
148
+ for part_id in range(1, predictor_output.u.size(1)):
149
+ if resampled_confidence.size(1) != predictor_output.u.size(1):
150
+ # confidence is not part-based, don't try to fill it part by part
151
+ continue
152
+ result[labels == part_id] = resampled_confidence[0, part_id][labels == part_id]
153
+
154
+ if resampled_confidence.size(1) != predictor_output.u.size(1):
155
+ # confidence is not part-based, fill the data with the first channel
156
+ # (targeted for segmentation confidences that have only 1 channel)
157
+ result = resampled_confidence[0, 0]
158
+
159
+ confidence_results[key] = result
160
+
161
+ return confidence_results # pyre-ignore[7]
162
+
163
+
164
+ def densepose_chart_predictor_output_to_result_with_confidences(
165
+ predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
166
+ ) -> DensePoseChartResultWithConfidences:
167
+ """
168
+ Convert densepose chart predictor outputs to results
169
+
170
+ Args:
171
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
172
+ output with confidences to be converted to results, must contain only 1 output
173
+ boxes (Boxes): bounding box that corresponds to the predictor output,
174
+ must contain only 1 bounding box
175
+ Return:
176
+ DensePose chart-based result with confidences (DensePoseChartResultWithConfidences)
177
+ """
178
+ assert len(predictor_output) == 1 and len(boxes) == 1, (
179
+ f"Predictor output to result conversion can operate only single outputs"
180
+ f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
181
+ )
182
+
183
+ boxes_xyxy_abs = boxes.tensor.clone()
184
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
185
+ box_xywh = make_int_box(boxes_xywh_abs[0])
186
+
187
+ labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
188
+ uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
189
+ confidences = resample_confidences_to_bbox(predictor_output, labels, box_xywh)
190
+ return DensePoseChartResultWithConfidences(labels=labels, uv=uv, **confidences)
densepose/converters/hflip.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+
7
+ from .base import BaseConverter
8
+
9
+
10
+ class HFlipConverter(BaseConverter):
11
+ """
12
+ Converts various DensePose predictor outputs to DensePose results.
13
+ Each DensePose predictor output type has to register its convertion strategy.
14
+ """
15
+
16
+ registry = {}
17
+ dst_type = None
18
+
19
+ @classmethod
20
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
21
+ # inconsistently.
22
+ def convert(cls, predictor_outputs: Any, transform_data: Any, *args, **kwargs):
23
+ """
24
+ Performs an horizontal flip on DensePose predictor outputs.
25
+ Does recursive lookup for base classes, so there's no need
26
+ for explicit registration for derived classes.
27
+
28
+ Args:
29
+ predictor_outputs: DensePose predictor output to be converted to BitMasks
30
+ transform_data: Anything useful for the flip
31
+ Return:
32
+ An instance of the same type as predictor_outputs
33
+ """
34
+ return super(HFlipConverter, cls).convert(
35
+ predictor_outputs, transform_data, *args, **kwargs
36
+ )
densepose/converters/segm_to_mask.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures import BitMasks, Boxes, BoxMode
10
+
11
+ from .base import IntTupleBox, make_int_box
12
+ from .to_mask import ImageSizeType
13
+
14
+
15
+ def resample_coarse_segm_tensor_to_bbox(coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox):
16
+ """
17
+ Resample coarse segmentation tensor to the given
18
+ bounding box and derive labels for each pixel of the bounding box
19
+
20
+ Args:
21
+ coarse_segm: float tensor of shape [1, K, Hout, Wout]
22
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
23
+ corner coordinates, width (W) and height (H)
24
+ Return:
25
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
26
+ """
27
+ x, y, w, h = box_xywh_abs
28
+ w = max(int(w), 1)
29
+ h = max(int(h), 1)
30
+ labels = F.interpolate(coarse_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
31
+ return labels
32
+
33
+
34
+ def resample_fine_and_coarse_segm_tensors_to_bbox(
35
+ fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
36
+ ):
37
+ """
38
+ Resample fine and coarse segmentation tensors to the given
39
+ bounding box and derive labels for each pixel of the bounding box
40
+
41
+ Args:
42
+ fine_segm: float tensor of shape [1, C, Hout, Wout]
43
+ coarse_segm: float tensor of shape [1, K, Hout, Wout]
44
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
45
+ corner coordinates, width (W) and height (H)
46
+ Return:
47
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
48
+ """
49
+ x, y, w, h = box_xywh_abs
50
+ w = max(int(w), 1)
51
+ h = max(int(h), 1)
52
+ # coarse segmentation
53
+ coarse_segm_bbox = F.interpolate(
54
+ coarse_segm,
55
+ (h, w),
56
+ mode="bilinear",
57
+ align_corners=False,
58
+ ).argmax(dim=1)
59
+ # combined coarse and fine segmentation
60
+ labels = (
61
+ F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
62
+ * (coarse_segm_bbox > 0).long()
63
+ )
64
+ return labels
65
+
66
+
67
+ def resample_fine_and_coarse_segm_to_bbox(predictor_output: Any, box_xywh_abs: IntTupleBox):
68
+ """
69
+ Resample fine and coarse segmentation outputs from a predictor to the given
70
+ bounding box and derive labels for each pixel of the bounding box
71
+
72
+ Args:
73
+ predictor_output: DensePose predictor output that contains segmentation
74
+ results to be resampled
75
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
76
+ corner coordinates, width (W) and height (H)
77
+ Return:
78
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
79
+ """
80
+ return resample_fine_and_coarse_segm_tensors_to_bbox(
81
+ predictor_output.fine_segm,
82
+ predictor_output.coarse_segm,
83
+ box_xywh_abs,
84
+ )
85
+
86
+
87
+ def predictor_output_with_coarse_segm_to_mask(
88
+ predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
89
+ ) -> BitMasks:
90
+ """
91
+ Convert predictor output with coarse and fine segmentation to a mask.
92
+ Assumes that predictor output has the following attributes:
93
+ - coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
94
+ unnormalized scores for N instances; D is the number of coarse
95
+ segmentation labels, H and W is the resolution of the estimate
96
+
97
+ Args:
98
+ predictor_output: DensePose predictor output to be converted to mask
99
+ boxes (Boxes): bounding boxes that correspond to the DensePose
100
+ predictor outputs
101
+ image_size_hw (tuple [int, int]): image height Himg and width Wimg
102
+ Return:
103
+ BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
104
+ a mask of the size of the image for each instance
105
+ """
106
+ H, W = image_size_hw
107
+ boxes_xyxy_abs = boxes.tensor.clone()
108
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
109
+ N = len(boxes_xywh_abs)
110
+ masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
111
+ for i in range(len(boxes_xywh_abs)):
112
+ box_xywh = make_int_box(boxes_xywh_abs[i])
113
+ box_mask = resample_coarse_segm_tensor_to_bbox(predictor_output[i].coarse_segm, box_xywh)
114
+ x, y, w, h = box_xywh
115
+ masks[i, y : y + h, x : x + w] = box_mask
116
+
117
+ return BitMasks(masks)
118
+
119
+
120
+ def predictor_output_with_fine_and_coarse_segm_to_mask(
121
+ predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
122
+ ) -> BitMasks:
123
+ """
124
+ Convert predictor output with coarse and fine segmentation to a mask.
125
+ Assumes that predictor output has the following attributes:
126
+ - coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
127
+ unnormalized scores for N instances; D is the number of coarse
128
+ segmentation labels, H and W is the resolution of the estimate
129
+ - fine_segm (tensor of size [N, C, H, W]): fine segmentation
130
+ unnormalized scores for N instances; C is the number of fine
131
+ segmentation labels, H and W is the resolution of the estimate
132
+
133
+ Args:
134
+ predictor_output: DensePose predictor output to be converted to mask
135
+ boxes (Boxes): bounding boxes that correspond to the DensePose
136
+ predictor outputs
137
+ image_size_hw (tuple [int, int]): image height Himg and width Wimg
138
+ Return:
139
+ BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
140
+ a mask of the size of the image for each instance
141
+ """
142
+ H, W = image_size_hw
143
+ boxes_xyxy_abs = boxes.tensor.clone()
144
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
145
+ N = len(boxes_xywh_abs)
146
+ masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
147
+ for i in range(len(boxes_xywh_abs)):
148
+ box_xywh = make_int_box(boxes_xywh_abs[i])
149
+ labels_i = resample_fine_and_coarse_segm_to_bbox(predictor_output[i], box_xywh)
150
+ x, y, w, h = box_xywh
151
+ masks[i, y : y + h, x : x + w] = labels_i > 0
152
+ return BitMasks(masks)
densepose/converters/to_chart_result.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+
7
+ from detectron2.structures import Boxes
8
+
9
+ from ..structures import DensePoseChartResult, DensePoseChartResultWithConfidences
10
+ from .base import BaseConverter
11
+
12
+
13
+ class ToChartResultConverter(BaseConverter):
14
+ """
15
+ Converts various DensePose predictor outputs to DensePose results.
16
+ Each DensePose predictor output type has to register its convertion strategy.
17
+ """
18
+
19
+ registry = {}
20
+ dst_type = DensePoseChartResult
21
+
22
+ @classmethod
23
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
24
+ # inconsistently.
25
+ def convert(cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs) -> DensePoseChartResult:
26
+ """
27
+ Convert DensePose predictor outputs to DensePoseResult using some registered
28
+ converter. Does recursive lookup for base classes, so there's no need
29
+ for explicit registration for derived classes.
30
+
31
+ Args:
32
+ densepose_predictor_outputs: DensePose predictor output to be
33
+ converted to BitMasks
34
+ boxes (Boxes): bounding boxes that correspond to the DensePose
35
+ predictor outputs
36
+ Return:
37
+ An instance of DensePoseResult. If no suitable converter was found, raises KeyError
38
+ """
39
+ return super(ToChartResultConverter, cls).convert(predictor_outputs, boxes, *args, **kwargs)
40
+
41
+
42
+ class ToChartResultConverterWithConfidences(BaseConverter):
43
+ """
44
+ Converts various DensePose predictor outputs to DensePose results.
45
+ Each DensePose predictor output type has to register its convertion strategy.
46
+ """
47
+
48
+ registry = {}
49
+ dst_type = DensePoseChartResultWithConfidences
50
+
51
+ @classmethod
52
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
53
+ # inconsistently.
54
+ def convert(
55
+ cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs
56
+ ) -> DensePoseChartResultWithConfidences:
57
+ """
58
+ Convert DensePose predictor outputs to DensePoseResult with confidences
59
+ using some registered converter. Does recursive lookup for base classes,
60
+ so there's no need for explicit registration for derived classes.
61
+
62
+ Args:
63
+ densepose_predictor_outputs: DensePose predictor output with confidences
64
+ to be converted to BitMasks
65
+ boxes (Boxes): bounding boxes that correspond to the DensePose
66
+ predictor outputs
67
+ Return:
68
+ An instance of DensePoseResult. If no suitable converter was found, raises KeyError
69
+ """
70
+ return super(ToChartResultConverterWithConfidences, cls).convert(
71
+ predictor_outputs, boxes, *args, **kwargs
72
+ )
densepose/converters/to_mask.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Tuple
6
+
7
+ from detectron2.structures import BitMasks, Boxes
8
+
9
+ from .base import BaseConverter
10
+
11
+ ImageSizeType = Tuple[int, int]
12
+
13
+
14
+ class ToMaskConverter(BaseConverter):
15
+ """
16
+ Converts various DensePose predictor outputs to masks
17
+ in bit mask format (see `BitMasks`). Each DensePose predictor output type
18
+ has to register its convertion strategy.
19
+ """
20
+
21
+ registry = {}
22
+ dst_type = BitMasks
23
+
24
+ @classmethod
25
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
26
+ # inconsistently.
27
+ def convert(
28
+ cls,
29
+ densepose_predictor_outputs: Any,
30
+ boxes: Boxes,
31
+ image_size_hw: ImageSizeType,
32
+ *args,
33
+ **kwargs
34
+ ) -> BitMasks:
35
+ """
36
+ Convert DensePose predictor outputs to BitMasks using some registered
37
+ converter. Does recursive lookup for base classes, so there's no need
38
+ for explicit registration for derived classes.
39
+
40
+ Args:
41
+ densepose_predictor_outputs: DensePose predictor output to be
42
+ converted to BitMasks
43
+ boxes (Boxes): bounding boxes that correspond to the DensePose
44
+ predictor outputs
45
+ image_size_hw (tuple [int, int]): image height and width
46
+ Return:
47
+ An instance of `BitMasks`. If no suitable converter was found, raises KeyError
48
+ """
49
+ return super(ToMaskConverter, cls).convert(
50
+ densepose_predictor_outputs, boxes, image_size_hw, *args, **kwargs
51
+ )
densepose/data/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .meshes import builtin
6
+ from .build import (
7
+ build_detection_test_loader,
8
+ build_detection_train_loader,
9
+ build_combined_loader,
10
+ build_frame_selector,
11
+ build_inference_based_loaders,
12
+ has_inference_based_loaders,
13
+ BootstrapDatasetFactoryCatalog,
14
+ )
15
+ from .combined_loader import CombinedDataLoader
16
+ from .dataset_mapper import DatasetMapper
17
+ from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
18
+ from .image_list_dataset import ImageListDataset
19
+ from .utils import is_relative_local_path, maybe_prepend_base_path
20
+
21
+ # ensure the builtin datasets are registered
22
+ from . import datasets
23
+
24
+ # ensure the bootstrap datasets builders are registered
25
+ from . import build
26
+
27
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
densepose/data/build.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import itertools
6
+ import logging
7
+ import numpy as np
8
+ from collections import UserDict, defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple
11
+ import torch
12
+ from torch.utils.data.dataset import Dataset
13
+
14
+ from detectron2.config import CfgNode
15
+ from detectron2.data.build import build_detection_test_loader as d2_build_detection_test_loader
16
+ from detectron2.data.build import build_detection_train_loader as d2_build_detection_train_loader
17
+ from detectron2.data.build import (
18
+ load_proposals_into_dataset,
19
+ print_instances_class_histogram,
20
+ trivial_batch_collator,
21
+ worker_init_reset_seed,
22
+ )
23
+ from detectron2.data.catalog import DatasetCatalog, Metadata, MetadataCatalog
24
+ from detectron2.data.samplers import TrainingSampler
25
+ from detectron2.utils.comm import get_world_size
26
+
27
+ from densepose.config import get_bootstrap_dataset_config
28
+ from densepose.modeling import build_densepose_embedder
29
+
30
+ from .combined_loader import CombinedDataLoader, Loader
31
+ from .dataset_mapper import DatasetMapper
32
+ from .datasets.coco import DENSEPOSE_CSE_KEYS_WITHOUT_MASK, DENSEPOSE_IUV_KEYS_WITHOUT_MASK
33
+ from .datasets.dataset_type import DatasetType
34
+ from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
35
+ from .samplers import (
36
+ DensePoseConfidenceBasedSampler,
37
+ DensePoseCSEConfidenceBasedSampler,
38
+ DensePoseCSEUniformSampler,
39
+ DensePoseUniformSampler,
40
+ MaskFromDensePoseSampler,
41
+ PredictionToGroundTruthSampler,
42
+ )
43
+ from .transform import ImageResizeTransform
44
+ from .utils import get_category_to_class_mapping, get_class_to_mesh_name_mapping
45
+ from .video import (
46
+ FirstKFramesSelector,
47
+ FrameSelectionStrategy,
48
+ LastKFramesSelector,
49
+ RandomKFramesSelector,
50
+ VideoKeyframeDataset,
51
+ video_list_from_file,
52
+ )
53
+
54
+ __all__ = ["build_detection_train_loader", "build_detection_test_loader"]
55
+
56
+
57
+ Instance = Dict[str, Any]
58
+ InstancePredicate = Callable[[Instance], bool]
59
+
60
+
61
+ def _compute_num_images_per_worker(cfg: CfgNode) -> int:
62
+ num_workers = get_world_size()
63
+ images_per_batch = cfg.SOLVER.IMS_PER_BATCH
64
+ assert (
65
+ images_per_batch % num_workers == 0
66
+ ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
67
+ images_per_batch, num_workers
68
+ )
69
+ assert (
70
+ images_per_batch >= num_workers
71
+ ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
72
+ images_per_batch, num_workers
73
+ )
74
+ images_per_worker = images_per_batch // num_workers
75
+ return images_per_worker
76
+
77
+
78
+ def _map_category_id_to_contiguous_id(dataset_name: str, dataset_dicts: Iterable[Instance]) -> None:
79
+ meta = MetadataCatalog.get(dataset_name)
80
+ for dataset_dict in dataset_dicts:
81
+ for ann in dataset_dict["annotations"]:
82
+ ann["category_id"] = meta.thing_dataset_id_to_contiguous_id[ann["category_id"]]
83
+
84
+
85
+ @dataclass
86
+ class _DatasetCategory:
87
+ """
88
+ Class representing category data in a dataset:
89
+ - id: category ID, as specified in the dataset annotations file
90
+ - name: category name, as specified in the dataset annotations file
91
+ - mapped_id: category ID after applying category maps (DATASETS.CATEGORY_MAPS config option)
92
+ - mapped_name: category name after applying category maps
93
+ - dataset_name: dataset in which the category is defined
94
+
95
+ For example, when training models in a class-agnostic manner, one could take LVIS 1.0
96
+ dataset and map the animal categories to the same category as human data from COCO:
97
+ id = 225
98
+ name = "cat"
99
+ mapped_id = 1
100
+ mapped_name = "person"
101
+ dataset_name = "lvis_v1_animals_dp_train"
102
+ """
103
+
104
+ id: int
105
+ name: str
106
+ mapped_id: int
107
+ mapped_name: str
108
+ dataset_name: str
109
+
110
+
111
+ _MergedCategoriesT = Dict[int, List[_DatasetCategory]]
112
+
113
+
114
+ def _add_category_id_to_contiguous_id_maps_to_metadata(
115
+ merged_categories: _MergedCategoriesT,
116
+ ) -> None:
117
+ merged_categories_per_dataset = {}
118
+ for contiguous_cat_id, cat_id in enumerate(sorted(merged_categories.keys())):
119
+ for cat in merged_categories[cat_id]:
120
+ if cat.dataset_name not in merged_categories_per_dataset:
121
+ merged_categories_per_dataset[cat.dataset_name] = defaultdict(list)
122
+ merged_categories_per_dataset[cat.dataset_name][cat_id].append(
123
+ (
124
+ contiguous_cat_id,
125
+ cat,
126
+ )
127
+ )
128
+
129
+ logger = logging.getLogger(__name__)
130
+ for dataset_name, merged_categories in merged_categories_per_dataset.items():
131
+ meta = MetadataCatalog.get(dataset_name)
132
+ if not hasattr(meta, "thing_classes"):
133
+ meta.thing_classes = []
134
+ meta.thing_dataset_id_to_contiguous_id = {}
135
+ meta.thing_dataset_id_to_merged_id = {}
136
+ else:
137
+ meta.thing_classes.clear()
138
+ meta.thing_dataset_id_to_contiguous_id.clear()
139
+ meta.thing_dataset_id_to_merged_id.clear()
140
+ logger.info(f"Dataset {dataset_name}: category ID to contiguous ID mapping:")
141
+ for _cat_id, categories in sorted(merged_categories.items()):
142
+ added_to_thing_classes = False
143
+ for contiguous_cat_id, cat in categories:
144
+ if not added_to_thing_classes:
145
+ meta.thing_classes.append(cat.mapped_name)
146
+ added_to_thing_classes = True
147
+ meta.thing_dataset_id_to_contiguous_id[cat.id] = contiguous_cat_id
148
+ meta.thing_dataset_id_to_merged_id[cat.id] = cat.mapped_id
149
+ logger.info(f"{cat.id} ({cat.name}) -> {contiguous_cat_id}")
150
+
151
+
152
+ def _maybe_create_general_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
153
+ def has_annotations(instance: Instance) -> bool:
154
+ return "annotations" in instance
155
+
156
+ def has_only_crowd_anotations(instance: Instance) -> bool:
157
+ for ann in instance["annotations"]:
158
+ if ann.get("is_crowd", 0) == 0:
159
+ return False
160
+ return True
161
+
162
+ def general_keep_instance_predicate(instance: Instance) -> bool:
163
+ return has_annotations(instance) and not has_only_crowd_anotations(instance)
164
+
165
+ if not cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS:
166
+ return None
167
+ return general_keep_instance_predicate
168
+
169
+
170
+ def _maybe_create_keypoints_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
171
+
172
+ min_num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
173
+
174
+ def has_sufficient_num_keypoints(instance: Instance) -> bool:
175
+ num_kpts = sum(
176
+ (np.array(ann["keypoints"][2::3]) > 0).sum()
177
+ for ann in instance["annotations"]
178
+ if "keypoints" in ann
179
+ )
180
+ return num_kpts >= min_num_keypoints
181
+
182
+ if cfg.MODEL.KEYPOINT_ON and (min_num_keypoints > 0):
183
+ return has_sufficient_num_keypoints
184
+ return None
185
+
186
+
187
+ def _maybe_create_mask_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
188
+ if not cfg.MODEL.MASK_ON:
189
+ return None
190
+
191
+ def has_mask_annotations(instance: Instance) -> bool:
192
+ return any("segmentation" in ann for ann in instance["annotations"])
193
+
194
+ return has_mask_annotations
195
+
196
+
197
+ def _maybe_create_densepose_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
198
+ if not cfg.MODEL.DENSEPOSE_ON:
199
+ return None
200
+
201
+ use_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
202
+
203
+ def has_densepose_annotations(instance: Instance) -> bool:
204
+ for ann in instance["annotations"]:
205
+ if all(key in ann for key in DENSEPOSE_IUV_KEYS_WITHOUT_MASK) or all(
206
+ key in ann for key in DENSEPOSE_CSE_KEYS_WITHOUT_MASK
207
+ ):
208
+ return True
209
+ if use_masks and "segmentation" in ann:
210
+ return True
211
+ return False
212
+
213
+ return has_densepose_annotations
214
+
215
+
216
+ def _maybe_create_specific_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
217
+ specific_predicate_creators = [
218
+ _maybe_create_keypoints_keep_instance_predicate,
219
+ _maybe_create_mask_keep_instance_predicate,
220
+ _maybe_create_densepose_keep_instance_predicate,
221
+ ]
222
+ predicates = [creator(cfg) for creator in specific_predicate_creators]
223
+ predicates = [p for p in predicates if p is not None]
224
+ if not predicates:
225
+ return None
226
+
227
+ def combined_predicate(instance: Instance) -> bool:
228
+ return any(p(instance) for p in predicates)
229
+
230
+ return combined_predicate
231
+
232
+
233
+ def _get_train_keep_instance_predicate(cfg: CfgNode):
234
+ general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
235
+ combined_specific_keep_predicate = _maybe_create_specific_keep_instance_predicate(cfg)
236
+
237
+ def combined_general_specific_keep_predicate(instance: Instance) -> bool:
238
+ return general_keep_predicate(instance) and combined_specific_keep_predicate(instance)
239
+
240
+ if (general_keep_predicate is None) and (combined_specific_keep_predicate is None):
241
+ return None
242
+ if general_keep_predicate is None:
243
+ return combined_specific_keep_predicate
244
+ if combined_specific_keep_predicate is None:
245
+ return general_keep_predicate
246
+ return combined_general_specific_keep_predicate
247
+
248
+
249
+ def _get_test_keep_instance_predicate(cfg: CfgNode):
250
+ general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
251
+ return general_keep_predicate
252
+
253
+
254
+ def _maybe_filter_and_map_categories(
255
+ dataset_name: str, dataset_dicts: List[Instance]
256
+ ) -> List[Instance]:
257
+ meta = MetadataCatalog.get(dataset_name)
258
+ category_id_map = meta.thing_dataset_id_to_contiguous_id
259
+ filtered_dataset_dicts = []
260
+ for dataset_dict in dataset_dicts:
261
+ anns = []
262
+ for ann in dataset_dict["annotations"]:
263
+ cat_id = ann["category_id"]
264
+ if cat_id not in category_id_map:
265
+ continue
266
+ ann["category_id"] = category_id_map[cat_id]
267
+ anns.append(ann)
268
+ dataset_dict["annotations"] = anns
269
+ filtered_dataset_dicts.append(dataset_dict)
270
+ return filtered_dataset_dicts
271
+
272
+
273
+ def _add_category_whitelists_to_metadata(cfg: CfgNode) -> None:
274
+ for dataset_name, whitelisted_cat_ids in cfg.DATASETS.WHITELISTED_CATEGORIES.items():
275
+ meta = MetadataCatalog.get(dataset_name)
276
+ meta.whitelisted_categories = whitelisted_cat_ids
277
+ logger = logging.getLogger(__name__)
278
+ logger.info(
279
+ "Whitelisted categories for dataset {}: {}".format(
280
+ dataset_name, meta.whitelisted_categories
281
+ )
282
+ )
283
+
284
+
285
+ def _add_category_maps_to_metadata(cfg: CfgNode) -> None:
286
+ for dataset_name, category_map in cfg.DATASETS.CATEGORY_MAPS.items():
287
+ category_map = {
288
+ int(cat_id_src): int(cat_id_dst) for cat_id_src, cat_id_dst in category_map.items()
289
+ }
290
+ meta = MetadataCatalog.get(dataset_name)
291
+ meta.category_map = category_map
292
+ logger = logging.getLogger(__name__)
293
+ logger.info("Category maps for dataset {}: {}".format(dataset_name, meta.category_map))
294
+
295
+
296
+ def _add_category_info_to_bootstrapping_metadata(dataset_name: str, dataset_cfg: CfgNode) -> None:
297
+ meta = MetadataCatalog.get(dataset_name)
298
+ meta.category_to_class_mapping = get_category_to_class_mapping(dataset_cfg)
299
+ meta.categories = dataset_cfg.CATEGORIES
300
+ meta.max_count_per_category = dataset_cfg.MAX_COUNT_PER_CATEGORY
301
+ logger = logging.getLogger(__name__)
302
+ logger.info(
303
+ "Category to class mapping for dataset {}: {}".format(
304
+ dataset_name, meta.category_to_class_mapping
305
+ )
306
+ )
307
+
308
+
309
+ def _maybe_add_class_to_mesh_name_map_to_metadata(dataset_names: List[str], cfg: CfgNode) -> None:
310
+ for dataset_name in dataset_names:
311
+ meta = MetadataCatalog.get(dataset_name)
312
+ if not hasattr(meta, "class_to_mesh_name"):
313
+ meta.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
314
+
315
+
316
+ def _merge_categories(dataset_names: Collection[str]) -> _MergedCategoriesT:
317
+ merged_categories = defaultdict(list)
318
+ category_names = {}
319
+ for dataset_name in dataset_names:
320
+ meta = MetadataCatalog.get(dataset_name)
321
+ whitelisted_categories = meta.get("whitelisted_categories")
322
+ category_map = meta.get("category_map", {})
323
+ cat_ids = (
324
+ whitelisted_categories if whitelisted_categories is not None else meta.categories.keys()
325
+ )
326
+ for cat_id in cat_ids:
327
+ cat_name = meta.categories[cat_id]
328
+ cat_id_mapped = category_map.get(cat_id, cat_id)
329
+ if cat_id_mapped == cat_id or cat_id_mapped in cat_ids:
330
+ category_names[cat_id] = cat_name
331
+ else:
332
+ category_names[cat_id] = str(cat_id_mapped)
333
+ # assign temporary mapped category name, this name can be changed
334
+ # during the second pass, since mapped ID can correspond to a category
335
+ # from a different dataset
336
+ cat_name_mapped = meta.categories[cat_id_mapped]
337
+ merged_categories[cat_id_mapped].append(
338
+ _DatasetCategory(
339
+ id=cat_id,
340
+ name=cat_name,
341
+ mapped_id=cat_id_mapped,
342
+ mapped_name=cat_name_mapped,
343
+ dataset_name=dataset_name,
344
+ )
345
+ )
346
+ # second pass to assign proper mapped category names
347
+ for cat_id, categories in merged_categories.items():
348
+ for cat in categories:
349
+ if cat_id in category_names and cat.mapped_name != category_names[cat_id]:
350
+ cat.mapped_name = category_names[cat_id]
351
+
352
+ return merged_categories
353
+
354
+
355
+ def _warn_if_merged_different_categories(merged_categories: _MergedCategoriesT) -> None:
356
+ logger = logging.getLogger(__name__)
357
+ for cat_id in merged_categories:
358
+ merged_categories_i = merged_categories[cat_id]
359
+ first_cat_name = merged_categories_i[0].name
360
+ if len(merged_categories_i) > 1 and not all(
361
+ cat.name == first_cat_name for cat in merged_categories_i[1:]
362
+ ):
363
+ cat_summary_str = ", ".join(
364
+ [f"{cat.id} ({cat.name}) from {cat.dataset_name}" for cat in merged_categories_i]
365
+ )
366
+ logger.warning(
367
+ f"Merged category {cat_id} corresponds to the following categories: "
368
+ f"{cat_summary_str}"
369
+ )
370
+
371
+
372
+ def combine_detection_dataset_dicts(
373
+ dataset_names: Collection[str],
374
+ keep_instance_predicate: Optional[InstancePredicate] = None,
375
+ proposal_files: Optional[Collection[str]] = None,
376
+ ) -> List[Instance]:
377
+ """
378
+ Load and prepare dataset dicts for training / testing
379
+
380
+ Args:
381
+ dataset_names (Collection[str]): a list of dataset names
382
+ keep_instance_predicate (Callable: Dict[str, Any] -> bool): predicate
383
+ applied to instance dicts which defines whether to keep the instance
384
+ proposal_files (Collection[str]): if given, a list of object proposal files
385
+ that match each dataset in `dataset_names`.
386
+ """
387
+ assert len(dataset_names)
388
+ if proposal_files is None:
389
+ proposal_files = [None] * len(dataset_names)
390
+ assert len(dataset_names) == len(proposal_files)
391
+ # load datasets and metadata
392
+ dataset_name_to_dicts = {}
393
+ for dataset_name in dataset_names:
394
+ dataset_name_to_dicts[dataset_name] = DatasetCatalog.get(dataset_name)
395
+ assert len(dataset_name_to_dicts), f"Dataset '{dataset_name}' is empty!"
396
+ # merge categories, requires category metadata to be loaded
397
+ # cat_id -> [(orig_cat_id, cat_name, dataset_name)]
398
+ merged_categories = _merge_categories(dataset_names)
399
+ _warn_if_merged_different_categories(merged_categories)
400
+ merged_category_names = [
401
+ merged_categories[cat_id][0].mapped_name for cat_id in sorted(merged_categories)
402
+ ]
403
+ # map to contiguous category IDs
404
+ _add_category_id_to_contiguous_id_maps_to_metadata(merged_categories)
405
+ # load annotations and dataset metadata
406
+ for dataset_name, proposal_file in zip(dataset_names, proposal_files):
407
+ dataset_dicts = dataset_name_to_dicts[dataset_name]
408
+ assert len(dataset_dicts), f"Dataset '{dataset_name}' is empty!"
409
+ if proposal_file is not None:
410
+ dataset_dicts = load_proposals_into_dataset(dataset_dicts, proposal_file)
411
+ dataset_dicts = _maybe_filter_and_map_categories(dataset_name, dataset_dicts)
412
+ print_instances_class_histogram(dataset_dicts, merged_category_names)
413
+ dataset_name_to_dicts[dataset_name] = dataset_dicts
414
+
415
+ if keep_instance_predicate is not None:
416
+ all_datasets_dicts_plain = [
417
+ d
418
+ for d in itertools.chain.from_iterable(dataset_name_to_dicts.values())
419
+ if keep_instance_predicate(d)
420
+ ]
421
+ else:
422
+ all_datasets_dicts_plain = list(
423
+ itertools.chain.from_iterable(dataset_name_to_dicts.values())
424
+ )
425
+ return all_datasets_dicts_plain
426
+
427
+
428
+ def build_detection_train_loader(cfg: CfgNode, mapper=None):
429
+ """
430
+ A data loader is created in a way similar to that of Detectron2.
431
+ The main differences are:
432
+ - it allows to combine datasets with different but compatible object category sets
433
+
434
+ The data loader is created by the following steps:
435
+ 1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
436
+ 2. Start workers to work on the dicts. Each worker will:
437
+ * Map each metadata dict into another format to be consumed by the model.
438
+ * Batch them by simply putting dicts into a list.
439
+ The batched ``list[mapped_dict]`` is what this dataloader will return.
440
+
441
+ Args:
442
+ cfg (CfgNode): the config
443
+ mapper (callable): a callable which takes a sample (dict) from dataset and
444
+ returns the format to be consumed by the model.
445
+ By default it will be `DatasetMapper(cfg, True)`.
446
+
447
+ Returns:
448
+ an infinite iterator of training data
449
+ """
450
+
451
+ _add_category_whitelists_to_metadata(cfg)
452
+ _add_category_maps_to_metadata(cfg)
453
+ _maybe_add_class_to_mesh_name_map_to_metadata(cfg.DATASETS.TRAIN, cfg)
454
+ dataset_dicts = combine_detection_dataset_dicts(
455
+ cfg.DATASETS.TRAIN,
456
+ keep_instance_predicate=_get_train_keep_instance_predicate(cfg),
457
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
458
+ )
459
+ if mapper is None:
460
+ mapper = DatasetMapper(cfg, True)
461
+ return d2_build_detection_train_loader(cfg, dataset=dataset_dicts, mapper=mapper)
462
+
463
+
464
+ def build_detection_test_loader(cfg, dataset_name, mapper=None):
465
+ """
466
+ Similar to `build_detection_train_loader`.
467
+ But this function uses the given `dataset_name` argument (instead of the names in cfg),
468
+ and uses batch size 1.
469
+
470
+ Args:
471
+ cfg: a detectron2 CfgNode
472
+ dataset_name (str): a name of the dataset that's available in the DatasetCatalog
473
+ mapper (callable): a callable which takes a sample (dict) from dataset
474
+ and returns the format to be consumed by the model.
475
+ By default it will be `DatasetMapper(cfg, False)`.
476
+
477
+ Returns:
478
+ DataLoader: a torch DataLoader, that loads the given detection
479
+ dataset, with test-time transformation and batching.
480
+ """
481
+ _add_category_whitelists_to_metadata(cfg)
482
+ _add_category_maps_to_metadata(cfg)
483
+ _maybe_add_class_to_mesh_name_map_to_metadata([dataset_name], cfg)
484
+ dataset_dicts = combine_detection_dataset_dicts(
485
+ [dataset_name],
486
+ keep_instance_predicate=_get_test_keep_instance_predicate(cfg),
487
+ proposal_files=(
488
+ [cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]]
489
+ if cfg.MODEL.LOAD_PROPOSALS
490
+ else None
491
+ ),
492
+ )
493
+ sampler = None
494
+ if not cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE:
495
+ sampler = torch.utils.data.SequentialSampler(dataset_dicts)
496
+ if mapper is None:
497
+ mapper = DatasetMapper(cfg, False)
498
+ return d2_build_detection_test_loader(
499
+ dataset_dicts, mapper=mapper, num_workers=cfg.DATALOADER.NUM_WORKERS, sampler=sampler
500
+ )
501
+
502
+
503
+ def build_frame_selector(cfg: CfgNode):
504
+ strategy = FrameSelectionStrategy(cfg.STRATEGY)
505
+ if strategy == FrameSelectionStrategy.RANDOM_K:
506
+ frame_selector = RandomKFramesSelector(cfg.NUM_IMAGES)
507
+ elif strategy == FrameSelectionStrategy.FIRST_K:
508
+ frame_selector = FirstKFramesSelector(cfg.NUM_IMAGES)
509
+ elif strategy == FrameSelectionStrategy.LAST_K:
510
+ frame_selector = LastKFramesSelector(cfg.NUM_IMAGES)
511
+ elif strategy == FrameSelectionStrategy.ALL:
512
+ frame_selector = None
513
+ # pyre-fixme[61]: `frame_selector` may not be initialized here.
514
+ return frame_selector
515
+
516
+
517
+ def build_transform(cfg: CfgNode, data_type: str):
518
+ if cfg.TYPE == "resize":
519
+ if data_type == "image":
520
+ return ImageResizeTransform(cfg.MIN_SIZE, cfg.MAX_SIZE)
521
+ raise ValueError(f"Unknown transform {cfg.TYPE} for data type {data_type}")
522
+
523
+
524
+ def build_combined_loader(cfg: CfgNode, loaders: Collection[Loader], ratios: Sequence[float]):
525
+ images_per_worker = _compute_num_images_per_worker(cfg)
526
+ return CombinedDataLoader(loaders, images_per_worker, ratios)
527
+
528
+
529
+ def build_bootstrap_dataset(dataset_name: str, cfg: CfgNode) -> Sequence[torch.Tensor]:
530
+ """
531
+ Build dataset that provides data to bootstrap on
532
+
533
+ Args:
534
+ dataset_name (str): Name of the dataset, needs to have associated metadata
535
+ to load the data
536
+ cfg (CfgNode): bootstrapping config
537
+ Returns:
538
+ Sequence[Tensor] - dataset that provides image batches, Tensors of size
539
+ [N, C, H, W] of type float32
540
+ """
541
+ logger = logging.getLogger(__name__)
542
+ _add_category_info_to_bootstrapping_metadata(dataset_name, cfg)
543
+ meta = MetadataCatalog.get(dataset_name)
544
+ factory = BootstrapDatasetFactoryCatalog.get(meta.dataset_type)
545
+ dataset = None
546
+ if factory is not None:
547
+ dataset = factory(meta, cfg)
548
+ if dataset is None:
549
+ logger.warning(f"Failed to create dataset {dataset_name} of type {meta.dataset_type}")
550
+ return dataset
551
+
552
+
553
+ def build_data_sampler(cfg: CfgNode, sampler_cfg: CfgNode, embedder: Optional[torch.nn.Module]):
554
+ if sampler_cfg.TYPE == "densepose_uniform":
555
+ data_sampler = PredictionToGroundTruthSampler()
556
+ # transform densepose pred -> gt
557
+ data_sampler.register_sampler(
558
+ "pred_densepose",
559
+ "gt_densepose",
560
+ DensePoseUniformSampler(count_per_class=sampler_cfg.COUNT_PER_CLASS),
561
+ )
562
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
563
+ return data_sampler
564
+ elif sampler_cfg.TYPE == "densepose_UV_confidence":
565
+ data_sampler = PredictionToGroundTruthSampler()
566
+ # transform densepose pred -> gt
567
+ data_sampler.register_sampler(
568
+ "pred_densepose",
569
+ "gt_densepose",
570
+ DensePoseConfidenceBasedSampler(
571
+ confidence_channel="sigma_2",
572
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
573
+ search_proportion=0.5,
574
+ ),
575
+ )
576
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
577
+ return data_sampler
578
+ elif sampler_cfg.TYPE == "densepose_fine_segm_confidence":
579
+ data_sampler = PredictionToGroundTruthSampler()
580
+ # transform densepose pred -> gt
581
+ data_sampler.register_sampler(
582
+ "pred_densepose",
583
+ "gt_densepose",
584
+ DensePoseConfidenceBasedSampler(
585
+ confidence_channel="fine_segm_confidence",
586
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
587
+ search_proportion=0.5,
588
+ ),
589
+ )
590
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
591
+ return data_sampler
592
+ elif sampler_cfg.TYPE == "densepose_coarse_segm_confidence":
593
+ data_sampler = PredictionToGroundTruthSampler()
594
+ # transform densepose pred -> gt
595
+ data_sampler.register_sampler(
596
+ "pred_densepose",
597
+ "gt_densepose",
598
+ DensePoseConfidenceBasedSampler(
599
+ confidence_channel="coarse_segm_confidence",
600
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
601
+ search_proportion=0.5,
602
+ ),
603
+ )
604
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
605
+ return data_sampler
606
+ elif sampler_cfg.TYPE == "densepose_cse_uniform":
607
+ assert embedder is not None
608
+ data_sampler = PredictionToGroundTruthSampler()
609
+ # transform densepose pred -> gt
610
+ data_sampler.register_sampler(
611
+ "pred_densepose",
612
+ "gt_densepose",
613
+ DensePoseCSEUniformSampler(
614
+ cfg=cfg,
615
+ use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
616
+ embedder=embedder,
617
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
618
+ ),
619
+ )
620
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
621
+ return data_sampler
622
+ elif sampler_cfg.TYPE == "densepose_cse_coarse_segm_confidence":
623
+ assert embedder is not None
624
+ data_sampler = PredictionToGroundTruthSampler()
625
+ # transform densepose pred -> gt
626
+ data_sampler.register_sampler(
627
+ "pred_densepose",
628
+ "gt_densepose",
629
+ DensePoseCSEConfidenceBasedSampler(
630
+ cfg=cfg,
631
+ use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
632
+ embedder=embedder,
633
+ confidence_channel="coarse_segm_confidence",
634
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
635
+ search_proportion=0.5,
636
+ ),
637
+ )
638
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
639
+ return data_sampler
640
+
641
+ raise ValueError(f"Unknown data sampler type {sampler_cfg.TYPE}")
642
+
643
+
644
+ def build_data_filter(cfg: CfgNode):
645
+ if cfg.TYPE == "detection_score":
646
+ min_score = cfg.MIN_VALUE
647
+ return ScoreBasedFilter(min_score=min_score)
648
+ raise ValueError(f"Unknown data filter type {cfg.TYPE}")
649
+
650
+
651
+ def build_inference_based_loader(
652
+ cfg: CfgNode,
653
+ dataset_cfg: CfgNode,
654
+ model: torch.nn.Module,
655
+ embedder: Optional[torch.nn.Module] = None,
656
+ ) -> InferenceBasedLoader:
657
+ """
658
+ Constructs data loader based on inference results of a model.
659
+ """
660
+ dataset = build_bootstrap_dataset(dataset_cfg.DATASET, dataset_cfg.IMAGE_LOADER)
661
+ meta = MetadataCatalog.get(dataset_cfg.DATASET)
662
+ training_sampler = TrainingSampler(len(dataset))
663
+ data_loader = torch.utils.data.DataLoader(
664
+ dataset, # pyre-ignore[6]
665
+ batch_size=dataset_cfg.IMAGE_LOADER.BATCH_SIZE,
666
+ sampler=training_sampler,
667
+ num_workers=dataset_cfg.IMAGE_LOADER.NUM_WORKERS,
668
+ collate_fn=trivial_batch_collator,
669
+ worker_init_fn=worker_init_reset_seed,
670
+ )
671
+ return InferenceBasedLoader(
672
+ model,
673
+ data_loader=data_loader,
674
+ data_sampler=build_data_sampler(cfg, dataset_cfg.DATA_SAMPLER, embedder),
675
+ data_filter=build_data_filter(dataset_cfg.FILTER),
676
+ shuffle=True,
677
+ batch_size=dataset_cfg.INFERENCE.OUTPUT_BATCH_SIZE,
678
+ inference_batch_size=dataset_cfg.INFERENCE.INPUT_BATCH_SIZE,
679
+ category_to_class_mapping=meta.category_to_class_mapping,
680
+ )
681
+
682
+
683
+ def has_inference_based_loaders(cfg: CfgNode) -> bool:
684
+ """
685
+ Returns True, if at least one inferense-based loader must
686
+ be instantiated for training
687
+ """
688
+ return len(cfg.BOOTSTRAP_DATASETS) > 0
689
+
690
+
691
+ def build_inference_based_loaders(
692
+ cfg: CfgNode, model: torch.nn.Module
693
+ ) -> Tuple[List[InferenceBasedLoader], List[float]]:
694
+ loaders = []
695
+ ratios = []
696
+ embedder = build_densepose_embedder(cfg).to(device=model.device) # pyre-ignore[16]
697
+ for dataset_spec in cfg.BOOTSTRAP_DATASETS:
698
+ dataset_cfg = get_bootstrap_dataset_config().clone()
699
+ dataset_cfg.merge_from_other_cfg(CfgNode(dataset_spec))
700
+ loader = build_inference_based_loader(cfg, dataset_cfg, model, embedder)
701
+ loaders.append(loader)
702
+ ratios.append(dataset_cfg.RATIO)
703
+ return loaders, ratios
704
+
705
+
706
+ def build_video_list_dataset(meta: Metadata, cfg: CfgNode):
707
+ video_list_fpath = meta.video_list_fpath
708
+ video_base_path = meta.video_base_path
709
+ category = meta.category
710
+ if cfg.TYPE == "video_keyframe":
711
+ frame_selector = build_frame_selector(cfg.SELECT)
712
+ transform = build_transform(cfg.TRANSFORM, data_type="image")
713
+ video_list = video_list_from_file(video_list_fpath, video_base_path)
714
+ keyframe_helper_fpath = getattr(cfg, "KEYFRAME_HELPER", None)
715
+ return VideoKeyframeDataset(
716
+ video_list, category, frame_selector, transform, keyframe_helper_fpath
717
+ )
718
+
719
+
720
+ class _BootstrapDatasetFactoryCatalog(UserDict):
721
+ """
722
+ A global dictionary that stores information about bootstrapped datasets creation functions
723
+ from metadata and config, for diverse DatasetType
724
+ """
725
+
726
+ def register(self, dataset_type: DatasetType, factory: Callable[[Metadata, CfgNode], Dataset]):
727
+ """
728
+ Args:
729
+ dataset_type (DatasetType): a DatasetType e.g. DatasetType.VIDEO_LIST
730
+ factory (Callable[Metadata, CfgNode]): a callable which takes Metadata and cfg
731
+ arguments and returns a dataset object.
732
+ """
733
+ assert dataset_type not in self, "Dataset '{}' is already registered!".format(dataset_type)
734
+ self[dataset_type] = factory
735
+
736
+
737
+ BootstrapDatasetFactoryCatalog = _BootstrapDatasetFactoryCatalog()
738
+ BootstrapDatasetFactoryCatalog.register(DatasetType.VIDEO_LIST, build_video_list_dataset)
densepose/data/combined_loader.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from collections import deque
7
+ from typing import Any, Collection, Deque, Iterable, Iterator, List, Sequence
8
+
9
+ Loader = Iterable[Any]
10
+
11
+
12
+ def _pooled_next(iterator: Iterator[Any], pool: Deque[Any]):
13
+ if not pool:
14
+ pool.extend(next(iterator))
15
+ return pool.popleft()
16
+
17
+
18
+ class CombinedDataLoader:
19
+ """
20
+ Combines data loaders using the provided sampling ratios
21
+ """
22
+
23
+ BATCH_COUNT = 100
24
+
25
+ def __init__(self, loaders: Collection[Loader], batch_size: int, ratios: Sequence[float]):
26
+ self.loaders = loaders
27
+ self.batch_size = batch_size
28
+ self.ratios = ratios
29
+
30
+ def __iter__(self) -> Iterator[List[Any]]:
31
+ iters = [iter(loader) for loader in self.loaders]
32
+ indices = []
33
+ pool = [deque()] * len(iters)
34
+ # infinite iterator, as in D2
35
+ while True:
36
+ if not indices:
37
+ # just a buffer of indices, its size doesn't matter
38
+ # as long as it's a multiple of batch_size
39
+ k = self.batch_size * self.BATCH_COUNT
40
+ indices = random.choices(range(len(self.loaders)), self.ratios, k=k)
41
+ try:
42
+ batch = [_pooled_next(iters[i], pool[i]) for i in indices[: self.batch_size]]
43
+ except StopIteration:
44
+ break
45
+ indices = indices[self.batch_size :]
46
+ yield batch
densepose/data/dataset_mapper.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ # pyre-unsafe
5
+
6
+ import copy
7
+ import logging
8
+ from typing import Any, Dict, List, Tuple
9
+ import torch
10
+
11
+ from detectron2.data import MetadataCatalog
12
+ from detectron2.data import detection_utils as utils
13
+ from detectron2.data import transforms as T
14
+ from detectron2.layers import ROIAlign
15
+ from detectron2.structures import BoxMode
16
+ from detectron2.utils.file_io import PathManager
17
+
18
+ from densepose.structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
19
+
20
+
21
+ def build_augmentation(cfg, is_train):
22
+ logger = logging.getLogger(__name__)
23
+ result = utils.build_augmentation(cfg, is_train)
24
+ if is_train:
25
+ random_rotation = T.RandomRotation(
26
+ cfg.INPUT.ROTATION_ANGLES, expand=False, sample_style="choice"
27
+ )
28
+ result.append(random_rotation)
29
+ logger.info("DensePose-specific augmentation used in training: " + str(random_rotation))
30
+ return result
31
+
32
+
33
+ class DatasetMapper:
34
+ """
35
+ A customized version of `detectron2.data.DatasetMapper`
36
+ """
37
+
38
+ def __init__(self, cfg, is_train=True):
39
+ self.augmentation = build_augmentation(cfg, is_train)
40
+
41
+ # fmt: off
42
+ self.img_format = cfg.INPUT.FORMAT
43
+ self.mask_on = (
44
+ cfg.MODEL.MASK_ON or (
45
+ cfg.MODEL.DENSEPOSE_ON
46
+ and cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS)
47
+ )
48
+ self.keypoint_on = cfg.MODEL.KEYPOINT_ON
49
+ self.densepose_on = cfg.MODEL.DENSEPOSE_ON
50
+ assert not cfg.MODEL.LOAD_PROPOSALS, "not supported yet"
51
+ # fmt: on
52
+ if self.keypoint_on and is_train:
53
+ # Flip only makes sense in training
54
+ self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
55
+ else:
56
+ self.keypoint_hflip_indices = None
57
+
58
+ if self.densepose_on:
59
+ densepose_transform_srcs = [
60
+ MetadataCatalog.get(ds).densepose_transform_src
61
+ for ds in cfg.DATASETS.TRAIN + cfg.DATASETS.TEST
62
+ ]
63
+ assert len(densepose_transform_srcs) > 0
64
+ # TODO: check that DensePose transformation data is the same for
65
+ # all the datasets. Otherwise one would have to pass DB ID with
66
+ # each entry to select proper transformation data. For now, since
67
+ # all DensePose annotated data uses the same data semantics, we
68
+ # omit this check.
69
+ densepose_transform_data_fpath = PathManager.get_local_path(densepose_transform_srcs[0])
70
+ self.densepose_transform_data = DensePoseTransformData.load(
71
+ densepose_transform_data_fpath
72
+ )
73
+
74
+ self.is_train = is_train
75
+
76
+ def __call__(self, dataset_dict):
77
+ """
78
+ Args:
79
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
80
+
81
+ Returns:
82
+ dict: a format that builtin models in detectron2 accept
83
+ """
84
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
85
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
86
+ utils.check_image_size(dataset_dict, image)
87
+
88
+ image, transforms = T.apply_transform_gens(self.augmentation, image)
89
+ image_shape = image.shape[:2] # h, w
90
+ dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
91
+
92
+ if not self.is_train:
93
+ dataset_dict.pop("annotations", None)
94
+ return dataset_dict
95
+
96
+ for anno in dataset_dict["annotations"]:
97
+ if not self.mask_on:
98
+ anno.pop("segmentation", None)
99
+ if not self.keypoint_on:
100
+ anno.pop("keypoints", None)
101
+
102
+ # USER: Implement additional transformations if you have other types of data
103
+ # USER: Don't call transpose_densepose if you don't need
104
+ annos = [
105
+ self._transform_densepose(
106
+ utils.transform_instance_annotations(
107
+ obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
108
+ ),
109
+ transforms,
110
+ )
111
+ for obj in dataset_dict.pop("annotations")
112
+ if obj.get("iscrowd", 0) == 0
113
+ ]
114
+
115
+ if self.mask_on:
116
+ self._add_densepose_masks_as_segmentation(annos, image_shape)
117
+
118
+ instances = utils.annotations_to_instances(annos, image_shape, mask_format="bitmask")
119
+ densepose_annotations = [obj.get("densepose") for obj in annos]
120
+ if densepose_annotations and not all(v is None for v in densepose_annotations):
121
+ instances.gt_densepose = DensePoseList(
122
+ densepose_annotations, instances.gt_boxes, image_shape
123
+ )
124
+
125
+ dataset_dict["instances"] = instances[instances.gt_boxes.nonempty()]
126
+ return dataset_dict
127
+
128
+ def _transform_densepose(self, annotation, transforms):
129
+ if not self.densepose_on:
130
+ return annotation
131
+
132
+ # Handle densepose annotations
133
+ is_valid, reason_not_valid = DensePoseDataRelative.validate_annotation(annotation)
134
+ if is_valid:
135
+ densepose_data = DensePoseDataRelative(annotation, cleanup=True)
136
+ densepose_data.apply_transform(transforms, self.densepose_transform_data)
137
+ annotation["densepose"] = densepose_data
138
+ else:
139
+ # logger = logging.getLogger(__name__)
140
+ # logger.debug("Could not load DensePose annotation: {}".format(reason_not_valid))
141
+ DensePoseDataRelative.cleanup_annotation(annotation)
142
+ # NOTE: annotations for certain instances may be unavailable.
143
+ # 'None' is accepted by the DensePostList data structure.
144
+ annotation["densepose"] = None
145
+ return annotation
146
+
147
+ def _add_densepose_masks_as_segmentation(
148
+ self, annotations: List[Dict[str, Any]], image_shape_hw: Tuple[int, int]
149
+ ):
150
+ for obj in annotations:
151
+ if ("densepose" not in obj) or ("segmentation" in obj):
152
+ continue
153
+ # DP segmentation: torch.Tensor [S, S] of float32, S=256
154
+ segm_dp = torch.zeros_like(obj["densepose"].segm)
155
+ segm_dp[obj["densepose"].segm > 0] = 1
156
+ segm_h, segm_w = segm_dp.shape
157
+ bbox_segm_dp = torch.tensor((0, 0, segm_h - 1, segm_w - 1), dtype=torch.float32)
158
+ # image bbox
159
+ x0, y0, x1, y1 = (
160
+ v.item() for v in BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS)
161
+ )
162
+ segm_aligned = (
163
+ ROIAlign((y1 - y0, x1 - x0), 1.0, 0, aligned=True)
164
+ .forward(segm_dp.view(1, 1, *segm_dp.shape), bbox_segm_dp)
165
+ .squeeze()
166
+ )
167
+ image_mask = torch.zeros(*image_shape_hw, dtype=torch.float32)
168
+ image_mask[y0:y1, x0:x1] = segm_aligned
169
+ # segmentation for BitMask: np.array [H, W] of bool
170
+ obj["segmentation"] = image_mask >= 0.5
densepose/data/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from . import builtin # ensure the builtin datasets are registered
6
+
7
+ __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
densepose/data/datasets/builtin.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from .chimpnsee import register_dataset as register_chimpnsee_dataset
5
+ from .coco import BASE_DATASETS as BASE_COCO_DATASETS
6
+ from .coco import DATASETS as COCO_DATASETS
7
+ from .coco import register_datasets as register_coco_datasets
8
+ from .lvis import DATASETS as LVIS_DATASETS
9
+ from .lvis import register_datasets as register_lvis_datasets
10
+
11
+ DEFAULT_DATASETS_ROOT = "datasets"
12
+
13
+
14
+ register_coco_datasets(COCO_DATASETS, DEFAULT_DATASETS_ROOT)
15
+ register_coco_datasets(BASE_COCO_DATASETS, DEFAULT_DATASETS_ROOT)
16
+ register_lvis_datasets(LVIS_DATASETS, DEFAULT_DATASETS_ROOT)
17
+
18
+ register_chimpnsee_dataset(DEFAULT_DATASETS_ROOT) # pyre-ignore[19]
densepose/data/datasets/chimpnsee.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Optional
6
+
7
+ from detectron2.data import DatasetCatalog, MetadataCatalog
8
+
9
+ from ..utils import maybe_prepend_base_path
10
+ from .dataset_type import DatasetType
11
+
12
+ CHIMPNSEE_DATASET_NAME = "chimpnsee"
13
+
14
+
15
+ def register_dataset(datasets_root: Optional[str] = None) -> None:
16
+ def empty_load_callback():
17
+ pass
18
+
19
+ video_list_fpath = maybe_prepend_base_path(
20
+ datasets_root,
21
+ "chimpnsee/cdna.eva.mpg.de/video_list.txt",
22
+ )
23
+ video_base_path = maybe_prepend_base_path(datasets_root, "chimpnsee/cdna.eva.mpg.de")
24
+
25
+ DatasetCatalog.register(CHIMPNSEE_DATASET_NAME, empty_load_callback)
26
+ MetadataCatalog.get(CHIMPNSEE_DATASET_NAME).set(
27
+ dataset_type=DatasetType.VIDEO_LIST,
28
+ video_list_fpath=video_list_fpath,
29
+ video_base_path=video_base_path,
30
+ category="chimpanzee",
31
+ )
densepose/data/datasets/coco.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ import contextlib
5
+ import io
6
+ import logging
7
+ import os
8
+ from collections import defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import Any, Dict, Iterable, List, Optional
11
+ from fvcore.common.timer import Timer
12
+
13
+ from detectron2.data import DatasetCatalog, MetadataCatalog
14
+ from detectron2.structures import BoxMode
15
+ from detectron2.utils.file_io import PathManager
16
+
17
+ from ..utils import maybe_prepend_base_path
18
+
19
+ DENSEPOSE_MASK_KEY = "dp_masks"
20
+ DENSEPOSE_IUV_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_I", "dp_U", "dp_V"]
21
+ DENSEPOSE_CSE_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_vertex", "ref_model"]
22
+ DENSEPOSE_ALL_POSSIBLE_KEYS = set(
23
+ DENSEPOSE_IUV_KEYS_WITHOUT_MASK + DENSEPOSE_CSE_KEYS_WITHOUT_MASK + [DENSEPOSE_MASK_KEY]
24
+ )
25
+ DENSEPOSE_METADATA_URL_PREFIX = "https://dl.fbaipublicfiles.com/densepose/data/"
26
+
27
+
28
+ @dataclass
29
+ class CocoDatasetInfo:
30
+ name: str
31
+ images_root: str
32
+ annotations_fpath: str
33
+
34
+
35
+ DATASETS = [
36
+ CocoDatasetInfo(
37
+ name="densepose_coco_2014_train",
38
+ images_root="coco/train2014",
39
+ annotations_fpath="coco/annotations/densepose_train2014.json",
40
+ ),
41
+ CocoDatasetInfo(
42
+ name="densepose_coco_2014_minival",
43
+ images_root="coco/val2014",
44
+ annotations_fpath="coco/annotations/densepose_minival2014.json",
45
+ ),
46
+ CocoDatasetInfo(
47
+ name="densepose_coco_2014_minival_100",
48
+ images_root="coco/val2014",
49
+ annotations_fpath="coco/annotations/densepose_minival2014_100.json",
50
+ ),
51
+ CocoDatasetInfo(
52
+ name="densepose_coco_2014_valminusminival",
53
+ images_root="coco/val2014",
54
+ annotations_fpath="coco/annotations/densepose_valminusminival2014.json",
55
+ ),
56
+ CocoDatasetInfo(
57
+ name="densepose_coco_2014_train_cse",
58
+ images_root="coco/train2014",
59
+ annotations_fpath="coco_cse/densepose_train2014_cse.json",
60
+ ),
61
+ CocoDatasetInfo(
62
+ name="densepose_coco_2014_minival_cse",
63
+ images_root="coco/val2014",
64
+ annotations_fpath="coco_cse/densepose_minival2014_cse.json",
65
+ ),
66
+ CocoDatasetInfo(
67
+ name="densepose_coco_2014_minival_100_cse",
68
+ images_root="coco/val2014",
69
+ annotations_fpath="coco_cse/densepose_minival2014_100_cse.json",
70
+ ),
71
+ CocoDatasetInfo(
72
+ name="densepose_coco_2014_valminusminival_cse",
73
+ images_root="coco/val2014",
74
+ annotations_fpath="coco_cse/densepose_valminusminival2014_cse.json",
75
+ ),
76
+ CocoDatasetInfo(
77
+ name="densepose_chimps",
78
+ images_root="densepose_chimps/images",
79
+ annotations_fpath="densepose_chimps/densepose_chimps_densepose.json",
80
+ ),
81
+ CocoDatasetInfo(
82
+ name="densepose_chimps_cse_train",
83
+ images_root="densepose_chimps/images",
84
+ annotations_fpath="densepose_chimps/densepose_chimps_cse_train.json",
85
+ ),
86
+ CocoDatasetInfo(
87
+ name="densepose_chimps_cse_val",
88
+ images_root="densepose_chimps/images",
89
+ annotations_fpath="densepose_chimps/densepose_chimps_cse_val.json",
90
+ ),
91
+ CocoDatasetInfo(
92
+ name="posetrack2017_train",
93
+ images_root="posetrack2017/posetrack_data_2017",
94
+ annotations_fpath="posetrack2017/densepose_posetrack_train2017.json",
95
+ ),
96
+ CocoDatasetInfo(
97
+ name="posetrack2017_val",
98
+ images_root="posetrack2017/posetrack_data_2017",
99
+ annotations_fpath="posetrack2017/densepose_posetrack_val2017.json",
100
+ ),
101
+ CocoDatasetInfo(
102
+ name="lvis_v05_train",
103
+ images_root="coco/train2017",
104
+ annotations_fpath="lvis/lvis_v0.5_plus_dp_train.json",
105
+ ),
106
+ CocoDatasetInfo(
107
+ name="lvis_v05_val",
108
+ images_root="coco/val2017",
109
+ annotations_fpath="lvis/lvis_v0.5_plus_dp_val.json",
110
+ ),
111
+ ]
112
+
113
+
114
+ BASE_DATASETS = [
115
+ CocoDatasetInfo(
116
+ name="base_coco_2017_train",
117
+ images_root="coco/train2017",
118
+ annotations_fpath="coco/annotations/instances_train2017.json",
119
+ ),
120
+ CocoDatasetInfo(
121
+ name="base_coco_2017_val",
122
+ images_root="coco/val2017",
123
+ annotations_fpath="coco/annotations/instances_val2017.json",
124
+ ),
125
+ CocoDatasetInfo(
126
+ name="base_coco_2017_val_100",
127
+ images_root="coco/val2017",
128
+ annotations_fpath="coco/annotations/instances_val2017_100.json",
129
+ ),
130
+ ]
131
+
132
+
133
+ def get_metadata(base_path: Optional[str]) -> Dict[str, Any]:
134
+ """
135
+ Returns metadata associated with COCO DensePose datasets
136
+
137
+ Args:
138
+ base_path: Optional[str]
139
+ Base path used to load metadata from
140
+
141
+ Returns:
142
+ Dict[str, Any]
143
+ Metadata in the form of a dictionary
144
+ """
145
+ meta = {
146
+ "densepose_transform_src": maybe_prepend_base_path(base_path, "UV_symmetry_transforms.mat"),
147
+ "densepose_smpl_subdiv": maybe_prepend_base_path(base_path, "SMPL_subdiv.mat"),
148
+ "densepose_smpl_subdiv_transform": maybe_prepend_base_path(
149
+ base_path,
150
+ "SMPL_SUBDIV_TRANSFORM.mat",
151
+ ),
152
+ }
153
+ return meta
154
+
155
+
156
+ def _load_coco_annotations(json_file: str):
157
+ """
158
+ Load COCO annotations from a JSON file
159
+
160
+ Args:
161
+ json_file: str
162
+ Path to the file to load annotations from
163
+ Returns:
164
+ Instance of `pycocotools.coco.COCO` that provides access to annotations
165
+ data
166
+ """
167
+ from pycocotools.coco import COCO
168
+
169
+ logger = logging.getLogger(__name__)
170
+ timer = Timer()
171
+ with contextlib.redirect_stdout(io.StringIO()):
172
+ coco_api = COCO(json_file)
173
+ if timer.seconds() > 1:
174
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
175
+ return coco_api
176
+
177
+
178
+ def _add_categories_metadata(dataset_name: str, categories: List[Dict[str, Any]]):
179
+ meta = MetadataCatalog.get(dataset_name)
180
+ meta.categories = {c["id"]: c["name"] for c in categories}
181
+ logger = logging.getLogger(__name__)
182
+ logger.info("Dataset {} categories: {}".format(dataset_name, meta.categories))
183
+
184
+
185
+ def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]):
186
+ if "minival" in json_file:
187
+ # Skip validation on COCO2014 valminusminival and minival annotations
188
+ # The ratio of buggy annotations there is tiny and does not affect accuracy
189
+ # Therefore we explicitly white-list them
190
+ return
191
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
192
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
193
+ json_file
194
+ )
195
+
196
+
197
+ def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
198
+ if "bbox" not in ann_dict:
199
+ return
200
+ obj["bbox"] = ann_dict["bbox"]
201
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
202
+
203
+
204
+ def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
205
+ if "segmentation" not in ann_dict:
206
+ return
207
+ segm = ann_dict["segmentation"]
208
+ if not isinstance(segm, dict):
209
+ # filter out invalid polygons (< 3 points)
210
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
211
+ if len(segm) == 0:
212
+ return
213
+ obj["segmentation"] = segm
214
+
215
+
216
+ def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
217
+ if "keypoints" not in ann_dict:
218
+ return
219
+ keypts = ann_dict["keypoints"] # list[int]
220
+ for idx, v in enumerate(keypts):
221
+ if idx % 3 != 2:
222
+ # COCO's segmentation coordinates are floating points in [0, H or W],
223
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
224
+ # Therefore we assume the coordinates are "pixel indices" and
225
+ # add 0.5 to convert to floating point coordinates.
226
+ keypts[idx] = v + 0.5
227
+ obj["keypoints"] = keypts
228
+
229
+
230
+ def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
231
+ for key in DENSEPOSE_ALL_POSSIBLE_KEYS:
232
+ if key in ann_dict:
233
+ obj[key] = ann_dict[key]
234
+
235
+
236
+ def _combine_images_with_annotations(
237
+ dataset_name: str,
238
+ image_root: str,
239
+ img_datas: Iterable[Dict[str, Any]],
240
+ ann_datas: Iterable[Iterable[Dict[str, Any]]],
241
+ ):
242
+
243
+ ann_keys = ["iscrowd", "category_id"]
244
+ dataset_dicts = []
245
+ contains_video_frame_info = False
246
+
247
+ for img_dict, ann_dicts in zip(img_datas, ann_datas):
248
+ record = {}
249
+ record["file_name"] = os.path.join(image_root, img_dict["file_name"])
250
+ record["height"] = img_dict["height"]
251
+ record["width"] = img_dict["width"]
252
+ record["image_id"] = img_dict["id"]
253
+ record["dataset"] = dataset_name
254
+ if "frame_id" in img_dict:
255
+ record["frame_id"] = img_dict["frame_id"]
256
+ record["video_id"] = img_dict.get("vid_id", None)
257
+ contains_video_frame_info = True
258
+ objs = []
259
+ for ann_dict in ann_dicts:
260
+ assert ann_dict["image_id"] == record["image_id"]
261
+ assert ann_dict.get("ignore", 0) == 0
262
+ obj = {key: ann_dict[key] for key in ann_keys if key in ann_dict}
263
+ _maybe_add_bbox(obj, ann_dict)
264
+ _maybe_add_segm(obj, ann_dict)
265
+ _maybe_add_keypoints(obj, ann_dict)
266
+ _maybe_add_densepose(obj, ann_dict)
267
+ objs.append(obj)
268
+ record["annotations"] = objs
269
+ dataset_dicts.append(record)
270
+ if contains_video_frame_info:
271
+ create_video_frame_mapping(dataset_name, dataset_dicts)
272
+ return dataset_dicts
273
+
274
+
275
+ def get_contiguous_id_to_category_id_map(metadata):
276
+ cat_id_2_cont_id = metadata.thing_dataset_id_to_contiguous_id
277
+ cont_id_2_cat_id = {}
278
+ for cat_id, cont_id in cat_id_2_cont_id.items():
279
+ if cont_id in cont_id_2_cat_id:
280
+ continue
281
+ cont_id_2_cat_id[cont_id] = cat_id
282
+ return cont_id_2_cat_id
283
+
284
+
285
+ def maybe_filter_categories_cocoapi(dataset_name, coco_api):
286
+ meta = MetadataCatalog.get(dataset_name)
287
+ cont_id_2_cat_id = get_contiguous_id_to_category_id_map(meta)
288
+ cat_id_2_cont_id = meta.thing_dataset_id_to_contiguous_id
289
+ # filter categories
290
+ cats = []
291
+ for cat in coco_api.dataset["categories"]:
292
+ cat_id = cat["id"]
293
+ if cat_id not in cat_id_2_cont_id:
294
+ continue
295
+ cont_id = cat_id_2_cont_id[cat_id]
296
+ if (cont_id in cont_id_2_cat_id) and (cont_id_2_cat_id[cont_id] == cat_id):
297
+ cats.append(cat)
298
+ coco_api.dataset["categories"] = cats
299
+ # filter annotations, if multiple categories are mapped to a single
300
+ # contiguous ID, use only one category ID and map all annotations to that category ID
301
+ anns = []
302
+ for ann in coco_api.dataset["annotations"]:
303
+ cat_id = ann["category_id"]
304
+ if cat_id not in cat_id_2_cont_id:
305
+ continue
306
+ cont_id = cat_id_2_cont_id[cat_id]
307
+ ann["category_id"] = cont_id_2_cat_id[cont_id]
308
+ anns.append(ann)
309
+ coco_api.dataset["annotations"] = anns
310
+ # recreate index
311
+ coco_api.createIndex()
312
+
313
+
314
+ def maybe_filter_and_map_categories_cocoapi(dataset_name, coco_api):
315
+ meta = MetadataCatalog.get(dataset_name)
316
+ category_id_map = meta.thing_dataset_id_to_contiguous_id
317
+ # map categories
318
+ cats = []
319
+ for cat in coco_api.dataset["categories"]:
320
+ cat_id = cat["id"]
321
+ if cat_id not in category_id_map:
322
+ continue
323
+ cat["id"] = category_id_map[cat_id]
324
+ cats.append(cat)
325
+ coco_api.dataset["categories"] = cats
326
+ # map annotation categories
327
+ anns = []
328
+ for ann in coco_api.dataset["annotations"]:
329
+ cat_id = ann["category_id"]
330
+ if cat_id not in category_id_map:
331
+ continue
332
+ ann["category_id"] = category_id_map[cat_id]
333
+ anns.append(ann)
334
+ coco_api.dataset["annotations"] = anns
335
+ # recreate index
336
+ coco_api.createIndex()
337
+
338
+
339
+ def create_video_frame_mapping(dataset_name, dataset_dicts):
340
+ mapping = defaultdict(dict)
341
+ for d in dataset_dicts:
342
+ video_id = d.get("video_id")
343
+ if video_id is None:
344
+ continue
345
+ mapping[video_id].update({d["frame_id"]: d["file_name"]})
346
+ MetadataCatalog.get(dataset_name).set(video_frame_mapping=mapping)
347
+
348
+
349
+ def load_coco_json(annotations_json_file: str, image_root: str, dataset_name: str):
350
+ """
351
+ Loads a JSON file with annotations in COCO instances format.
352
+ Replaces `detectron2.data.datasets.coco.load_coco_json` to handle metadata
353
+ in a more flexible way. Postpones category mapping to a later stage to be
354
+ able to combine several datasets with different (but coherent) sets of
355
+ categories.
356
+
357
+ Args:
358
+
359
+ annotations_json_file: str
360
+ Path to the JSON file with annotations in COCO instances format.
361
+ image_root: str
362
+ directory that contains all the images
363
+ dataset_name: str
364
+ the name that identifies a dataset, e.g. "densepose_coco_2014_train"
365
+ extra_annotation_keys: Optional[List[str]]
366
+ If provided, these keys are used to extract additional data from
367
+ the annotations.
368
+ """
369
+ coco_api = _load_coco_annotations(PathManager.get_local_path(annotations_json_file))
370
+ _add_categories_metadata(dataset_name, coco_api.loadCats(coco_api.getCatIds()))
371
+ # sort indices for reproducible results
372
+ img_ids = sorted(coco_api.imgs.keys())
373
+ # imgs is a list of dicts, each looks something like:
374
+ # {'license': 4,
375
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
376
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
377
+ # 'height': 427,
378
+ # 'width': 640,
379
+ # 'date_captured': '2013-11-17 05:57:24',
380
+ # 'id': 1268}
381
+ imgs = coco_api.loadImgs(img_ids)
382
+ logger = logging.getLogger(__name__)
383
+ logger.info("Loaded {} images in COCO format from {}".format(len(imgs), annotations_json_file))
384
+ # anns is a list[list[dict]], where each dict is an annotation
385
+ # record for an object. The inner list enumerates the objects in an image
386
+ # and the outer list enumerates over images.
387
+ anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
388
+ _verify_annotations_have_unique_ids(annotations_json_file, anns)
389
+ dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns)
390
+ return dataset_records
391
+
392
+
393
+ def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None):
394
+ """
395
+ Registers provided COCO DensePose dataset
396
+
397
+ Args:
398
+ dataset_data: CocoDatasetInfo
399
+ Dataset data
400
+ datasets_root: Optional[str]
401
+ Datasets root folder (default: None)
402
+ """
403
+ annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath)
404
+ images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root)
405
+
406
+ def load_annotations():
407
+ return load_coco_json(
408
+ annotations_json_file=annotations_fpath,
409
+ image_root=images_root,
410
+ dataset_name=dataset_data.name,
411
+ )
412
+
413
+ DatasetCatalog.register(dataset_data.name, load_annotations)
414
+ MetadataCatalog.get(dataset_data.name).set(
415
+ json_file=annotations_fpath,
416
+ image_root=images_root,
417
+ **get_metadata(DENSEPOSE_METADATA_URL_PREFIX)
418
+ )
419
+
420
+
421
+ def register_datasets(
422
+ datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None
423
+ ):
424
+ """
425
+ Registers provided COCO DensePose datasets
426
+
427
+ Args:
428
+ datasets_data: Iterable[CocoDatasetInfo]
429
+ An iterable of dataset datas
430
+ datasets_root: Optional[str]
431
+ Datasets root folder (default: None)
432
+ """
433
+ for dataset_data in datasets_data:
434
+ register_dataset(dataset_data, datasets_root)
densepose/data/datasets/dataset_type.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from enum import Enum
6
+
7
+
8
+ class DatasetType(Enum):
9
+ """
10
+ Dataset type, mostly used for datasets that contain data to bootstrap models on
11
+ """
12
+
13
+ VIDEO_LIST = "video_list"
densepose/data/datasets/lvis.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ import logging
5
+ import os
6
+ from typing import Any, Dict, Iterable, List, Optional
7
+ from fvcore.common.timer import Timer
8
+
9
+ from detectron2.data import DatasetCatalog, MetadataCatalog
10
+ from detectron2.data.datasets.lvis import get_lvis_instances_meta
11
+ from detectron2.structures import BoxMode
12
+ from detectron2.utils.file_io import PathManager
13
+
14
+ from ..utils import maybe_prepend_base_path
15
+ from .coco import (
16
+ DENSEPOSE_ALL_POSSIBLE_KEYS,
17
+ DENSEPOSE_METADATA_URL_PREFIX,
18
+ CocoDatasetInfo,
19
+ get_metadata,
20
+ )
21
+
22
+ DATASETS = [
23
+ CocoDatasetInfo(
24
+ name="densepose_lvis_v1_ds1_train_v1",
25
+ images_root="coco_",
26
+ annotations_fpath="lvis/densepose_lvis_v1_ds1_train_v1.json",
27
+ ),
28
+ CocoDatasetInfo(
29
+ name="densepose_lvis_v1_ds1_val_v1",
30
+ images_root="coco_",
31
+ annotations_fpath="lvis/densepose_lvis_v1_ds1_val_v1.json",
32
+ ),
33
+ CocoDatasetInfo(
34
+ name="densepose_lvis_v1_ds2_train_v1",
35
+ images_root="coco_",
36
+ annotations_fpath="lvis/densepose_lvis_v1_ds2_train_v1.json",
37
+ ),
38
+ CocoDatasetInfo(
39
+ name="densepose_lvis_v1_ds2_val_v1",
40
+ images_root="coco_",
41
+ annotations_fpath="lvis/densepose_lvis_v1_ds2_val_v1.json",
42
+ ),
43
+ CocoDatasetInfo(
44
+ name="densepose_lvis_v1_ds1_val_animals_100",
45
+ images_root="coco_",
46
+ annotations_fpath="lvis/densepose_lvis_v1_val_animals_100_v2.json",
47
+ ),
48
+ ]
49
+
50
+
51
+ def _load_lvis_annotations(json_file: str):
52
+ """
53
+ Load COCO annotations from a JSON file
54
+
55
+ Args:
56
+ json_file: str
57
+ Path to the file to load annotations from
58
+ Returns:
59
+ Instance of `pycocotools.coco.COCO` that provides access to annotations
60
+ data
61
+ """
62
+ from lvis import LVIS
63
+
64
+ json_file = PathManager.get_local_path(json_file)
65
+ logger = logging.getLogger(__name__)
66
+ timer = Timer()
67
+ lvis_api = LVIS(json_file)
68
+ if timer.seconds() > 1:
69
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
70
+ return lvis_api
71
+
72
+
73
+ def _add_categories_metadata(dataset_name: str) -> None:
74
+ metadict = get_lvis_instances_meta(dataset_name)
75
+ categories = metadict["thing_classes"]
76
+ metadata = MetadataCatalog.get(dataset_name)
77
+ metadata.categories = {i + 1: categories[i] for i in range(len(categories))}
78
+ logger = logging.getLogger(__name__)
79
+ logger.info(f"Dataset {dataset_name} has {len(categories)} categories")
80
+
81
+
82
+ def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]) -> None:
83
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
84
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
85
+ json_file
86
+ )
87
+
88
+
89
+ def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
90
+ if "bbox" not in ann_dict:
91
+ return
92
+ obj["bbox"] = ann_dict["bbox"]
93
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
94
+
95
+
96
+ def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
97
+ if "segmentation" not in ann_dict:
98
+ return
99
+ segm = ann_dict["segmentation"]
100
+ if not isinstance(segm, dict):
101
+ # filter out invalid polygons (< 3 points)
102
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
103
+ if len(segm) == 0:
104
+ return
105
+ obj["segmentation"] = segm
106
+
107
+
108
+ def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
109
+ if "keypoints" not in ann_dict:
110
+ return
111
+ keypts = ann_dict["keypoints"] # list[int]
112
+ for idx, v in enumerate(keypts):
113
+ if idx % 3 != 2:
114
+ # COCO's segmentation coordinates are floating points in [0, H or W],
115
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
116
+ # Therefore we assume the coordinates are "pixel indices" and
117
+ # add 0.5 to convert to floating point coordinates.
118
+ keypts[idx] = v + 0.5
119
+ obj["keypoints"] = keypts
120
+
121
+
122
+ def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
123
+ for key in DENSEPOSE_ALL_POSSIBLE_KEYS:
124
+ if key in ann_dict:
125
+ obj[key] = ann_dict[key]
126
+
127
+
128
+ def _combine_images_with_annotations(
129
+ dataset_name: str,
130
+ image_root: str,
131
+ img_datas: Iterable[Dict[str, Any]],
132
+ ann_datas: Iterable[Iterable[Dict[str, Any]]],
133
+ ):
134
+
135
+ dataset_dicts = []
136
+
137
+ def get_file_name(img_root, img_dict):
138
+ # Determine the path including the split folder ("train2017", "val2017", "test2017") from
139
+ # the coco_url field. Example:
140
+ # 'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg'
141
+ split_folder, file_name = img_dict["coco_url"].split("/")[-2:]
142
+ return os.path.join(img_root + split_folder, file_name)
143
+
144
+ for img_dict, ann_dicts in zip(img_datas, ann_datas):
145
+ record = {}
146
+ record["file_name"] = get_file_name(image_root, img_dict)
147
+ record["height"] = img_dict["height"]
148
+ record["width"] = img_dict["width"]
149
+ record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", [])
150
+ record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
151
+ record["image_id"] = img_dict["id"]
152
+ record["dataset"] = dataset_name
153
+
154
+ objs = []
155
+ for ann_dict in ann_dicts:
156
+ assert ann_dict["image_id"] == record["image_id"]
157
+ obj = {}
158
+ _maybe_add_bbox(obj, ann_dict)
159
+ obj["iscrowd"] = ann_dict.get("iscrowd", 0)
160
+ obj["category_id"] = ann_dict["category_id"]
161
+ _maybe_add_segm(obj, ann_dict)
162
+ _maybe_add_keypoints(obj, ann_dict)
163
+ _maybe_add_densepose(obj, ann_dict)
164
+ objs.append(obj)
165
+ record["annotations"] = objs
166
+ dataset_dicts.append(record)
167
+ return dataset_dicts
168
+
169
+
170
+ def load_lvis_json(annotations_json_file: str, image_root: str, dataset_name: str):
171
+ """
172
+ Loads a JSON file with annotations in LVIS instances format.
173
+ Replaces `detectron2.data.datasets.coco.load_lvis_json` to handle metadata
174
+ in a more flexible way. Postpones category mapping to a later stage to be
175
+ able to combine several datasets with different (but coherent) sets of
176
+ categories.
177
+
178
+ Args:
179
+
180
+ annotations_json_file: str
181
+ Path to the JSON file with annotations in COCO instances format.
182
+ image_root: str
183
+ directory that contains all the images
184
+ dataset_name: str
185
+ the name that identifies a dataset, e.g. "densepose_coco_2014_train"
186
+ extra_annotation_keys: Optional[List[str]]
187
+ If provided, these keys are used to extract additional data from
188
+ the annotations.
189
+ """
190
+ lvis_api = _load_lvis_annotations(PathManager.get_local_path(annotations_json_file))
191
+
192
+ _add_categories_metadata(dataset_name)
193
+
194
+ # sort indices for reproducible results
195
+ img_ids = sorted(lvis_api.imgs.keys())
196
+ # imgs is a list of dicts, each looks something like:
197
+ # {'license': 4,
198
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
199
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
200
+ # 'height': 427,
201
+ # 'width': 640,
202
+ # 'date_captured': '2013-11-17 05:57:24',
203
+ # 'id': 1268}
204
+ imgs = lvis_api.load_imgs(img_ids)
205
+ logger = logging.getLogger(__name__)
206
+ logger.info("Loaded {} images in LVIS format from {}".format(len(imgs), annotations_json_file))
207
+ # anns is a list[list[dict]], where each dict is an annotation
208
+ # record for an object. The inner list enumerates the objects in an image
209
+ # and the outer list enumerates over images.
210
+ anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
211
+
212
+ _verify_annotations_have_unique_ids(annotations_json_file, anns)
213
+ dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns)
214
+ return dataset_records
215
+
216
+
217
+ def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None) -> None:
218
+ """
219
+ Registers provided LVIS DensePose dataset
220
+
221
+ Args:
222
+ dataset_data: CocoDatasetInfo
223
+ Dataset data
224
+ datasets_root: Optional[str]
225
+ Datasets root folder (default: None)
226
+ """
227
+ annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath)
228
+ images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root)
229
+
230
+ def load_annotations():
231
+ return load_lvis_json(
232
+ annotations_json_file=annotations_fpath,
233
+ image_root=images_root,
234
+ dataset_name=dataset_data.name,
235
+ )
236
+
237
+ DatasetCatalog.register(dataset_data.name, load_annotations)
238
+ MetadataCatalog.get(dataset_data.name).set(
239
+ json_file=annotations_fpath,
240
+ image_root=images_root,
241
+ evaluator_type="lvis",
242
+ **get_metadata(DENSEPOSE_METADATA_URL_PREFIX),
243
+ )
244
+
245
+
246
+ def register_datasets(
247
+ datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None
248
+ ) -> None:
249
+ """
250
+ Registers provided LVIS DensePose datasets
251
+
252
+ Args:
253
+ datasets_data: Iterable[CocoDatasetInfo]
254
+ An iterable of dataset datas
255
+ datasets_root: Optional[str]
256
+ Datasets root folder (default: None)
257
+ """
258
+ for dataset_data in datasets_data:
259
+ register_dataset(dataset_data, datasets_root)
densepose/data/image_list_dataset.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ # pyre-unsafe
5
+
6
+ import logging
7
+ import numpy as np
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+ import torch
10
+ from torch.utils.data.dataset import Dataset
11
+
12
+ from detectron2.data.detection_utils import read_image
13
+
14
+ ImageTransform = Callable[[torch.Tensor], torch.Tensor]
15
+
16
+
17
+ class ImageListDataset(Dataset):
18
+ """
19
+ Dataset that provides images from a list.
20
+ """
21
+
22
+ _EMPTY_IMAGE = torch.empty((0, 3, 1, 1))
23
+
24
+ def __init__(
25
+ self,
26
+ image_list: List[str],
27
+ category_list: Union[str, List[str], None] = None,
28
+ transform: Optional[ImageTransform] = None,
29
+ ):
30
+ """
31
+ Args:
32
+ image_list (List[str]): list of paths to image files
33
+ category_list (Union[str, List[str], None]): list of animal categories for
34
+ each image. If it is a string, or None, this applies to all images
35
+ """
36
+ if type(category_list) is list:
37
+ self.category_list = category_list
38
+ else:
39
+ self.category_list = [category_list] * len(image_list)
40
+ assert len(image_list) == len(
41
+ self.category_list
42
+ ), "length of image and category lists must be equal"
43
+ self.image_list = image_list
44
+ self.transform = transform
45
+
46
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
47
+ """
48
+ Gets selected images from the list
49
+
50
+ Args:
51
+ idx (int): video index in the video list file
52
+ Returns:
53
+ A dictionary containing two keys:
54
+ images (torch.Tensor): tensor of size [N, 3, H, W] (N = 1, or 0 for _EMPTY_IMAGE)
55
+ categories (List[str]): categories of the frames
56
+ """
57
+ categories = [self.category_list[idx]]
58
+ fpath = self.image_list[idx]
59
+ transform = self.transform
60
+
61
+ try:
62
+ image = torch.from_numpy(np.ascontiguousarray(read_image(fpath, format="BGR")))
63
+ image = image.permute(2, 0, 1).unsqueeze(0).float() # HWC -> NCHW
64
+ if transform is not None:
65
+ image = transform(image)
66
+ return {"images": image, "categories": categories}
67
+ except (OSError, RuntimeError) as e:
68
+ logger = logging.getLogger(__name__)
69
+ logger.warning(f"Error opening image file container {fpath}: {e}")
70
+
71
+ return {"images": self._EMPTY_IMAGE, "categories": []}
72
+
73
+ def __len__(self):
74
+ return len(self.image_list)
densepose/data/inference_based_loader.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
7
+ import torch
8
+ from torch import nn
9
+
10
+ SampledData = Any
11
+ ModelOutput = Any
12
+
13
+
14
+ def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]:
15
+ """
16
+ Group elements of an iterable by chunks of size `n`, e.g.
17
+ grouper(range(9), 4) ->
18
+ (0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None)
19
+ """
20
+ it = iter(iterable)
21
+ while True:
22
+ values = []
23
+ for _ in range(n):
24
+ try:
25
+ value = next(it)
26
+ except StopIteration:
27
+ if values:
28
+ values.extend([fillvalue] * (n - len(values)))
29
+ yield tuple(values)
30
+ return
31
+ values.append(value)
32
+ yield tuple(values)
33
+
34
+
35
+ class ScoreBasedFilter:
36
+ """
37
+ Filters entries in model output based on their scores
38
+ Discards all entries with score less than the specified minimum
39
+ """
40
+
41
+ def __init__(self, min_score: float = 0.8):
42
+ self.min_score = min_score
43
+
44
+ def __call__(self, model_output: ModelOutput) -> ModelOutput:
45
+ for model_output_i in model_output:
46
+ instances = model_output_i["instances"]
47
+ if not instances.has("scores"):
48
+ continue
49
+ instances_filtered = instances[instances.scores >= self.min_score]
50
+ model_output_i["instances"] = instances_filtered
51
+ return model_output
52
+
53
+
54
+ class InferenceBasedLoader:
55
+ """
56
+ Data loader based on results inferred by a model. Consists of:
57
+ - a data loader that provides batches of images
58
+ - a model that is used to infer the results
59
+ - a data sampler that converts inferred results to annotations
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ model: nn.Module,
65
+ data_loader: Iterable[List[Dict[str, Any]]],
66
+ data_sampler: Optional[Callable[[ModelOutput], List[SampledData]]] = None,
67
+ data_filter: Optional[Callable[[ModelOutput], ModelOutput]] = None,
68
+ shuffle: bool = True,
69
+ batch_size: int = 4,
70
+ inference_batch_size: int = 4,
71
+ drop_last: bool = False,
72
+ category_to_class_mapping: Optional[dict] = None,
73
+ ):
74
+ """
75
+ Constructor
76
+
77
+ Args:
78
+ model (torch.nn.Module): model used to produce data
79
+ data_loader (Iterable[List[Dict[str, Any]]]): iterable that provides
80
+ dictionaries with "images" and "categories" fields to perform inference on
81
+ data_sampler (Callable: ModelOutput -> SampledData): functor
82
+ that produces annotation data from inference results;
83
+ (optional, default: None)
84
+ data_filter (Callable: ModelOutput -> ModelOutput): filter
85
+ that selects model outputs for further processing
86
+ (optional, default: None)
87
+ shuffle (bool): if True, the input images get shuffled
88
+ batch_size (int): batch size for the produced annotation data
89
+ inference_batch_size (int): batch size for input images
90
+ drop_last (bool): if True, drop the last batch if it is undersized
91
+ category_to_class_mapping (dict): category to class mapping
92
+ """
93
+ self.model = model
94
+ self.model.eval()
95
+ self.data_loader = data_loader
96
+ self.data_sampler = data_sampler
97
+ self.data_filter = data_filter
98
+ self.shuffle = shuffle
99
+ self.batch_size = batch_size
100
+ self.inference_batch_size = inference_batch_size
101
+ self.drop_last = drop_last
102
+ if category_to_class_mapping is not None:
103
+ self.category_to_class_mapping = category_to_class_mapping
104
+ else:
105
+ self.category_to_class_mapping = {}
106
+
107
+ def __iter__(self) -> Iterator[List[SampledData]]:
108
+ for batch in self.data_loader:
109
+ # batch : List[Dict[str: Tensor[N, C, H, W], str: Optional[str]]]
110
+ # images_batch : Tensor[N, C, H, W]
111
+ # image : Tensor[C, H, W]
112
+ images_and_categories = [
113
+ {"image": image, "category": category}
114
+ for element in batch
115
+ for image, category in zip(element["images"], element["categories"])
116
+ ]
117
+ if not images_and_categories:
118
+ continue
119
+ if self.shuffle:
120
+ random.shuffle(images_and_categories)
121
+ yield from self._produce_data(images_and_categories) # pyre-ignore[6]
122
+
123
+ def _produce_data(
124
+ self, images_and_categories: List[Tuple[torch.Tensor, Optional[str]]]
125
+ ) -> Iterator[List[SampledData]]:
126
+ """
127
+ Produce batches of data from images
128
+
129
+ Args:
130
+ images_and_categories (List[Tuple[torch.Tensor, Optional[str]]]):
131
+ list of images and corresponding categories to process
132
+
133
+ Returns:
134
+ Iterator over batches of data sampled from model outputs
135
+ """
136
+ data_batches: List[SampledData] = []
137
+ category_to_class_mapping = self.category_to_class_mapping
138
+ batched_images_and_categories = _grouper(images_and_categories, self.inference_batch_size)
139
+ for batch in batched_images_and_categories:
140
+ batch = [
141
+ {
142
+ "image": image_and_category["image"].to(self.model.device),
143
+ "category": image_and_category["category"],
144
+ }
145
+ for image_and_category in batch
146
+ if image_and_category is not None
147
+ ]
148
+ if not batch:
149
+ continue
150
+ with torch.no_grad():
151
+ model_output = self.model(batch)
152
+ for model_output_i, batch_i in zip(model_output, batch):
153
+ assert len(batch_i["image"].shape) == 3
154
+ model_output_i["image"] = batch_i["image"]
155
+ instance_class = category_to_class_mapping.get(batch_i["category"], 0)
156
+ model_output_i["instances"].dataset_classes = torch.tensor(
157
+ [instance_class] * len(model_output_i["instances"])
158
+ )
159
+ model_output_filtered = (
160
+ model_output if self.data_filter is None else self.data_filter(model_output)
161
+ )
162
+ data = (
163
+ model_output_filtered
164
+ if self.data_sampler is None
165
+ else self.data_sampler(model_output_filtered)
166
+ )
167
+ for data_i in data:
168
+ if len(data_i["instances"]):
169
+ data_batches.append(data_i)
170
+ if len(data_batches) >= self.batch_size:
171
+ yield data_batches[: self.batch_size]
172
+ data_batches = data_batches[self.batch_size :]
173
+ if not self.drop_last and data_batches:
174
+ yield data_batches
densepose/data/meshes/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ from . import builtin
6
+
7
+ __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
densepose/data/meshes/builtin.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ from .catalog import MeshInfo, register_meshes
6
+
7
+ DENSEPOSE_MESHES_DIR = "https://dl.fbaipublicfiles.com/densepose/meshes/"
8
+
9
+ MESHES = [
10
+ MeshInfo(
11
+ name="smpl_27554",
12
+ data="smpl_27554.pkl",
13
+ geodists="geodists/geodists_smpl_27554.pkl",
14
+ symmetry="symmetry/symmetry_smpl_27554.pkl",
15
+ texcoords="texcoords/texcoords_smpl_27554.pkl",
16
+ ),
17
+ MeshInfo(
18
+ name="chimp_5029",
19
+ data="chimp_5029.pkl",
20
+ geodists="geodists/geodists_chimp_5029.pkl",
21
+ symmetry="symmetry/symmetry_chimp_5029.pkl",
22
+ texcoords="texcoords/texcoords_chimp_5029.pkl",
23
+ ),
24
+ MeshInfo(
25
+ name="cat_5001",
26
+ data="cat_5001.pkl",
27
+ geodists="geodists/geodists_cat_5001.pkl",
28
+ symmetry="symmetry/symmetry_cat_5001.pkl",
29
+ texcoords="texcoords/texcoords_cat_5001.pkl",
30
+ ),
31
+ MeshInfo(
32
+ name="cat_7466",
33
+ data="cat_7466.pkl",
34
+ geodists="geodists/geodists_cat_7466.pkl",
35
+ symmetry="symmetry/symmetry_cat_7466.pkl",
36
+ texcoords="texcoords/texcoords_cat_7466.pkl",
37
+ ),
38
+ MeshInfo(
39
+ name="sheep_5004",
40
+ data="sheep_5004.pkl",
41
+ geodists="geodists/geodists_sheep_5004.pkl",
42
+ symmetry="symmetry/symmetry_sheep_5004.pkl",
43
+ texcoords="texcoords/texcoords_sheep_5004.pkl",
44
+ ),
45
+ MeshInfo(
46
+ name="zebra_5002",
47
+ data="zebra_5002.pkl",
48
+ geodists="geodists/geodists_zebra_5002.pkl",
49
+ symmetry="symmetry/symmetry_zebra_5002.pkl",
50
+ texcoords="texcoords/texcoords_zebra_5002.pkl",
51
+ ),
52
+ MeshInfo(
53
+ name="horse_5004",
54
+ data="horse_5004.pkl",
55
+ geodists="geodists/geodists_horse_5004.pkl",
56
+ symmetry="symmetry/symmetry_horse_5004.pkl",
57
+ texcoords="texcoords/texcoords_zebra_5002.pkl",
58
+ ),
59
+ MeshInfo(
60
+ name="giraffe_5002",
61
+ data="giraffe_5002.pkl",
62
+ geodists="geodists/geodists_giraffe_5002.pkl",
63
+ symmetry="symmetry/symmetry_giraffe_5002.pkl",
64
+ texcoords="texcoords/texcoords_giraffe_5002.pkl",
65
+ ),
66
+ MeshInfo(
67
+ name="elephant_5002",
68
+ data="elephant_5002.pkl",
69
+ geodists="geodists/geodists_elephant_5002.pkl",
70
+ symmetry="symmetry/symmetry_elephant_5002.pkl",
71
+ texcoords="texcoords/texcoords_elephant_5002.pkl",
72
+ ),
73
+ MeshInfo(
74
+ name="dog_5002",
75
+ data="dog_5002.pkl",
76
+ geodists="geodists/geodists_dog_5002.pkl",
77
+ symmetry="symmetry/symmetry_dog_5002.pkl",
78
+ texcoords="texcoords/texcoords_dog_5002.pkl",
79
+ ),
80
+ MeshInfo(
81
+ name="dog_7466",
82
+ data="dog_7466.pkl",
83
+ geodists="geodists/geodists_dog_7466.pkl",
84
+ symmetry="symmetry/symmetry_dog_7466.pkl",
85
+ texcoords="texcoords/texcoords_dog_7466.pkl",
86
+ ),
87
+ MeshInfo(
88
+ name="cow_5002",
89
+ data="cow_5002.pkl",
90
+ geodists="geodists/geodists_cow_5002.pkl",
91
+ symmetry="symmetry/symmetry_cow_5002.pkl",
92
+ texcoords="texcoords/texcoords_cow_5002.pkl",
93
+ ),
94
+ MeshInfo(
95
+ name="bear_4936",
96
+ data="bear_4936.pkl",
97
+ geodists="geodists/geodists_bear_4936.pkl",
98
+ symmetry="symmetry/symmetry_bear_4936.pkl",
99
+ texcoords="texcoords/texcoords_bear_4936.pkl",
100
+ ),
101
+ ]
102
+
103
+ register_meshes(MESHES, DENSEPOSE_MESHES_DIR)
densepose/data/meshes/catalog.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ import logging
6
+ from collections import UserDict
7
+ from dataclasses import dataclass
8
+ from typing import Iterable, Optional
9
+
10
+ from ..utils import maybe_prepend_base_path
11
+
12
+
13
+ @dataclass
14
+ class MeshInfo:
15
+ name: str
16
+ data: str
17
+ geodists: Optional[str] = None
18
+ symmetry: Optional[str] = None
19
+ texcoords: Optional[str] = None
20
+
21
+
22
+ class _MeshCatalog(UserDict):
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ self.mesh_ids = {}
26
+ self.mesh_names = {}
27
+ self.max_mesh_id = -1
28
+
29
+ def __setitem__(self, key, value):
30
+ if key in self:
31
+ logger = logging.getLogger(__name__)
32
+ logger.warning(
33
+ f"Overwriting mesh catalog entry '{key}': old value {self[key]}"
34
+ f", new value {value}"
35
+ )
36
+ mesh_id = self.mesh_ids[key]
37
+ else:
38
+ self.max_mesh_id += 1
39
+ mesh_id = self.max_mesh_id
40
+ super().__setitem__(key, value)
41
+ self.mesh_ids[key] = mesh_id
42
+ self.mesh_names[mesh_id] = key
43
+
44
+ def get_mesh_id(self, shape_name: str) -> int:
45
+ return self.mesh_ids[shape_name]
46
+
47
+ def get_mesh_name(self, mesh_id: int) -> str:
48
+ return self.mesh_names[mesh_id]
49
+
50
+
51
+ MeshCatalog = _MeshCatalog()
52
+
53
+
54
+ def register_mesh(mesh_info: MeshInfo, base_path: Optional[str]) -> None:
55
+ geodists, symmetry, texcoords = mesh_info.geodists, mesh_info.symmetry, mesh_info.texcoords
56
+ if geodists:
57
+ geodists = maybe_prepend_base_path(base_path, geodists)
58
+ if symmetry:
59
+ symmetry = maybe_prepend_base_path(base_path, symmetry)
60
+ if texcoords:
61
+ texcoords = maybe_prepend_base_path(base_path, texcoords)
62
+ MeshCatalog[mesh_info.name] = MeshInfo(
63
+ name=mesh_info.name,
64
+ data=maybe_prepend_base_path(base_path, mesh_info.data),
65
+ geodists=geodists,
66
+ symmetry=symmetry,
67
+ texcoords=texcoords,
68
+ )
69
+
70
+
71
+ def register_meshes(mesh_infos: Iterable[MeshInfo], base_path: Optional[str]) -> None:
72
+ for mesh_info in mesh_infos:
73
+ register_mesh(mesh_info, base_path)
densepose/data/samplers/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .densepose_uniform import DensePoseUniformSampler
6
+ from .densepose_confidence_based import DensePoseConfidenceBasedSampler
7
+ from .densepose_cse_uniform import DensePoseCSEUniformSampler
8
+ from .densepose_cse_confidence_based import DensePoseCSEConfidenceBasedSampler
9
+ from .mask_from_densepose import MaskFromDensePoseSampler
10
+ from .prediction_to_gt import PredictionToGroundTruthSampler
densepose/data/samplers/densepose_base.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Dict, List, Tuple
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures import BoxMode, Instances
10
+
11
+ from densepose.converters import ToChartResultConverter
12
+ from densepose.converters.base import IntTupleBox, make_int_box
13
+ from densepose.structures import DensePoseDataRelative, DensePoseList
14
+
15
+
16
+ class DensePoseBaseSampler:
17
+ """
18
+ Base DensePose sampler to produce DensePose data from DensePose predictions.
19
+ Samples for each class are drawn according to some distribution over all pixels estimated
20
+ to belong to that class.
21
+ """
22
+
23
+ def __init__(self, count_per_class: int = 8):
24
+ """
25
+ Constructor
26
+
27
+ Args:
28
+ count_per_class (int): the sampler produces at most `count_per_class`
29
+ samples for each category
30
+ """
31
+ self.count_per_class = count_per_class
32
+
33
+ def __call__(self, instances: Instances) -> DensePoseList:
34
+ """
35
+ Convert DensePose predictions (an instance of `DensePoseChartPredictorOutput`)
36
+ into DensePose annotations data (an instance of `DensePoseList`)
37
+ """
38
+ boxes_xyxy_abs = instances.pred_boxes.tensor.clone().cpu()
39
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
40
+ dp_datas = []
41
+ for i in range(len(boxes_xywh_abs)):
42
+ annotation_i = self._sample(instances[i], make_int_box(boxes_xywh_abs[i]))
43
+ annotation_i[DensePoseDataRelative.S_KEY] = self._resample_mask( # pyre-ignore[6]
44
+ instances[i].pred_densepose
45
+ )
46
+ dp_datas.append(DensePoseDataRelative(annotation_i))
47
+ # create densepose annotations on CPU
48
+ dp_list = DensePoseList(dp_datas, boxes_xyxy_abs, instances.image_size)
49
+ return dp_list
50
+
51
+ def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
52
+ """
53
+ Sample DensPoseDataRelative from estimation results
54
+ """
55
+ labels, dp_result = self._produce_labels_and_results(instance)
56
+ annotation = {
57
+ DensePoseDataRelative.X_KEY: [],
58
+ DensePoseDataRelative.Y_KEY: [],
59
+ DensePoseDataRelative.U_KEY: [],
60
+ DensePoseDataRelative.V_KEY: [],
61
+ DensePoseDataRelative.I_KEY: [],
62
+ }
63
+ n, h, w = dp_result.shape
64
+ for part_id in range(1, DensePoseDataRelative.N_PART_LABELS + 1):
65
+ # indices - tuple of 3 1D tensors of size k
66
+ # 0: index along the first dimension N
67
+ # 1: index along H dimension
68
+ # 2: index along W dimension
69
+ indices = torch.nonzero(labels.expand(n, h, w) == part_id, as_tuple=True)
70
+ # values - an array of size [n, k]
71
+ # n: number of channels (U, V, confidences)
72
+ # k: number of points labeled with part_id
73
+ values = dp_result[indices].view(n, -1)
74
+ k = values.shape[1]
75
+ count = min(self.count_per_class, k)
76
+ if count <= 0:
77
+ continue
78
+ index_sample = self._produce_index_sample(values, count)
79
+ sampled_values = values[:, index_sample]
80
+ sampled_y = indices[1][index_sample] + 0.5
81
+ sampled_x = indices[2][index_sample] + 0.5
82
+ # prepare / normalize data
83
+ x = (sampled_x / w * 256.0).cpu().tolist()
84
+ y = (sampled_y / h * 256.0).cpu().tolist()
85
+ u = sampled_values[0].clamp(0, 1).cpu().tolist()
86
+ v = sampled_values[1].clamp(0, 1).cpu().tolist()
87
+ fine_segm_labels = [part_id] * count
88
+ # extend annotations
89
+ annotation[DensePoseDataRelative.X_KEY].extend(x)
90
+ annotation[DensePoseDataRelative.Y_KEY].extend(y)
91
+ annotation[DensePoseDataRelative.U_KEY].extend(u)
92
+ annotation[DensePoseDataRelative.V_KEY].extend(v)
93
+ annotation[DensePoseDataRelative.I_KEY].extend(fine_segm_labels)
94
+ return annotation
95
+
96
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
97
+ """
98
+ Abstract method to produce a sample of indices to select data
99
+ To be implemented in descendants
100
+
101
+ Args:
102
+ values (torch.Tensor): an array of size [n, k] that contains
103
+ estimated values (U, V, confidences);
104
+ n: number of channels (U, V, confidences)
105
+ k: number of points labeled with part_id
106
+ count (int): number of samples to produce, should be positive and <= k
107
+
108
+ Return:
109
+ list(int): indices of values (along axis 1) selected as a sample
110
+ """
111
+ raise NotImplementedError
112
+
113
+ def _produce_labels_and_results(self, instance: Instances) -> Tuple[torch.Tensor, torch.Tensor]:
114
+ """
115
+ Method to get labels and DensePose results from an instance
116
+
117
+ Args:
118
+ instance (Instances): an instance of `DensePoseChartPredictorOutput`
119
+
120
+ Return:
121
+ labels (torch.Tensor): shape [H, W], DensePose segmentation labels
122
+ dp_result (torch.Tensor): shape [2, H, W], stacked DensePose results u and v
123
+ """
124
+ converter = ToChartResultConverter
125
+ chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
126
+ labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
127
+ return labels, dp_result
128
+
129
+ def _resample_mask(self, output: Any) -> torch.Tensor:
130
+ """
131
+ Convert DensePose predictor output to segmentation annotation - tensors of size
132
+ (256, 256) and type `int64`.
133
+
134
+ Args:
135
+ output: DensePose predictor output with the following attributes:
136
+ - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
137
+ segmentation scores
138
+ - fine_segm: tensor of size [N, C, H, W] with unnormalized fine
139
+ segmentation scores
140
+ Return:
141
+ Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
142
+ where S = DensePoseDataRelative.MASK_SIZE
143
+ """
144
+ sz = DensePoseDataRelative.MASK_SIZE
145
+ S = (
146
+ F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
147
+ .argmax(dim=1)
148
+ .long()
149
+ )
150
+ I = (
151
+ (
152
+ F.interpolate(
153
+ output.fine_segm,
154
+ (sz, sz),
155
+ mode="bilinear",
156
+ align_corners=False,
157
+ ).argmax(dim=1)
158
+ * (S > 0).long()
159
+ )
160
+ .squeeze()
161
+ .cpu()
162
+ )
163
+ # Map fine segmentation results to coarse segmentation ground truth
164
+ # TODO: extract this into separate classes
165
+ # coarse segmentation: 1 = Torso, 2 = Right Hand, 3 = Left Hand,
166
+ # 4 = Left Foot, 5 = Right Foot, 6 = Upper Leg Right, 7 = Upper Leg Left,
167
+ # 8 = Lower Leg Right, 9 = Lower Leg Left, 10 = Upper Arm Left,
168
+ # 11 = Upper Arm Right, 12 = Lower Arm Left, 13 = Lower Arm Right,
169
+ # 14 = Head
170
+ # fine segmentation: 1, 2 = Torso, 3 = Right Hand, 4 = Left Hand,
171
+ # 5 = Left Foot, 6 = Right Foot, 7, 9 = Upper Leg Right,
172
+ # 8, 10 = Upper Leg Left, 11, 13 = Lower Leg Right,
173
+ # 12, 14 = Lower Leg Left, 15, 17 = Upper Arm Left,
174
+ # 16, 18 = Upper Arm Right, 19, 21 = Lower Arm Left,
175
+ # 20, 22 = Lower Arm Right, 23, 24 = Head
176
+ FINE_TO_COARSE_SEGMENTATION = {
177
+ 1: 1,
178
+ 2: 1,
179
+ 3: 2,
180
+ 4: 3,
181
+ 5: 4,
182
+ 6: 5,
183
+ 7: 6,
184
+ 8: 7,
185
+ 9: 6,
186
+ 10: 7,
187
+ 11: 8,
188
+ 12: 9,
189
+ 13: 8,
190
+ 14: 9,
191
+ 15: 10,
192
+ 16: 11,
193
+ 17: 10,
194
+ 18: 11,
195
+ 19: 12,
196
+ 20: 13,
197
+ 21: 12,
198
+ 22: 13,
199
+ 23: 14,
200
+ 24: 14,
201
+ }
202
+ mask = torch.zeros((sz, sz), dtype=torch.int64, device=torch.device("cpu"))
203
+ for i in range(DensePoseDataRelative.N_PART_LABELS):
204
+ mask[I == i + 1] = FINE_TO_COARSE_SEGMENTATION[i + 1]
205
+ return mask
densepose/data/samplers/densepose_confidence_based.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from typing import Optional, Tuple
7
+ import torch
8
+
9
+ from densepose.converters import ToChartResultConverterWithConfidences
10
+
11
+ from .densepose_base import DensePoseBaseSampler
12
+
13
+
14
+ class DensePoseConfidenceBasedSampler(DensePoseBaseSampler):
15
+ """
16
+ Samples DensePose data from DensePose predictions.
17
+ Samples for each class are drawn using confidence value estimates.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ confidence_channel: str,
23
+ count_per_class: int = 8,
24
+ search_count_multiplier: Optional[float] = None,
25
+ search_proportion: Optional[float] = None,
26
+ ):
27
+ """
28
+ Constructor
29
+
30
+ Args:
31
+ confidence_channel (str): confidence channel to use for sampling;
32
+ possible values:
33
+ "sigma_2": confidences for UV values
34
+ "fine_segm_confidence": confidences for fine segmentation
35
+ "coarse_segm_confidence": confidences for coarse segmentation
36
+ (default: "sigma_2")
37
+ count_per_class (int): the sampler produces at most `count_per_class`
38
+ samples for each category (default: 8)
39
+ search_count_multiplier (float or None): if not None, the total number
40
+ of the most confident estimates of a given class to consider is
41
+ defined as `min(search_count_multiplier * count_per_class, N)`,
42
+ where `N` is the total number of estimates of the class; cannot be
43
+ specified together with `search_proportion` (default: None)
44
+ search_proportion (float or None): if not None, the total number of the
45
+ of the most confident estimates of a given class to consider is
46
+ defined as `min(max(search_proportion * N, count_per_class), N)`,
47
+ where `N` is the total number of estimates of the class; cannot be
48
+ specified together with `search_count_multiplier` (default: None)
49
+ """
50
+ super().__init__(count_per_class)
51
+ self.confidence_channel = confidence_channel
52
+ self.search_count_multiplier = search_count_multiplier
53
+ self.search_proportion = search_proportion
54
+ assert (search_count_multiplier is None) or (search_proportion is None), (
55
+ f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
56
+ f"and search_proportion (={search_proportion})"
57
+ )
58
+
59
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
60
+ """
61
+ Produce a sample of indices to select data based on confidences
62
+
63
+ Args:
64
+ values (torch.Tensor): an array of size [n, k] that contains
65
+ estimated values (U, V, confidences);
66
+ n: number of channels (U, V, confidences)
67
+ k: number of points labeled with part_id
68
+ count (int): number of samples to produce, should be positive and <= k
69
+
70
+ Return:
71
+ list(int): indices of values (along axis 1) selected as a sample
72
+ """
73
+ k = values.shape[1]
74
+ if k == count:
75
+ index_sample = list(range(k))
76
+ else:
77
+ # take the best count * search_count_multiplier pixels,
78
+ # sample from them uniformly
79
+ # (here best = smallest variance)
80
+ _, sorted_confidence_indices = torch.sort(values[2])
81
+ if self.search_count_multiplier is not None:
82
+ search_count = min(int(count * self.search_count_multiplier), k)
83
+ elif self.search_proportion is not None:
84
+ search_count = min(max(int(k * self.search_proportion), count), k)
85
+ else:
86
+ search_count = min(count, k)
87
+ sample_from_top = random.sample(range(search_count), count)
88
+ index_sample = sorted_confidence_indices[:search_count][sample_from_top]
89
+ return index_sample
90
+
91
+ def _produce_labels_and_results(self, instance) -> Tuple[torch.Tensor, torch.Tensor]:
92
+ """
93
+ Method to get labels and DensePose results from an instance, with confidences
94
+
95
+ Args:
96
+ instance (Instances): an instance of `DensePoseChartPredictorOutputWithConfidences`
97
+
98
+ Return:
99
+ labels (torch.Tensor): shape [H, W], DensePose segmentation labels
100
+ dp_result (torch.Tensor): shape [3, H, W], DensePose results u and v
101
+ stacked with the confidence channel
102
+ """
103
+ converter = ToChartResultConverterWithConfidences
104
+ chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
105
+ labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
106
+ dp_result = torch.cat(
107
+ (dp_result, getattr(chart_result, self.confidence_channel)[None].cpu())
108
+ )
109
+
110
+ return labels, dp_result
densepose/data/samplers/densepose_cse_base.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Dict, List, Tuple
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import CfgNode
10
+ from detectron2.structures import Instances
11
+
12
+ from densepose.converters.base import IntTupleBox
13
+ from densepose.data.utils import get_class_to_mesh_name_mapping
14
+ from densepose.modeling.cse.utils import squared_euclidean_distance_matrix
15
+ from densepose.structures import DensePoseDataRelative
16
+
17
+ from .densepose_base import DensePoseBaseSampler
18
+
19
+
20
+ class DensePoseCSEBaseSampler(DensePoseBaseSampler):
21
+ """
22
+ Base DensePose sampler to produce DensePose data from DensePose predictions.
23
+ Samples for each class are drawn according to some distribution over all pixels estimated
24
+ to belong to that class.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ cfg: CfgNode,
30
+ use_gt_categories: bool,
31
+ embedder: torch.nn.Module,
32
+ count_per_class: int = 8,
33
+ ):
34
+ """
35
+ Constructor
36
+
37
+ Args:
38
+ cfg (CfgNode): the config of the model
39
+ embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
40
+ count_per_class (int): the sampler produces at most `count_per_class`
41
+ samples for each category
42
+ """
43
+ super().__init__(count_per_class)
44
+ self.embedder = embedder
45
+ self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
46
+ self.use_gt_categories = use_gt_categories
47
+
48
+ def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
49
+ """
50
+ Sample DensPoseDataRelative from estimation results
51
+ """
52
+ if self.use_gt_categories:
53
+ instance_class = instance.dataset_classes.tolist()[0]
54
+ else:
55
+ instance_class = instance.pred_classes.tolist()[0]
56
+ mesh_name = self.class_to_mesh_name[instance_class]
57
+
58
+ annotation = {
59
+ DensePoseDataRelative.X_KEY: [],
60
+ DensePoseDataRelative.Y_KEY: [],
61
+ DensePoseDataRelative.VERTEX_IDS_KEY: [],
62
+ DensePoseDataRelative.MESH_NAME_KEY: mesh_name,
63
+ }
64
+
65
+ mask, embeddings, other_values = self._produce_mask_and_results(instance, bbox_xywh)
66
+ indices = torch.nonzero(mask, as_tuple=True)
67
+ selected_embeddings = embeddings.permute(1, 2, 0)[indices].cpu()
68
+ values = other_values[:, indices[0], indices[1]]
69
+ k = values.shape[1]
70
+
71
+ count = min(self.count_per_class, k)
72
+ if count <= 0:
73
+ return annotation
74
+
75
+ index_sample = self._produce_index_sample(values, count)
76
+ closest_vertices = squared_euclidean_distance_matrix(
77
+ selected_embeddings[index_sample], self.embedder(mesh_name)
78
+ )
79
+ closest_vertices = torch.argmin(closest_vertices, dim=1)
80
+
81
+ sampled_y = indices[0][index_sample] + 0.5
82
+ sampled_x = indices[1][index_sample] + 0.5
83
+ # prepare / normalize data
84
+ _, _, w, h = bbox_xywh
85
+ x = (sampled_x / w * 256.0).cpu().tolist()
86
+ y = (sampled_y / h * 256.0).cpu().tolist()
87
+ # extend annotations
88
+ annotation[DensePoseDataRelative.X_KEY].extend(x)
89
+ annotation[DensePoseDataRelative.Y_KEY].extend(y)
90
+ annotation[DensePoseDataRelative.VERTEX_IDS_KEY].extend(closest_vertices.cpu().tolist())
91
+ return annotation
92
+
93
+ def _produce_mask_and_results(
94
+ self, instance: Instances, bbox_xywh: IntTupleBox
95
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
96
+ """
97
+ Method to get labels and DensePose results from an instance
98
+
99
+ Args:
100
+ instance (Instances): an instance of `DensePoseEmbeddingPredictorOutput`
101
+ bbox_xywh (IntTupleBox): the corresponding bounding box
102
+
103
+ Return:
104
+ mask (torch.Tensor): shape [H, W], DensePose segmentation mask
105
+ embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W],
106
+ DensePose CSE Embeddings
107
+ other_values (Tuple[torch.Tensor]): a tensor of shape [0, H, W],
108
+ for potential other values
109
+ """
110
+ densepose_output = instance.pred_densepose
111
+ S = densepose_output.coarse_segm
112
+ E = densepose_output.embedding
113
+ _, _, w, h = bbox_xywh
114
+ embeddings = F.interpolate(E, size=(h, w), mode="bilinear")[0]
115
+ coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0]
116
+ mask = coarse_segm_resized.argmax(0) > 0
117
+ other_values = torch.empty((0, h, w), device=E.device)
118
+ return mask, embeddings, other_values
119
+
120
+ def _resample_mask(self, output: Any) -> torch.Tensor:
121
+ """
122
+ Convert DensePose predictor output to segmentation annotation - tensors of size
123
+ (256, 256) and type `int64`.
124
+
125
+ Args:
126
+ output: DensePose predictor output with the following attributes:
127
+ - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
128
+ segmentation scores
129
+ Return:
130
+ Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
131
+ where S = DensePoseDataRelative.MASK_SIZE
132
+ """
133
+ sz = DensePoseDataRelative.MASK_SIZE
134
+ mask = (
135
+ F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
136
+ .argmax(dim=1)
137
+ .long()
138
+ .squeeze()
139
+ .cpu()
140
+ )
141
+ return mask
densepose/data/samplers/densepose_cse_confidence_based.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from typing import Optional, Tuple
7
+ import torch
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import CfgNode
11
+ from detectron2.structures import Instances
12
+
13
+ from densepose.converters.base import IntTupleBox
14
+
15
+ from .densepose_cse_base import DensePoseCSEBaseSampler
16
+
17
+
18
+ class DensePoseCSEConfidenceBasedSampler(DensePoseCSEBaseSampler):
19
+ """
20
+ Samples DensePose data from DensePose predictions.
21
+ Samples for each class are drawn using confidence value estimates.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ cfg: CfgNode,
27
+ use_gt_categories: bool,
28
+ embedder: torch.nn.Module,
29
+ confidence_channel: str,
30
+ count_per_class: int = 8,
31
+ search_count_multiplier: Optional[float] = None,
32
+ search_proportion: Optional[float] = None,
33
+ ):
34
+ """
35
+ Constructor
36
+
37
+ Args:
38
+ cfg (CfgNode): the config of the model
39
+ embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
40
+ confidence_channel (str): confidence channel to use for sampling;
41
+ possible values:
42
+ "coarse_segm_confidence": confidences for coarse segmentation
43
+ (default: "coarse_segm_confidence")
44
+ count_per_class (int): the sampler produces at most `count_per_class`
45
+ samples for each category (default: 8)
46
+ search_count_multiplier (float or None): if not None, the total number
47
+ of the most confident estimates of a given class to consider is
48
+ defined as `min(search_count_multiplier * count_per_class, N)`,
49
+ where `N` is the total number of estimates of the class; cannot be
50
+ specified together with `search_proportion` (default: None)
51
+ search_proportion (float or None): if not None, the total number of the
52
+ of the most confident estimates of a given class to consider is
53
+ defined as `min(max(search_proportion * N, count_per_class), N)`,
54
+ where `N` is the total number of estimates of the class; cannot be
55
+ specified together with `search_count_multiplier` (default: None)
56
+ """
57
+ super().__init__(cfg, use_gt_categories, embedder, count_per_class)
58
+ self.confidence_channel = confidence_channel
59
+ self.search_count_multiplier = search_count_multiplier
60
+ self.search_proportion = search_proportion
61
+ assert (search_count_multiplier is None) or (search_proportion is None), (
62
+ f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
63
+ f"and search_proportion (={search_proportion})"
64
+ )
65
+
66
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
67
+ """
68
+ Produce a sample of indices to select data based on confidences
69
+
70
+ Args:
71
+ values (torch.Tensor): a tensor of length k that contains confidences
72
+ k: number of points labeled with part_id
73
+ count (int): number of samples to produce, should be positive and <= k
74
+
75
+ Return:
76
+ list(int): indices of values (along axis 1) selected as a sample
77
+ """
78
+ k = values.shape[1]
79
+ if k == count:
80
+ index_sample = list(range(k))
81
+ else:
82
+ # take the best count * search_count_multiplier pixels,
83
+ # sample from them uniformly
84
+ # (here best = smallest variance)
85
+ _, sorted_confidence_indices = torch.sort(values[0])
86
+ if self.search_count_multiplier is not None:
87
+ search_count = min(int(count * self.search_count_multiplier), k)
88
+ elif self.search_proportion is not None:
89
+ search_count = min(max(int(k * self.search_proportion), count), k)
90
+ else:
91
+ search_count = min(count, k)
92
+ sample_from_top = random.sample(range(search_count), count)
93
+ index_sample = sorted_confidence_indices[-search_count:][sample_from_top]
94
+ return index_sample
95
+
96
+ def _produce_mask_and_results(
97
+ self, instance: Instances, bbox_xywh: IntTupleBox
98
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
99
+ """
100
+ Method to get labels and DensePose results from an instance
101
+
102
+ Args:
103
+ instance (Instances): an instance of
104
+ `DensePoseEmbeddingPredictorOutputWithConfidences`
105
+ bbox_xywh (IntTupleBox): the corresponding bounding box
106
+
107
+ Return:
108
+ mask (torch.Tensor): shape [H, W], DensePose segmentation mask
109
+ embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W]
110
+ DensePose CSE Embeddings
111
+ other_values: a tensor of shape [1, H, W], DensePose CSE confidence
112
+ """
113
+ _, _, w, h = bbox_xywh
114
+ densepose_output = instance.pred_densepose
115
+ mask, embeddings, _ = super()._produce_mask_and_results(instance, bbox_xywh)
116
+ other_values = F.interpolate(
117
+ getattr(densepose_output, self.confidence_channel),
118
+ size=(h, w),
119
+ mode="bilinear",
120
+ )[0].cpu()
121
+ return mask, embeddings, other_values
densepose/data/samplers/densepose_cse_uniform.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .densepose_cse_base import DensePoseCSEBaseSampler
6
+ from .densepose_uniform import DensePoseUniformSampler
7
+
8
+
9
+ class DensePoseCSEUniformSampler(DensePoseCSEBaseSampler, DensePoseUniformSampler):
10
+ """
11
+ Uniform Sampler for CSE
12
+ """
13
+
14
+ pass
densepose/data/samplers/densepose_uniform.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ import torch
7
+
8
+ from .densepose_base import DensePoseBaseSampler
9
+
10
+
11
+ class DensePoseUniformSampler(DensePoseBaseSampler):
12
+ """
13
+ Samples DensePose data from DensePose predictions.
14
+ Samples for each class are drawn uniformly over all pixels estimated
15
+ to belong to that class.
16
+ """
17
+
18
+ def __init__(self, count_per_class: int = 8):
19
+ """
20
+ Constructor
21
+
22
+ Args:
23
+ count_per_class (int): the sampler produces at most `count_per_class`
24
+ samples for each category
25
+ """
26
+ super().__init__(count_per_class)
27
+
28
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
29
+ """
30
+ Produce a uniform sample of indices to select data
31
+
32
+ Args:
33
+ values (torch.Tensor): an array of size [n, k] that contains
34
+ estimated values (U, V, confidences);
35
+ n: number of channels (U, V, confidences)
36
+ k: number of points labeled with part_id
37
+ count (int): number of samples to produce, should be positive and <= k
38
+
39
+ Return:
40
+ list(int): indices of values (along axis 1) selected as a sample
41
+ """
42
+ k = values.shape[1]
43
+ return random.sample(range(k), count)
densepose/data/samplers/mask_from_densepose.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from detectron2.structures import BitMasks, Instances
6
+
7
+ from densepose.converters import ToMaskConverter
8
+
9
+
10
+ class MaskFromDensePoseSampler:
11
+ """
12
+ Produce mask GT from DensePose predictions
13
+ This sampler simply converts DensePose predictions to BitMasks
14
+ that a contain a bool tensor of the size of the input image
15
+ """
16
+
17
+ def __call__(self, instances: Instances) -> BitMasks:
18
+ """
19
+ Converts predicted data from `instances` into the GT mask data
20
+
21
+ Args:
22
+ instances (Instances): predicted results, expected to have `pred_densepose` field
23
+
24
+ Returns:
25
+ Boolean Tensor of the size of the input image that has non-zero
26
+ values at pixels that are estimated to belong to the detected object
27
+ """
28
+ return ToMaskConverter.convert(
29
+ instances.pred_densepose, instances.pred_boxes, instances.image_size
30
+ )
densepose/data/samplers/prediction_to_gt.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Callable, Dict, List, Optional
7
+
8
+ from detectron2.structures import Instances
9
+
10
+ ModelOutput = Dict[str, Any]
11
+ SampledData = Dict[str, Any]
12
+
13
+
14
+ @dataclass
15
+ class _Sampler:
16
+ """
17
+ Sampler registry entry that contains:
18
+ - src (str): source field to sample from (deleted after sampling)
19
+ - dst (Optional[str]): destination field to sample to, if not None
20
+ - func (Optional[Callable: Any -> Any]): function that performs sampling,
21
+ if None, reference copy is performed
22
+ """
23
+
24
+ src: str
25
+ dst: Optional[str]
26
+ func: Optional[Callable[[Any], Any]]
27
+
28
+
29
+ class PredictionToGroundTruthSampler:
30
+ """
31
+ Sampler implementation that converts predictions to GT using registered
32
+ samplers for different fields of `Instances`.
33
+ """
34
+
35
+ def __init__(self, dataset_name: str = ""):
36
+ self.dataset_name = dataset_name
37
+ self._samplers = {}
38
+ self.register_sampler("pred_boxes", "gt_boxes", None)
39
+ self.register_sampler("pred_classes", "gt_classes", None)
40
+ # delete scores
41
+ self.register_sampler("scores")
42
+
43
+ def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]:
44
+ """
45
+ Transform model output into ground truth data through sampling
46
+
47
+ Args:
48
+ model_output (Dict[str, Any]): model output
49
+ Returns:
50
+ Dict[str, Any]: sampled data
51
+ """
52
+ for model_output_i in model_output:
53
+ instances: Instances = model_output_i["instances"]
54
+ # transform data in each field
55
+ for _, sampler in self._samplers.items():
56
+ if not instances.has(sampler.src) or sampler.dst is None:
57
+ continue
58
+ if sampler.func is None:
59
+ instances.set(sampler.dst, instances.get(sampler.src))
60
+ else:
61
+ instances.set(sampler.dst, sampler.func(instances))
62
+ # delete model output data that was transformed
63
+ for _, sampler in self._samplers.items():
64
+ if sampler.src != sampler.dst and instances.has(sampler.src):
65
+ instances.remove(sampler.src)
66
+ model_output_i["dataset"] = self.dataset_name
67
+ return model_output
68
+
69
+ def register_sampler(
70
+ self,
71
+ prediction_attr: str,
72
+ gt_attr: Optional[str] = None,
73
+ func: Optional[Callable[[Any], Any]] = None,
74
+ ):
75
+ """
76
+ Register sampler for a field
77
+
78
+ Args:
79
+ prediction_attr (str): field to replace with a sampled value
80
+ gt_attr (Optional[str]): field to store the sampled value to, if not None
81
+ func (Optional[Callable: Any -> Any]): sampler function
82
+ """
83
+ self._samplers[(prediction_attr, gt_attr)] = _Sampler(
84
+ src=prediction_attr, dst=gt_attr, func=func
85
+ )
86
+
87
+ def remove_sampler(
88
+ self,
89
+ prediction_attr: str,
90
+ gt_attr: Optional[str] = None,
91
+ ):
92
+ """
93
+ Remove sampler for a field
94
+
95
+ Args:
96
+ prediction_attr (str): field to replace with a sampled value
97
+ gt_attr (Optional[str]): field to store the sampled value to, if not None
98
+ """
99
+ assert (prediction_attr, gt_attr) in self._samplers
100
+ del self._samplers[(prediction_attr, gt_attr)]
densepose/data/transform/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .image import ImageResizeTransform
densepose/data/transform/image.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import torch
6
+
7
+
8
+ class ImageResizeTransform:
9
+ """
10
+ Transform that resizes images loaded from a dataset
11
+ (BGR data in NCHW channel order, typically uint8) to a format ready to be
12
+ consumed by DensePose training (BGR float32 data in NCHW channel order)
13
+ """
14
+
15
+ def __init__(self, min_size: int = 800, max_size: int = 1333):
16
+ self.min_size = min_size
17
+ self.max_size = max_size
18
+
19
+ def __call__(self, images: torch.Tensor) -> torch.Tensor:
20
+ """
21
+ Args:
22
+ images (torch.Tensor): tensor of size [N, 3, H, W] that contains
23
+ BGR data (typically in uint8)
24
+ Returns:
25
+ images (torch.Tensor): tensor of size [N, 3, H1, W1] where
26
+ H1 and W1 are chosen to respect the specified min and max sizes
27
+ and preserve the original aspect ratio, the data channels
28
+ follow BGR order and the data type is `torch.float32`
29
+ """
30
+ # resize with min size
31
+ images = images.float()
32
+ min_size = min(images.shape[-2:])
33
+ max_size = max(images.shape[-2:])
34
+ scale = min(self.min_size / min_size, self.max_size / max_size)
35
+ images = torch.nn.functional.interpolate(
36
+ images,
37
+ scale_factor=scale,
38
+ mode="bilinear",
39
+ align_corners=False,
40
+ )
41
+ return images
densepose/data/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import os
6
+ from typing import Dict, Optional
7
+
8
+ from detectron2.config import CfgNode
9
+
10
+
11
+ def is_relative_local_path(path: str) -> bool:
12
+ path_str = os.fsdecode(path)
13
+ return ("://" not in path_str) and not os.path.isabs(path)
14
+
15
+
16
+ def maybe_prepend_base_path(base_path: Optional[str], path: str):
17
+ """
18
+ Prepends the provided path with a base path prefix if:
19
+ 1) base path is not None;
20
+ 2) path is a local path
21
+ """
22
+ if base_path is None:
23
+ return path
24
+ if is_relative_local_path(path):
25
+ return os.path.join(base_path, path)
26
+ return path
27
+
28
+
29
+ def get_class_to_mesh_name_mapping(cfg: CfgNode) -> Dict[int, str]:
30
+ return {
31
+ int(class_id): mesh_name
32
+ for class_id, mesh_name in cfg.DATASETS.CLASS_TO_MESH_NAME_MAPPING.items()
33
+ }
34
+
35
+
36
+ def get_category_to_class_mapping(dataset_cfg: CfgNode) -> Dict[str, int]:
37
+ return {
38
+ category: int(class_id)
39
+ for category, class_id in dataset_cfg.CATEGORY_TO_CLASS_MAPPING.items()
40
+ }
densepose/data/video/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .frame_selector import (
6
+ FrameSelectionStrategy,
7
+ RandomKFramesSelector,
8
+ FirstKFramesSelector,
9
+ LastKFramesSelector,
10
+ FrameTsList,
11
+ FrameSelector,
12
+ )
13
+
14
+ from .video_keyframe_dataset import (
15
+ VideoKeyframeDataset,
16
+ video_list_from_file,
17
+ list_keyframes,
18
+ read_keyframes,
19
+ )
densepose/data/video/frame_selector.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from collections.abc import Callable
7
+ from enum import Enum
8
+ from typing import Callable as TCallable
9
+ from typing import List
10
+
11
+ FrameTsList = List[int]
12
+ FrameSelector = TCallable[[FrameTsList], FrameTsList]
13
+
14
+
15
+ class FrameSelectionStrategy(Enum):
16
+ """
17
+ Frame selection strategy used with videos:
18
+ - "random_k": select k random frames
19
+ - "first_k": select k first frames
20
+ - "last_k": select k last frames
21
+ - "all": select all frames
22
+ """
23
+
24
+ # fmt: off
25
+ RANDOM_K = "random_k"
26
+ FIRST_K = "first_k"
27
+ LAST_K = "last_k"
28
+ ALL = "all"
29
+ # fmt: on
30
+
31
+
32
+ class RandomKFramesSelector(Callable): # pyre-ignore[39]
33
+ """
34
+ Selector that retains at most `k` random frames
35
+ """
36
+
37
+ def __init__(self, k: int):
38
+ self.k = k
39
+
40
+ def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
41
+ """
42
+ Select `k` random frames
43
+
44
+ Args:
45
+ frames_tss (List[int]): timestamps of input frames
46
+ Returns:
47
+ List[int]: timestamps of selected frames
48
+ """
49
+ return random.sample(frame_tss, min(self.k, len(frame_tss)))
50
+
51
+
52
+ class FirstKFramesSelector(Callable): # pyre-ignore[39]
53
+ """
54
+ Selector that retains at most `k` first frames
55
+ """
56
+
57
+ def __init__(self, k: int):
58
+ self.k = k
59
+
60
+ def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
61
+ """
62
+ Select `k` first frames
63
+
64
+ Args:
65
+ frames_tss (List[int]): timestamps of input frames
66
+ Returns:
67
+ List[int]: timestamps of selected frames
68
+ """
69
+ return frame_tss[: self.k]
70
+
71
+
72
+ class LastKFramesSelector(Callable): # pyre-ignore[39]
73
+ """
74
+ Selector that retains at most `k` last frames from video data
75
+ """
76
+
77
+ def __init__(self, k: int):
78
+ self.k = k
79
+
80
+ def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
81
+ """
82
+ Select `k` last frames
83
+
84
+ Args:
85
+ frames_tss (List[int]): timestamps of input frames
86
+ Returns:
87
+ List[int]: timestamps of selected frames
88
+ """
89
+ return frame_tss[-self.k :]
densepose/data/video/video_keyframe_dataset.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ # pyre-unsafe
5
+
6
+ import csv
7
+ import logging
8
+ import numpy as np
9
+ from typing import Any, Callable, Dict, List, Optional, Union
10
+ import av
11
+ import torch
12
+ from torch.utils.data.dataset import Dataset
13
+
14
+ from detectron2.utils.file_io import PathManager
15
+
16
+ from ..utils import maybe_prepend_base_path
17
+ from .frame_selector import FrameSelector, FrameTsList
18
+
19
+ FrameList = List[av.frame.Frame] # pyre-ignore[16]
20
+ FrameTransform = Callable[[torch.Tensor], torch.Tensor]
21
+
22
+
23
+ def list_keyframes(video_fpath: str, video_stream_idx: int = 0) -> FrameTsList:
24
+ """
25
+ Traverses all keyframes of a video file. Returns a list of keyframe
26
+ timestamps. Timestamps are counts in timebase units.
27
+
28
+ Args:
29
+ video_fpath (str): Video file path
30
+ video_stream_idx (int): Video stream index (default: 0)
31
+ Returns:
32
+ List[int]: list of keyframe timestaps (timestamp is a count in timebase
33
+ units)
34
+ """
35
+ try:
36
+ with PathManager.open(video_fpath, "rb") as io:
37
+ # pyre-fixme[16]: Module `av` has no attribute `open`.
38
+ container = av.open(io, mode="r")
39
+ stream = container.streams.video[video_stream_idx]
40
+ keyframes = []
41
+ pts = -1
42
+ # Note: even though we request forward seeks for keyframes, sometimes
43
+ # a keyframe in backwards direction is returned. We introduce tolerance
44
+ # as a max count of ignored backward seeks
45
+ tolerance_backward_seeks = 2
46
+ while True:
47
+ try:
48
+ container.seek(pts + 1, backward=False, any_frame=False, stream=stream)
49
+ except av.AVError as e:
50
+ # the exception occurs when the video length is exceeded,
51
+ # we then return whatever data we've already collected
52
+ logger = logging.getLogger(__name__)
53
+ logger.debug(
54
+ f"List keyframes: Error seeking video file {video_fpath}, "
55
+ f"video stream {video_stream_idx}, pts {pts + 1}, AV error: {e}"
56
+ )
57
+ return keyframes
58
+ except OSError as e:
59
+ logger = logging.getLogger(__name__)
60
+ logger.warning(
61
+ f"List keyframes: Error seeking video file {video_fpath}, "
62
+ f"video stream {video_stream_idx}, pts {pts + 1}, OS error: {e}"
63
+ )
64
+ return []
65
+ packet = next(container.demux(video=video_stream_idx))
66
+ if packet.pts is not None and packet.pts <= pts:
67
+ logger = logging.getLogger(__name__)
68
+ logger.warning(
69
+ f"Video file {video_fpath}, stream {video_stream_idx}: "
70
+ f"bad seek for packet {pts + 1} (got packet {packet.pts}), "
71
+ f"tolerance {tolerance_backward_seeks}."
72
+ )
73
+ tolerance_backward_seeks -= 1
74
+ if tolerance_backward_seeks == 0:
75
+ return []
76
+ pts += 1
77
+ continue
78
+ tolerance_backward_seeks = 2
79
+ pts = packet.pts
80
+ if pts is None:
81
+ return keyframes
82
+ if packet.is_keyframe:
83
+ keyframes.append(pts)
84
+ return keyframes
85
+ except OSError as e:
86
+ logger = logging.getLogger(__name__)
87
+ logger.warning(
88
+ f"List keyframes: Error opening video file container {video_fpath}, " f"OS error: {e}"
89
+ )
90
+ except RuntimeError as e:
91
+ logger = logging.getLogger(__name__)
92
+ logger.warning(
93
+ f"List keyframes: Error opening video file container {video_fpath}, "
94
+ f"Runtime error: {e}"
95
+ )
96
+ return []
97
+
98
+
99
+ def read_keyframes(
100
+ video_fpath: str, keyframes: FrameTsList, video_stream_idx: int = 0
101
+ ) -> FrameList: # pyre-ignore[11]
102
+ """
103
+ Reads keyframe data from a video file.
104
+
105
+ Args:
106
+ video_fpath (str): Video file path
107
+ keyframes (List[int]): List of keyframe timestamps (as counts in
108
+ timebase units to be used in container seek operations)
109
+ video_stream_idx (int): Video stream index (default: 0)
110
+ Returns:
111
+ List[Frame]: list of frames that correspond to the specified timestamps
112
+ """
113
+ try:
114
+ with PathManager.open(video_fpath, "rb") as io:
115
+ # pyre-fixme[16]: Module `av` has no attribute `open`.
116
+ container = av.open(io)
117
+ stream = container.streams.video[video_stream_idx]
118
+ frames = []
119
+ for pts in keyframes:
120
+ try:
121
+ container.seek(pts, any_frame=False, stream=stream)
122
+ frame = next(container.decode(video=0))
123
+ frames.append(frame)
124
+ except av.AVError as e:
125
+ logger = logging.getLogger(__name__)
126
+ logger.warning(
127
+ f"Read keyframes: Error seeking video file {video_fpath}, "
128
+ f"video stream {video_stream_idx}, pts {pts}, AV error: {e}"
129
+ )
130
+ container.close()
131
+ return frames
132
+ except OSError as e:
133
+ logger = logging.getLogger(__name__)
134
+ logger.warning(
135
+ f"Read keyframes: Error seeking video file {video_fpath}, "
136
+ f"video stream {video_stream_idx}, pts {pts}, OS error: {e}"
137
+ )
138
+ container.close()
139
+ return frames
140
+ except StopIteration:
141
+ logger = logging.getLogger(__name__)
142
+ logger.warning(
143
+ f"Read keyframes: Error decoding frame from {video_fpath}, "
144
+ f"video stream {video_stream_idx}, pts {pts}"
145
+ )
146
+ container.close()
147
+ return frames
148
+
149
+ container.close()
150
+ return frames
151
+ except OSError as e:
152
+ logger = logging.getLogger(__name__)
153
+ logger.warning(
154
+ f"Read keyframes: Error opening video file container {video_fpath}, OS error: {e}"
155
+ )
156
+ except RuntimeError as e:
157
+ logger = logging.getLogger(__name__)
158
+ logger.warning(
159
+ f"Read keyframes: Error opening video file container {video_fpath}, Runtime error: {e}"
160
+ )
161
+ return []
162
+
163
+
164
+ def video_list_from_file(video_list_fpath: str, base_path: Optional[str] = None):
165
+ """
166
+ Create a list of paths to video files from a text file.
167
+
168
+ Args:
169
+ video_list_fpath (str): path to a plain text file with the list of videos
170
+ base_path (str): base path for entries from the video list (default: None)
171
+ """
172
+ video_list = []
173
+ with PathManager.open(video_list_fpath, "r") as io:
174
+ for line in io:
175
+ video_list.append(maybe_prepend_base_path(base_path, str(line.strip())))
176
+ return video_list
177
+
178
+
179
+ def read_keyframe_helper_data(fpath: str):
180
+ """
181
+ Read keyframe data from a file in CSV format: the header should contain
182
+ "video_id" and "keyframes" fields. Value specifications are:
183
+ video_id: int
184
+ keyframes: list(int)
185
+ Example of contents:
186
+ video_id,keyframes
187
+ 2,"[1,11,21,31,41,51,61,71,81]"
188
+
189
+ Args:
190
+ fpath (str): File containing keyframe data
191
+
192
+ Return:
193
+ video_id_to_keyframes (dict: int -> list(int)): for a given video ID it
194
+ contains a list of keyframes for that video
195
+ """
196
+ video_id_to_keyframes = {}
197
+ try:
198
+ with PathManager.open(fpath, "r") as io:
199
+ csv_reader = csv.reader(io)
200
+ header = next(csv_reader)
201
+ video_id_idx = header.index("video_id")
202
+ keyframes_idx = header.index("keyframes")
203
+ for row in csv_reader:
204
+ video_id = int(row[video_id_idx])
205
+ assert (
206
+ video_id not in video_id_to_keyframes
207
+ ), f"Duplicate keyframes entry for video {fpath}"
208
+ video_id_to_keyframes[video_id] = (
209
+ [int(v) for v in row[keyframes_idx][1:-1].split(",")]
210
+ if len(row[keyframes_idx]) > 2
211
+ else []
212
+ )
213
+ except Exception as e:
214
+ logger = logging.getLogger(__name__)
215
+ logger.warning(f"Error reading keyframe helper data from {fpath}: {e}")
216
+ return video_id_to_keyframes
217
+
218
+
219
+ class VideoKeyframeDataset(Dataset):
220
+ """
221
+ Dataset that provides keyframes for a set of videos.
222
+ """
223
+
224
+ _EMPTY_FRAMES = torch.empty((0, 3, 1, 1))
225
+
226
+ def __init__(
227
+ self,
228
+ video_list: List[str],
229
+ category_list: Union[str, List[str], None] = None,
230
+ frame_selector: Optional[FrameSelector] = None,
231
+ transform: Optional[FrameTransform] = None,
232
+ keyframe_helper_fpath: Optional[str] = None,
233
+ ):
234
+ """
235
+ Dataset constructor
236
+
237
+ Args:
238
+ video_list (List[str]): list of paths to video files
239
+ category_list (Union[str, List[str], None]): list of animal categories for each
240
+ video file. If it is a string, or None, this applies to all videos
241
+ frame_selector (Callable: KeyFrameList -> KeyFrameList):
242
+ selects keyframes to process, keyframes are given by
243
+ packet timestamps in timebase counts. If None, all keyframes
244
+ are selected (default: None)
245
+ transform (Callable: torch.Tensor -> torch.Tensor):
246
+ transforms a batch of RGB images (tensors of size [B, 3, H, W]),
247
+ returns a tensor of the same size. If None, no transform is
248
+ applied (default: None)
249
+
250
+ """
251
+ if type(category_list) is list:
252
+ self.category_list = category_list
253
+ else:
254
+ self.category_list = [category_list] * len(video_list)
255
+ assert len(video_list) == len(
256
+ self.category_list
257
+ ), "length of video and category lists must be equal"
258
+ self.video_list = video_list
259
+ self.frame_selector = frame_selector
260
+ self.transform = transform
261
+ self.keyframe_helper_data = (
262
+ read_keyframe_helper_data(keyframe_helper_fpath)
263
+ if keyframe_helper_fpath is not None
264
+ else None
265
+ )
266
+
267
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
268
+ """
269
+ Gets selected keyframes from a given video
270
+
271
+ Args:
272
+ idx (int): video index in the video list file
273
+ Returns:
274
+ A dictionary containing two keys:
275
+ images (torch.Tensor): tensor of size [N, H, W, 3] or of size
276
+ defined by the transform that contains keyframes data
277
+ categories (List[str]): categories of the frames
278
+ """
279
+ categories = [self.category_list[idx]]
280
+ fpath = self.video_list[idx]
281
+ keyframes = (
282
+ list_keyframes(fpath)
283
+ if self.keyframe_helper_data is None or idx not in self.keyframe_helper_data
284
+ else self.keyframe_helper_data[idx]
285
+ )
286
+ transform = self.transform
287
+ frame_selector = self.frame_selector
288
+ if not keyframes:
289
+ return {"images": self._EMPTY_FRAMES, "categories": []}
290
+ if frame_selector is not None:
291
+ keyframes = frame_selector(keyframes)
292
+ frames = read_keyframes(fpath, keyframes)
293
+ if not frames:
294
+ return {"images": self._EMPTY_FRAMES, "categories": []}
295
+ frames = np.stack([frame.to_rgb().to_ndarray() for frame in frames])
296
+ frames = torch.as_tensor(frames, device=torch.device("cpu"))
297
+ frames = frames[..., [2, 1, 0]] # RGB -> BGR
298
+ frames = frames.permute(0, 3, 1, 2).float() # NHWC -> NCHW
299
+ if transform is not None:
300
+ frames = transform(frames)
301
+ return {"images": frames, "categories": categories}
302
+
303
+ def __len__(self):
304
+ return len(self.video_list)
densepose/engine/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .trainer import Trainer
densepose/engine/trainer.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ import logging
6
+ import os
7
+ from collections import OrderedDict
8
+ from typing import List, Optional, Union
9
+ import torch
10
+ from torch import nn
11
+
12
+ from detectron2.checkpoint import DetectionCheckpointer
13
+ from detectron2.config import CfgNode
14
+ from detectron2.engine import DefaultTrainer
15
+ from detectron2.evaluation import (
16
+ DatasetEvaluator,
17
+ DatasetEvaluators,
18
+ inference_on_dataset,
19
+ print_csv_format,
20
+ )
21
+ from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping
22
+ from detectron2.utils import comm
23
+ from detectron2.utils.events import EventWriter, get_event_storage
24
+
25
+ from densepose import DensePoseDatasetMapperTTA, DensePoseGeneralizedRCNNWithTTA, load_from_cfg
26
+ from densepose.data import (
27
+ DatasetMapper,
28
+ build_combined_loader,
29
+ build_detection_test_loader,
30
+ build_detection_train_loader,
31
+ build_inference_based_loaders,
32
+ has_inference_based_loaders,
33
+ )
34
+ from densepose.evaluation.d2_evaluator_adapter import Detectron2COCOEvaluatorAdapter
35
+ from densepose.evaluation.evaluator import DensePoseCOCOEvaluator, build_densepose_evaluator_storage
36
+ from densepose.modeling.cse import Embedder
37
+
38
+
39
+ class SampleCountingLoader:
40
+ def __init__(self, loader):
41
+ self.loader = loader
42
+
43
+ def __iter__(self):
44
+ it = iter(self.loader)
45
+ storage = get_event_storage()
46
+ while True:
47
+ try:
48
+ batch = next(it)
49
+ num_inst_per_dataset = {}
50
+ for data in batch:
51
+ dataset_name = data["dataset"]
52
+ if dataset_name not in num_inst_per_dataset:
53
+ num_inst_per_dataset[dataset_name] = 0
54
+ num_inst = len(data["instances"])
55
+ num_inst_per_dataset[dataset_name] += num_inst
56
+ for dataset_name in num_inst_per_dataset:
57
+ storage.put_scalar(f"batch/{dataset_name}", num_inst_per_dataset[dataset_name])
58
+ yield batch
59
+ except StopIteration:
60
+ break
61
+
62
+
63
+ class SampleCountMetricPrinter(EventWriter):
64
+ def __init__(self):
65
+ self.logger = logging.getLogger(__name__)
66
+
67
+ def write(self):
68
+ storage = get_event_storage()
69
+ batch_stats_strs = []
70
+ for key, buf in storage.histories().items():
71
+ if key.startswith("batch/"):
72
+ batch_stats_strs.append(f"{key} {buf.avg(20)}")
73
+ self.logger.info(", ".join(batch_stats_strs))
74
+
75
+
76
+ class Trainer(DefaultTrainer):
77
+ @classmethod
78
+ def extract_embedder_from_model(cls, model: nn.Module) -> Optional[Embedder]:
79
+ if isinstance(model, nn.parallel.DistributedDataParallel):
80
+ model = model.module
81
+ if hasattr(model, "roi_heads") and hasattr(model.roi_heads, "embedder"):
82
+ return model.roi_heads.embedder
83
+ return None
84
+
85
+ # TODO: the only reason to copy the base class code here is to pass the embedder from
86
+ # the model to the evaluator; that should be refactored to avoid unnecessary copy-pasting
87
+ @classmethod
88
+ def test(
89
+ cls,
90
+ cfg: CfgNode,
91
+ model: nn.Module,
92
+ evaluators: Optional[Union[DatasetEvaluator, List[DatasetEvaluator]]] = None,
93
+ ):
94
+ """
95
+ Args:
96
+ cfg (CfgNode):
97
+ model (nn.Module):
98
+ evaluators (DatasetEvaluator, list[DatasetEvaluator] or None): if None, will call
99
+ :meth:`build_evaluator`. Otherwise, must have the same length as
100
+ ``cfg.DATASETS.TEST``.
101
+
102
+ Returns:
103
+ dict: a dict of result metrics
104
+ """
105
+ logger = logging.getLogger(__name__)
106
+ if isinstance(evaluators, DatasetEvaluator):
107
+ evaluators = [evaluators]
108
+ if evaluators is not None:
109
+ assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
110
+ len(cfg.DATASETS.TEST), len(evaluators)
111
+ )
112
+
113
+ results = OrderedDict()
114
+ for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
115
+ data_loader = cls.build_test_loader(cfg, dataset_name)
116
+ # When evaluators are passed in as arguments,
117
+ # implicitly assume that evaluators can be created before data_loader.
118
+ if evaluators is not None:
119
+ evaluator = evaluators[idx]
120
+ else:
121
+ try:
122
+ embedder = cls.extract_embedder_from_model(model)
123
+ evaluator = cls.build_evaluator(cfg, dataset_name, embedder=embedder)
124
+ except NotImplementedError:
125
+ logger.warn(
126
+ "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
127
+ "or implement its `build_evaluator` method."
128
+ )
129
+ results[dataset_name] = {}
130
+ continue
131
+ if cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE or comm.is_main_process():
132
+ results_i = inference_on_dataset(model, data_loader, evaluator)
133
+ else:
134
+ results_i = {}
135
+ results[dataset_name] = results_i
136
+ if comm.is_main_process():
137
+ assert isinstance(
138
+ results_i, dict
139
+ ), "Evaluator must return a dict on the main process. Got {} instead.".format(
140
+ results_i
141
+ )
142
+ logger.info("Evaluation results for {} in csv format:".format(dataset_name))
143
+ print_csv_format(results_i)
144
+
145
+ if len(results) == 1:
146
+ results = list(results.values())[0]
147
+ return results
148
+
149
+ @classmethod
150
+ def build_evaluator(
151
+ cls,
152
+ cfg: CfgNode,
153
+ dataset_name: str,
154
+ output_folder: Optional[str] = None,
155
+ embedder: Optional[Embedder] = None,
156
+ ) -> DatasetEvaluators:
157
+ if output_folder is None:
158
+ output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
159
+ evaluators = []
160
+ distributed = cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE
161
+ # Note: we currently use COCO evaluator for both COCO and LVIS datasets
162
+ # to have compatible metrics. LVIS bbox evaluator could also be used
163
+ # with an adapter to properly handle filtered / mapped categories
164
+ # evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
165
+ # if evaluator_type == "coco":
166
+ # evaluators.append(COCOEvaluator(dataset_name, output_dir=output_folder))
167
+ # elif evaluator_type == "lvis":
168
+ # evaluators.append(LVISEvaluator(dataset_name, output_dir=output_folder))
169
+ evaluators.append(
170
+ Detectron2COCOEvaluatorAdapter(
171
+ dataset_name, output_dir=output_folder, distributed=distributed
172
+ )
173
+ )
174
+ if cfg.MODEL.DENSEPOSE_ON:
175
+ storage = build_densepose_evaluator_storage(cfg, output_folder)
176
+ evaluators.append(
177
+ DensePoseCOCOEvaluator(
178
+ dataset_name,
179
+ distributed,
180
+ output_folder,
181
+ evaluator_type=cfg.DENSEPOSE_EVALUATION.TYPE,
182
+ min_iou_threshold=cfg.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD,
183
+ storage=storage,
184
+ embedder=embedder,
185
+ should_evaluate_mesh_alignment=cfg.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT,
186
+ mesh_alignment_mesh_names=cfg.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES,
187
+ )
188
+ )
189
+ return DatasetEvaluators(evaluators)
190
+
191
+ @classmethod
192
+ def build_optimizer(cls, cfg: CfgNode, model: nn.Module):
193
+ params = get_default_optimizer_params(
194
+ model,
195
+ base_lr=cfg.SOLVER.BASE_LR,
196
+ weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
197
+ bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
198
+ weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
199
+ overrides={
200
+ "features": {
201
+ "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR,
202
+ },
203
+ "embeddings": {
204
+ "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR,
205
+ },
206
+ },
207
+ )
208
+ optimizer = torch.optim.SGD(
209
+ params,
210
+ cfg.SOLVER.BASE_LR,
211
+ momentum=cfg.SOLVER.MOMENTUM,
212
+ nesterov=cfg.SOLVER.NESTEROV,
213
+ weight_decay=cfg.SOLVER.WEIGHT_DECAY,
214
+ )
215
+ # pyre-fixme[6]: For 2nd param expected `Type[Optimizer]` but got `SGD`.
216
+ return maybe_add_gradient_clipping(cfg, optimizer)
217
+
218
+ @classmethod
219
+ def build_test_loader(cls, cfg: CfgNode, dataset_name):
220
+ return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False))
221
+
222
+ @classmethod
223
+ def build_train_loader(cls, cfg: CfgNode):
224
+ data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True))
225
+ if not has_inference_based_loaders(cfg):
226
+ return data_loader
227
+ model = cls.build_model(cfg)
228
+ model.to(cfg.BOOTSTRAP_MODEL.DEVICE)
229
+ DetectionCheckpointer(model).resume_or_load(cfg.BOOTSTRAP_MODEL.WEIGHTS, resume=False)
230
+ inference_based_loaders, ratios = build_inference_based_loaders(cfg, model)
231
+ loaders = [data_loader] + inference_based_loaders
232
+ ratios = [1.0] + ratios
233
+ combined_data_loader = build_combined_loader(cfg, loaders, ratios)
234
+ sample_counting_loader = SampleCountingLoader(combined_data_loader)
235
+ return sample_counting_loader
236
+
237
+ def build_writers(self):
238
+ writers = super().build_writers()
239
+ writers.append(SampleCountMetricPrinter())
240
+ return writers
241
+
242
+ @classmethod
243
+ def test_with_TTA(cls, cfg: CfgNode, model):
244
+ logger = logging.getLogger("detectron2.trainer")
245
+ # In the end of training, run an evaluation with TTA
246
+ # Only support some R-CNN models.
247
+ logger.info("Running inference with test-time augmentation ...")
248
+ transform_data = load_from_cfg(cfg)
249
+ model = DensePoseGeneralizedRCNNWithTTA(
250
+ cfg, model, transform_data, DensePoseDatasetMapperTTA(cfg)
251
+ )
252
+ evaluators = [
253
+ cls.build_evaluator(
254
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
255
+ )
256
+ for name in cfg.DATASETS.TEST
257
+ ]
258
+ res = cls.test(cfg, model, evaluators) # pyre-ignore[6]
259
+ res = OrderedDict({k + "_TTA": v for k, v in res.items()})
260
+ return res
densepose/evaluation/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .evaluator import DensePoseCOCOEvaluator
densepose/evaluation/d2_evaluator_adapter.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from detectron2.data.catalog import Metadata
6
+ from detectron2.evaluation import COCOEvaluator
7
+
8
+ from densepose.data.datasets.coco import (
9
+ get_contiguous_id_to_category_id_map,
10
+ maybe_filter_categories_cocoapi,
11
+ )
12
+
13
+
14
+ def _maybe_add_iscrowd_annotations(cocoapi) -> None:
15
+ for ann in cocoapi.dataset["annotations"]:
16
+ if "iscrowd" not in ann:
17
+ ann["iscrowd"] = 0
18
+
19
+
20
+ class Detectron2COCOEvaluatorAdapter(COCOEvaluator):
21
+ def __init__(
22
+ self,
23
+ dataset_name,
24
+ output_dir=None,
25
+ distributed=True,
26
+ ):
27
+ super().__init__(dataset_name, output_dir=output_dir, distributed=distributed)
28
+ maybe_filter_categories_cocoapi(dataset_name, self._coco_api)
29
+ _maybe_add_iscrowd_annotations(self._coco_api)
30
+ # substitute category metadata to account for categories
31
+ # that are mapped to the same contiguous id
32
+ if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
33
+ self._maybe_substitute_metadata()
34
+
35
+ def _maybe_substitute_metadata(self):
36
+ cont_id_2_cat_id = get_contiguous_id_to_category_id_map(self._metadata)
37
+ cat_id_2_cont_id = self._metadata.thing_dataset_id_to_contiguous_id
38
+ if len(cont_id_2_cat_id) == len(cat_id_2_cont_id):
39
+ return
40
+
41
+ cat_id_2_cont_id_injective = {}
42
+ for cat_id, cont_id in cat_id_2_cont_id.items():
43
+ if (cont_id in cont_id_2_cat_id) and (cont_id_2_cat_id[cont_id] == cat_id):
44
+ cat_id_2_cont_id_injective[cat_id] = cont_id
45
+
46
+ metadata_new = Metadata(name=self._metadata.name)
47
+ for key, value in self._metadata.__dict__.items():
48
+ if key == "thing_dataset_id_to_contiguous_id":
49
+ setattr(metadata_new, key, cat_id_2_cont_id_injective)
50
+ else:
51
+ setattr(metadata_new, key, value)
52
+ self._metadata = metadata_new