diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..49035ff10e9b40aa64301b8e8819dcf46606c4cb --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,159 @@ +# Creative Commons Attribution-NonCommercial 4.0 International + +Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. + +### Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. + +* __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). + +* __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). + +## Creative Commons Attribution-NonCommercial 4.0 International Public License + +By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. + +### Section 1 – Definitions. + +a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. + +b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. + +c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. + +d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. + +e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. + +f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. + +g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. + +h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. + +i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. + +j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. + +k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. + +l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. __Your__ has a corresponding meaning. + +### Section 2 – Scope. + +a. ___License grant.___ + + 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: + + A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and + + B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. + + 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. + + 3. __Term.__ The term of this Public License is specified in Section 6(a). + + 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. + + 5. __Downstream recipients.__ + + A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. + + B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. + + 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). + +b. ___Other rights.___ + + 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this Public License. + + 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. + +### Section 3 – License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the following conditions. + +a. ___Attribution.___ + + 1. If You Share the Licensed Material (including in modified form), You must: + + A. retain the following if it is supplied by the Licensor with the Licensed Material: + + i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of warranties; + + v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; + + B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and + + C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. + + 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License. + +### Section 4 – Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: + +a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; + +b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and + +c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. + +### Section 5 – Disclaimer of Warranties and Limitation of Liability. + +a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ + +b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ + +c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. + +### Section 6 – Term and Termination. + +a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. + +b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. + +c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. + +d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. + +### Section 7 – Other Terms and Conditions. + +a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. + +b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. + +### Section 8 – Interpretation. + +a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. + +b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. + +c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. + +d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. + +> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. +> +> Creative Commons may be contacted at creativecommons.org diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..0a3c46d8fa563a1026b3b650b46df77e667f84f6 --- /dev/null +++ b/app.py @@ -0,0 +1,39 @@ +import os +os.system('python setup.py build develop') +os.system('pip install --upgrade --no-cache-dir gdown') +os.system('gdown -O output/mixtrain/ 1XQsikiNY7ILgZvmvOeUf9oPDG4fTp0zs') + +import cv2 +import pandas as pd +import gradio as gr +from tools.demo import TextDemo +from maskrcnn_benchmark.config import cfg + + +def infer(filepath): + cfg.merge_from_file('configs/mixtrain/seg_rec_poly_fuse_feature.yaml') + # manual override some options + cfg.merge_from_list(["MODEL.DEVICE", "cpu"]) + + text_demo = TextDemo( + cfg, + min_image_size=800, + confidence_threshold=0.7, + output_polygon=True + ) + image = cv2.imread(filepath) + result_polygons, result_words = text_demo.run_on_opencv_image(image) + text_demo.visualization(image, result_polygons, result_words) + cv2.imwrite('result.jpg', image) + return 'result.jpg', pd.DataFrame(result_words) + + +iface = gr.Interface( + fn=infer, + title="Mask TextSpotter v3", + description="Mask TextSpotter v3 is an end-to-end trainable scene text spotter that adopts a Segmentation Proposal Network (SPN) instead of an RPN. Mask TextSpotter v3 significantly improves robustness to rotations, aspect ratios, and shapes.", + inputs=[gr.inputs.Image(label="image", type="filepath")], + outputs=[gr.outputs.Image(), gr.outputs.Dataframe(headers=['word'])], + examples=['example1.jpg', 'example2.jpg', 'example3.jpg'], + article="GitHub Repo", +).launch(enable_queue=True, cache_examples=True) diff --git a/configs/mixtrain/seg_rec_poly_fuse_feature.yaml b/configs/mixtrain/seg_rec_poly_fuse_feature.yaml new file mode 100644 index 0000000000000000000000000000000000000000..19cb63394e6c71319b263e1af7f39c3419f1806e --- /dev/null +++ b/configs/mixtrain/seg_rec_poly_fuse_feature.yaml @@ -0,0 +1,97 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + # WEIGHT: './output/path-to-pretrain-model' # for training + WEIGHT: './output/mixtrain/trained_model.pth' # for testing + BACKBONE: + CONV_BODY: "R-50-FPN" + OUT_CHANNELS: 256 + RESNETS: + BACKBONE_OUT_CHANNELS: 256 + RPN: + USE_FPN: True + ANCHOR_STRIDE: (4, 8, 16, 32, 64) + PRE_NMS_TOP_N_TRAIN: 2000 + PRE_NMS_TOP_N_TEST: 1000 + POST_NMS_TOP_N_TEST: 1000 + FPN_POST_NMS_TOP_N_TEST: 1000 + SEG: + USE_FPN: True + USE_FUSE_FEATURE: True + TOP_N_TRAIN: 1000 + TOP_N_TEST: 1000 + BINARY_THRESH: 0.1 + BOX_THRESH: 0.1 + MIN_SIZE: 5 + SHRINK_RATIO: 0.4 + EXPAND_RATIO: 3.0 + ROI_HEADS: + USE_FPN: True + BATCH_SIZE_PER_IMAGE: 512 + ROI_BOX_HEAD: + POOLER_RESOLUTION: 7 + POOLER_SCALES: (0.25,) + POOLER_SAMPLING_RATIO: 2 + FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" + PREDICTOR: "FPNPredictor" + NUM_CLASSES: 2 + USE_MASKED_FEATURE: True + ROI_MASK_HEAD: + POOLER_SCALES: (0.25,) + FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor" + PREDICTOR: "SeqCharMaskRCNNC4Predictor" + POOLER_RESOLUTION: 14 + POOLER_RESOLUTION_H: 32 + POOLER_RESOLUTION_W: 32 + POOLER_SAMPLING_RATIO: 2 + RESOLUTION: 28 + RESOLUTION_H: 64 + RESOLUTION_W: 64 + SHARE_BOX_FEATURE_EXTRACTOR: False + CHAR_NUM_CLASSES: 37 + USE_WEIGHTED_CHAR_MASK: True + MASK_BATCH_SIZE_PER_IM: 64 + USE_MASKED_FEATURE: True + MASK_ON: True + CHAR_MASK_ON: True + SEG_ON: True + # TRAIN_DETECTION_ONLY: True +SEQUENCE: + SEQ_ON: True + NUM_CHAR: 38 + BOS_TOKEN: 0 + MAX_LENGTH: 32 + TEACHER_FORCE_RATIO: 1.0 +DATASETS: + # TRAIN: ("synthtext_train",) + TRAIN: ("synthtext_train","icdar_2013_train","icdar_2015_train","scut-eng-char_train","total_text_train") + RATIOS: [0.25,0.25,0.25,0.125,0.125] + # TEST: ("icdar_2015_test",) + TEST: ("total_text_test",) + # TEST: ("rotated_ic13_test_45",) + AUG: True + IGNORE_DIFFICULT: True + MAX_ROTATE_THETA: 90 +DATALOADER: + SIZE_DIVISIBILITY: 32 + NUM_WORKERS: 4 + ASPECT_RATIO_GROUPING: False +SOLVER: + BASE_LR: 0.002 #0.02 + WARMUP_FACTOR: 0.1 + WEIGHT_DECAY: 0.0001 + STEPS: (100000, 160000) + MAX_ITER: 300000 + IMS_PER_BATCH: 8 + RESUME: False + DISPLAY_FREQ: 20 +OUTPUT_DIR: "./output/mixtrain" +TEST: + VIS: True + CHAR_THRESH: 192 + IMS_PER_BATCH: 1 +INPUT: + MIN_SIZE_TRAIN: (800, 1000, 1200, 1400) + MAX_SIZE_TRAIN: 2333 + MIN_SIZE_TEST: 1000 + # MIN_SIZE_TEST: 1440 + MAX_SIZE_TEST: 4000 diff --git a/configs/pretrain/seg_rec_poly_fuse_feature.yaml b/configs/pretrain/seg_rec_poly_fuse_feature.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8d1267111e3f706e12326f857ba3bdda378cb9bf --- /dev/null +++ b/configs/pretrain/seg_rec_poly_fuse_feature.yaml @@ -0,0 +1,94 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" + BACKBONE: + CONV_BODY: "R-50-FPN" + OUT_CHANNELS: 256 + RESNETS: + BACKBONE_OUT_CHANNELS: 256 + RPN: + USE_FPN: True + ANCHOR_STRIDE: (4, 8, 16, 32, 64) + PRE_NMS_TOP_N_TRAIN: 2000 + PRE_NMS_TOP_N_TEST: 1000 + POST_NMS_TOP_N_TEST: 1000 + FPN_POST_NMS_TOP_N_TEST: 1000 + SEG: + USE_FPN: True + USE_FUSE_FEATURE: True + TOP_N_TRAIN: 1000 + TOP_N_TEST: 1000 + BINARY_THRESH: 0.1 + BOX_THRESH: 0.1 + MIN_SIZE: 5 + SHRINK_RATIO: 0.4 + EXPAND_RATIO: 3.0 + ROI_HEADS: + USE_FPN: True + BATCH_SIZE_PER_IMAGE: 512 + ROI_BOX_HEAD: + POOLER_RESOLUTION: 7 + POOLER_SCALES: (0.25,) + POOLER_SAMPLING_RATIO: 2 + FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" + PREDICTOR: "FPNPredictor" + NUM_CLASSES: 2 + USE_MASKED_FEATURE: True + ROI_MASK_HEAD: + POOLER_SCALES: (0.25,) + FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor" + PREDICTOR: "SeqCharMaskRCNNC4Predictor" + POOLER_RESOLUTION: 14 + POOLER_RESOLUTION_H: 32 + POOLER_RESOLUTION_W: 32 + POOLER_SAMPLING_RATIO: 2 + RESOLUTION: 28 + RESOLUTION_H: 64 + RESOLUTION_W: 64 + SHARE_BOX_FEATURE_EXTRACTOR: False + CHAR_NUM_CLASSES: 37 + USE_WEIGHTED_CHAR_MASK: True + MASK_BATCH_SIZE_PER_IM: 64 + USE_MASKED_FEATURE: True + MASK_ON: True + CHAR_MASK_ON: True + SEG_ON: True +SEQUENCE: + SEQ_ON: True + NUM_CHAR: 38 + BOS_TOKEN: 0 + MAX_LENGTH: 32 + TEACHER_FORCE_RATIO: 1.0 +DATASETS: + TRAIN: ("synthtext_train",) + # TRAIN: ("synthtext_train","icdar_2013_train","icdar_2015_train","scut-eng-char_train","total_text_train") + # RATIOS: [0.25,0.25,0.25,0.125,0.125] + TEST: ("icdar_2015_test",) + # TEST: ("total_text_test",) + AUG: True + IGNORE_DIFFICULT: True + MAX_ROTATE_THETA: 90 +DATALOADER: + SIZE_DIVISIBILITY: 32 + NUM_WORKERS: 4 + ASPECT_RATIO_GROUPING: False +SOLVER: + BASE_LR: 0.02 #0.02 + WARMUP_FACTOR: 0.1 + WEIGHT_DECAY: 0.0001 + STEPS: (100000, 200000) + MAX_ITER: 300000 + IMS_PER_BATCH: 8 + RESUME: True + DISPLAY_FREQ: 20 +OUTPUT_DIR: "./output/pretrain" +TEST: + VIS: False + CHAR_THRESH: 192 + IMS_PER_BATCH: 1 +INPUT: + MIN_SIZE_TRAIN: (600, 800) + # MIN_SIZE_TRAIN: (800, 1000, 1200, 1400) + MAX_SIZE_TRAIN: 2333 + MIN_SIZE_TEST: 1440 + MAX_SIZE_TEST: 4000 diff --git a/evaluation/icdar2015/e2e/prepare_results.py b/evaluation/icdar2015/e2e/prepare_results.py new file mode 100644 index 0000000000000000000000000000000000000000..81c017aa0819063ecdf9f6bc9f8c7bfca5a405b9 --- /dev/null +++ b/evaluation/icdar2015/e2e/prepare_results.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import sys +import os +sys.path.append('./') +import shapely +from shapely.geometry import Polygon,MultiPoint +import numpy as np +import editdistance +sys.path.append('../../') +from weighted_editdistance import weighted_edit_distance +from tqdm import tqdm +try: + import pickle +except ImportError: + import cPickle as pickle + +def list_from_str(st): + line = st.split(',') + # box[0:4], polygon[4:12], word, seq_word, detection_score, rec_socre, seq_score, char_score_path + new_line = [float(a) for a in line[4:12]]+[float(line[-4])]+[line[-5]]+[line[-6]]+[float(line[-3])]+[float(line[-2])] + [line[-1]] + return new_line + +def polygon_from_list(line): + """ + Create a shapely polygon object from gt or dt line. + """ + polygon_points = np.array(line).reshape(4, 2) + polygon = Polygon(polygon_points).convex_hull + return polygon + +def polygon_iou(list1, list2): + """ + Intersection over union between two shapely polygons. + """ + polygon_points1 = np.array(list1).reshape(4, 2) + poly1 = Polygon(polygon_points1).convex_hull + polygon_points2 = np.array(list2).reshape(4, 2) + poly2 = Polygon(polygon_points2).convex_hull + union_poly = np.concatenate((polygon_points1,polygon_points2)) + if not poly1.intersects(poly2): # this test is fast and can accelerate calculation + iou = 0 + else: + try: + inter_area = poly1.intersection(poly2).area + #union_area = poly1.area + poly2.area - inter_area + union_area = MultiPoint(union_poly).convex_hull.area + iou = float(inter_area) / (union_area+1e-6) + except shapely.geos.TopologicalError: + print('shapely.geos.TopologicalError occured, iou set to 0') + iou = 0 + return iou + +def nms(boxes,overlap): + rec_scores = [b[-2] for b in boxes] + indices = sorted(range(len(rec_scores)), key=lambda k: -rec_scores[k]) + box_num = len(boxes) + nms_flag = [True]*box_num + for i in range(box_num): + ii = indices[i] + if not nms_flag[ii]: + continue + for j in range(box_num): + jj = indices[j] + if j == i: + continue + if not nms_flag[jj]: + continue + box1 = boxes[ii] + box2 = boxes[jj] + box1_score = rec_scores[ii] + box2_score = rec_scores[jj] + str1 = box1[9] + str2 = box2[9] + box_i = [box1[0],box1[1],box1[4],box1[5]] + box_j = [box2[0],box2[1],box2[4],box2[5]] + poly1 = polygon_from_list(box1[0:8]) + poly2 = polygon_from_list(box2[0:8]) + iou = polygon_iou(box1[0:8],box2[0:8]) + thresh = overlap + + if iou > thresh: + if box1_score > box2_score: + nms_flag[jj] = False + if box1_score == box2_score and poly1.area > poly2.area: + nms_flag[jj] = False + if box1_score == box2_score and poly1.area<=poly2.area: + nms_flag[ii] = False + break + + return nms_flag + +def packing(save_dir, cache_dir, pack_name): + files = os.listdir(save_dir) + if not os.path.exists(cache_dir): + os.mkdir(cache_dir) + os.system('zip -r -q -j '+os.path.join(cache_dir, pack_name+'.zip')+' '+save_dir+'/*') + +def test_single(results_dir,lexicon_type=3,cache_dir='./cache_dir',score_det=0.5,score_rec=0.5,score_rec_seq=0.5,overlap=0.2, use_lexicon=True, weighted_ed=True, use_seq=False, use_char=False, mix=False): + ''' + results_dir: result directory + score_det: score of detection bounding box + score_rec: score of the mask recognition branch + socre_rec_seq: score of the sequence recognition branch + overlap: overlap threshold used for nms + lexicon_type: 1 for generic; 2 for weak; 3 for strong + use_seq: use the recognition result of sequence branch + use_mix: use both the recognition result of the mask and sequence branches, selected by score + ''' + print('score_det:', 'score_det:', score_det, 'score_rec:', score_rec, 'score_rec_seq:', score_rec_seq, 'lexicon_type:', lexicon_type, 'weighted_ed:', weighted_ed, 'use_seq:', use_seq, 'use_char:', use_char, 'mix:', mix) + if not os.path.exists(cache_dir): + os.mkdir(cache_dir) + nms_dir = os.path.join(cache_dir,str(score_det)+'_'+str(score_rec)+'_'+str(score_rec_seq)) + if not os.path.exists(nms_dir): + os.mkdir(nms_dir) + if lexicon_type==1: + # generic lexicon + lexicon_path = '../../lexicons/ic15/GenericVocabulary_new.txt' + lexicon_fid=open(lexicon_path, 'r') + pair_list = open('../../lexicons/ic15/GenericVocabulary_pair_list.txt', 'r') + pairs = dict() + for line in pair_list.readlines(): + line=line.strip() + word = line.split(' ')[0].upper() + word_gt = line[len(word)+1:] + pairs[word] = word_gt + lexicon_fid=open(lexicon_path, 'r') + lexicon=[] + for line in lexicon_fid.readlines(): + line=line.strip() + lexicon.append(line) + if lexicon_type==2: + # weak lexicon + lexicon_path = '../../lexicons/ic15/ch4_test_vocabulary_new.txt' + lexicon_fid=open(lexicon_path, 'r') + pair_list = open('../../lexicons/ic15/ch4_test_vocabulary_pair_list.txt', 'r') + pairs = dict() + for line in pair_list.readlines(): + line=line.strip() + word = line.split(' ')[0].upper() + word_gt = line[len(word)+1:] + pairs[word] = word_gt + lexicon_fid=open(lexicon_path, 'r') + lexicon=[] + for line in lexicon_fid.readlines(): + line=line.strip() + lexicon.append(line) + + for i in tqdm(range(1,501)): + img = 'img_'+str(i)+'.jpg' + gt_img = 'gt_img_'+str(i)+'.txt' + if lexicon_type==3: + # weak + lexicon_path = '../../lexicons/ic15/new_strong_lexicon/new_voc_img_' + str(i) + '.txt' + lexicon_fid=open(lexicon_path, 'r') + pair_list = open('../../lexicons/ic15/new_strong_lexicon/pair_voc_img_' + str(i) + '.txt', 'r') + pairs = dict() + for line in pair_list.readlines(): + line=line.strip() + word = line.split(' ')[0].upper() + word_gt = line[len(word)+1:] + pairs[word] = word_gt + lexicon_fid=open(lexicon_path, 'r') + lexicon=[] + for line in lexicon_fid.readlines(): + line=line.strip() + lexicon.append(line) + result_path = os.path.join(results_dir,'res_img_'+str(i)+'.txt') + if os.path.isfile(result_path): + with open(result_path,'r') as f: + dt_lines = [a.strip() for a in f.readlines()] + dt_lines = [list_from_str(dt) for dt in dt_lines] + else: + dt_lines = [] + dt_lines = [dt for dt in dt_lines if dt[-2]>score_rec_seq and dt[-3]>score_rec and dt[-6]>score_det] + nms_flag = nms(dt_lines,overlap) + boxes = [] + for k in range(len(dt_lines)): + dt = dt_lines[k] + if nms_flag[k]: + if dt not in boxes: + boxes.append(dt) + + with open(os.path.join(nms_dir,'res_img_'+str(i)+'.txt'),'w') as f: + for g in boxes: + gt_coors = [int(b) for b in g[0:8]] + with open('../../../' + g[-1], "rb") as input_file: + # with open(g[-1], "rb") as input_file: + dict_scores = pickle.load(input_file) + if use_char and use_seq: + if g[-2]>g[-3]: + word = g[-5] + scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1) + else: + word = g[-4] + scores = dict_scores['seg_char_scores'] + elif use_seq: + word = g[-5] + scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1) + else: + word = g[-4] + scores = dict_scores['seg_char_scores'] + match_word, match_dist = find_match_word(word, lexicon, pairs, scores, use_lexicon, weighted_ed) + if match_dist<1.5 or lexicon_type==1: + gt_coor_strs = [str(a) for a in gt_coors]+ [match_word] + f.write(','.join(gt_coor_strs)+'\r\n') + + pack_name = str(score_det)+'_'+str(score_rec)+'_over'+str(overlap) + + packing(nms_dir,cache_dir,pack_name) + submit_file_path = os.path.join(cache_dir, pack_name+'.zip') + return submit_file_path + +def find_match_word(rec_str, lexicon, pairs, scores_numpy, use_ed = True, weighted_ed = False): + if not use_ed: + return rec_str + rec_str = rec_str.upper() + dist_min = 100 + dist_min_pre = 100 + match_word = '' + match_dist = 100 + if not weighted_ed: + for word in lexicon: + word = word.upper() + ed = editdistance.eval(rec_str, word) + length_dist = abs(len(word) - len(rec_str)) + # dist = ed + length_dist + dist = ed + if dist -s= [-o= -p=]' %sys.argv[0]) + sys.exit(2) + + +def load_zip_file_keys(file,fileNameRegExp=''): + """ + Returns an array with the entries of the ZIP file that match with the regular expression. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + """ + try: + archive=zipfile.ZipFile(file, mode='r', allowZip64=True) + except : + raise Exception('Error loading the ZIP archive.') + + pairs = [] + + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp!="": + m = re.match(fileNameRegExp,name) + if m == None: + addFile = False + else: + if len(m.groups())>0: + keyName = m.group(1) + + if addFile: + pairs.append( keyName ) + + return pairs + + +def load_zip_file(file,fileNameRegExp='',allEntries=False): + """ + Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + allEntries validates that all entries in the ZIP file pass the fileNameRegExp + """ + try: + archive=zipfile.ZipFile(file, mode='r', allowZip64=True) + except : + raise Exception('Error loading the ZIP archive') + + pairs = [] + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp!="": + m = re.match(fileNameRegExp,name) + if m == None: + addFile = False + else: + if len(m.groups())>0: + keyName = m.group(1) + + if addFile: + pairs.append( [ keyName , archive.read(name)] ) + else: + if allEntries: + raise Exception('ZIP entry not valid: %s' %name) + + return dict(pairs) + +def decode_utf8(raw): + """ + Returns a Unicode object on success, or None on failure + """ + try: + raw = codecs.decode(raw,'utf-8', 'replace') + #extracts BOM if exists + raw = raw.encode('utf8') + if raw.startswith(codecs.BOM_UTF8): + raw = raw.replace(codecs.BOM_UTF8, '', 1) + return raw.decode('utf-8') + except: + return None + +def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): + """ + This function validates that all lines of the file calling the Line validation function for each line + """ + utf8File = decode_utf8(file_contents) + if (utf8File is None) : + raise Exception("The file %s is not UTF-8" %fileName) + + lines = utf8File.split( "\r\n" if CRLF else "\n" ) + for line in lines: + line = line.replace("\r","").replace("\n","") + if(line != ""): + try: + validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) + except Exception as e: + raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace')) + + + +def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + """ + get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) + + +def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + Returns values from a textline. Points , [Confidences], [Transcriptions] + """ + confidence = 0.0 + transcription = ""; + points = [] + + numPoints = 4; + + if LTRB: + + numPoints = 4; + + if withTranscription and withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + if m == None : + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") + elif withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") + elif withTranscription: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") + else: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") + + xmin = int(m.group(1)) + ymin = int(m.group(2)) + xmax = int(m.group(3)) + ymax = int(m.group(4)) + if(xmax0 and imHeight>0): + validate_point_inside_bounds(xmin,ymin,imWidth,imHeight); + validate_point_inside_bounds(xmax,ymax,imWidth,imHeight); + + else: + + numPoints = 8; + + if withTranscription and withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") + elif withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") + elif withTranscription: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") + else: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") + + points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ] + + validate_clockwise_points(points) + + if (imWidth>0 and imHeight>0): + validate_point_inside_bounds(points[0],points[1],imWidth,imHeight); + validate_point_inside_bounds(points[2],points[3],imWidth,imHeight); + validate_point_inside_bounds(points[4],points[5],imWidth,imHeight); + validate_point_inside_bounds(points[6],points[7],imWidth,imHeight); + + + if withConfidence: + try: + confidence = float(m.group(numPoints+1)) + except ValueError: + raise Exception("Confidence value must be a float") + + if withTranscription: + posTranscription = numPoints + (2 if withConfidence else 1) + transcription = m.group(posTranscription) + m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription) + if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters + transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") + + return points,confidence,transcription + + +def validate_point_inside_bounds(x,y,imWidth,imHeight): + if(x<0 or x>imWidth): + raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight)) + if(y<0 or y>imHeight): + raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight)) + +def validate_clockwise_points(points): + """ + Validates that the points that the 4 points that dlimite a polygon are in clockwise order. + """ + + if len(points) != 8: + raise Exception("Points list not valid." + str(len(points))) + + point = [ + [int(points[0]) , int(points[1])], + [int(points[2]) , int(points[3])], + [int(points[4]) , int(points[5])], + [int(points[6]) , int(points[7])] + ] + edge = [ + ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]), + ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]), + ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]), + ( point[0][0] - point[3][0])*( point[0][1] + point[3][1]) + ] + + summatory = edge[0] + edge[1] + edge[2] + edge[3]; + if summatory>0: + raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") + +def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True): + """ + Returns all points, confindences and transcriptions of a file in lists. Valid line formats: + xmin,ymin,xmax,ymax,[confidence],[transcription] + x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] + """ + pointsList = [] + transcriptionsList = [] + confidencesList = [] + + lines = content.split( "\r\n" if CRLF else "\n" ) + for line in lines: + line = line.replace("\r","").replace("\n","") + if(line != "") : + points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight); + pointsList.append(points) + transcriptionsList.append(transcription) + confidencesList.append(confidence) + + if withConfidence and len(confidencesList)>0 and sort_by_confidences: + import numpy as np + sorted_ind = np.argsort(-np.array(confidencesList)) + confidencesList = [confidencesList[i] for i in sorted_ind] + pointsList = [pointsList[i] for i in sorted_ind] + transcriptionsList = [transcriptionsList[i] for i in sorted_ind] + + return pointsList,confidencesList,transcriptionsList + +def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True): + """ + This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. + Params: + p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results + """ + + if (p == None): + p = dict([s[1:].split('=') for s in sys.argv[1:]]) + if(len(sys.argv)<3): + print_help() + + evalParams = default_evaluation_params_fn() + if 'p' in p.keys(): + evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) + + resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'} + try: + validate_data_fn(p['g'], p['s'], evalParams) + evalData = evaluate_method_fn(p['g'], p['s'], evalParams) + resDict.update(evalData) + + except Exception as e: + resDict['Message']= str(e) + resDict['calculated']=False + + if 'o' in p: + if not os.path.exists(p['o']): + os.makedirs(p['o']) + + resultsOutputname = p['o'] + '/results.zip' + outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) + + del resDict['per_sample'] + if 'output_items' in resDict.keys(): + del resDict['output_items'] + + outZip.writestr('method.json',json.dumps(resDict)) + + if not resDict['calculated']: + if show_result: + sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n') + if 'o' in p: + outZip.close() + return resDict + + if 'o' in p: + if per_sample == True: + for k,v in evalData['per_sample'].items(): + outZip.writestr( k + '.json',json.dumps(v)) + + if 'output_items' in evalData.keys(): + for k, v in evalData['output_items'].items(): + outZip.writestr( k,v) + + outZip.close() + + if show_result: + sys.stdout.write("Calculated!") + sys.stdout.write(json.dumps(resDict['method'])) + + return resDict + + +def main_validation(default_evaluation_params_fn,validate_data_fn): + """ + This process validates a method + Params: + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + """ + try: + p = dict([s[1:].split('=') for s in sys.argv[1:]]) + evalParams = default_evaluation_params_fn() + if 'p' in p.keys(): + evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) + + validate_data_fn(p['g'], p['s'], evalParams) + print('SUCCESS') + sys.exit(0) + except Exception as e: + print(str(e)) + sys.exit(101) \ No newline at end of file diff --git a/evaluation/icdar2015/e2e/script.py b/evaluation/icdar2015/e2e/script.py new file mode 100644 index 0000000000000000000000000000000000000000..ca418eb4f3ccb4471e01e74bf5b9c4f0f073a4ca --- /dev/null +++ b/evaluation/icdar2015/e2e/script.py @@ -0,0 +1,461 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# encoding=utf8 +from collections import namedtuple +import rrc_evaluation_funcs +import importlib +from prepare_results import prepare_results_for_evaluation + +def evaluation_imports(): + """ + evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation. + """ + return { + 'Polygon':'plg', + 'numpy':'np' + } + +def default_evaluation_params(): + """ + default_evaluation_params: Default parameters to use for the validation and evaluation. + """ + return { + 'IOU_CONSTRAINT' :0.5, + 'AREA_PRECISION_CONSTRAINT' :0.5, + 'WORD_SPOTTING' :False, + 'MIN_LENGTH_CARE_WORD' :3, + 'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt', + 'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt', + 'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) + 'CRLF':False, # Lines are delimited by Windows CRLF format + 'CONFIDENCES':False, #Detections must include confidence value. MAP and MAR will be calculated, + 'SPECIAL_CHARACTERS':'!?.:,*"()·[]/\'', + 'ONLY_REMOVE_FIRST_LAST_CHARACTER' : True + } + +def validate_data(gtFilePath, submFilePath, evaluationParams): + """ + Method validate_data: validates that all files in the results folder are correct (have the correct name contents). + Validates also that there are no missing files in the folder. + If some error detected, the method raises the error + """ + gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) + + subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) + + #Validate format of GroundTruth + for k in gt: + rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True) + + #Validate format of results + for k in subm: + if (k in gt) == False : + raise Exception("The sample %s not present in GT" %k) + + rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES']) + + +def evaluate_method(gtFilePath, submFilePath, evaluationParams): + """ + Method evaluate_method: evaluate method and returns the results + Results. Dictionary with the following values: + - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } + - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } + """ + for module,alias in evaluation_imports().items(): + globals()[alias] = importlib.import_module(module) + + def polygon_from_points(points,correctOffset=False): + """ + Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 + """ + + if correctOffset: #this will substract 1 from the coordinates that correspond to the xmax and ymax + points[2] -= 1 + points[4] -= 1 + points[5] -= 1 + points[7] -= 1 + + resBoxes=np.empty([1,8],dtype='int32') + resBoxes[0,0]=int(points[0]) + resBoxes[0,4]=int(points[1]) + resBoxes[0,1]=int(points[2]) + resBoxes[0,5]=int(points[3]) + resBoxes[0,2]=int(points[4]) + resBoxes[0,6]=int(points[5]) + resBoxes[0,3]=int(points[6]) + resBoxes[0,7]=int(points[7]) + pointMat = resBoxes[0].reshape([2,4]).T + return plg.Polygon( pointMat) + + def rectangle_to_polygon(rect): + resBoxes=np.empty([1,8],dtype='int32') + resBoxes[0,0]=int(rect.xmin) + resBoxes[0,4]=int(rect.ymax) + resBoxes[0,1]=int(rect.xmin) + resBoxes[0,5]=int(rect.ymin) + resBoxes[0,2]=int(rect.xmax) + resBoxes[0,6]=int(rect.ymin) + resBoxes[0,3]=int(rect.xmax) + resBoxes[0,7]=int(rect.ymax) + + pointMat = resBoxes[0].reshape([2,4]).T + + return plg.Polygon( pointMat) + + def rectangle_to_points(rect): + points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)] + return points + + def get_union(pD,pG): + areaA = pD.area(); + areaB = pG.area(); + return areaA + areaB - get_intersection(pD, pG); + + def get_intersection_over_union(pD,pG): + try: + return get_intersection(pD, pG) / get_union(pD, pG); + except: + return 0 + + def get_intersection(pD,pG): + pInt = pD & pG + if len(pInt) == 0: + return 0 + return pInt.area() + + def compute_ap(confList, matchList,numGtCare): + correct = 0 + AP = 0 + if len(confList)>0: + confList = np.array(confList) + matchList = np.array(matchList) + sorted_ind = np.argsort(-confList) + confList = confList[sorted_ind] + matchList = matchList[sorted_ind] + for n in range(len(confList)): + match = matchList[n] + if match: + correct += 1 + AP += float(correct)/(n + 1) + + if numGtCare>0: + AP /= numGtCare + + return AP + + def transcription_match(transGt,transDet,specialCharacters='!?.:,*"()·[]/\'',onlyRemoveFirstLastCharacterGT=True): + + if onlyRemoveFirstLastCharacterGT: + #special characters in GT are allowed only at initial or final position + if (transGt==transDet): + return True + + if specialCharacters.find(transGt[0])>-1: + if transGt[1:]==transDet: + return True + + if specialCharacters.find(transGt[-1])>-1: + if transGt[0:len(transGt)-1]==transDet: + return True + + if specialCharacters.find(transGt[0])>-1 and specialCharacters.find(transGt[-1])>-1: + if transGt[1:len(transGt)-1]==transDet: + return True + return False + else: + #Special characters are removed from the begining and the end of both Detection and GroundTruth + while len(transGt)>0 and specialCharacters.find(transGt[0])>-1: + transGt = transGt[1:] + + while len(transDet)>0 and specialCharacters.find(transDet[0])>-1: + transDet = transDet[1:] + + while len(transGt)>0 and specialCharacters.find(transGt[-1])>-1 : + transGt = transGt[0:len(transGt)-1] + + while len(transDet)>0 and specialCharacters.find(transDet[-1])>-1: + transDet = transDet[0:len(transDet)-1] + + return transGt == transDet + + + def include_in_dictionary(transcription): + """ + Function used in Word Spotting that finds if the Ground Truth transcription meets the rules to enter into the dictionary. If not, the transcription will be cared as don't care + """ + #special case 's at final + if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S": + transcription = transcription[0:len(transcription)-2] + + #hypens at init or final of the word + transcription = transcription.strip('-'); + + specialCharacters = "'!?.:,*\"()·[]/"; + for character in specialCharacters: + transcription = transcription.replace(character,' ') + + transcription = transcription.strip() + + if len(transcription) != len(transcription.replace(" ","")) : + return False; + + if len(transcription) < evaluationParams['MIN_LENGTH_CARE_WORD']: + return False; + + notAllowed = "×÷·"; + + range1 = [ ord(u'a'), ord(u'z') ] + range2 = [ ord(u'A'), ord(u'Z') ] + range3 = [ ord(u'À'), ord(u'ƿ') ] + range4 = [ ord(u'DŽ'), ord(u'ɿ') ] + range5 = [ ord(u'Ά'), ord(u'Ͽ') ] + range6 = [ ord(u'-'), ord(u'-') ] + + for char in transcription : + charCode = ord(char) + if(notAllowed.find(char) != -1): + return False + + valid = ( charCode>=range1[0] and charCode<=range1[1] ) or ( charCode>=range2[0] and charCode<=range2[1] ) or ( charCode>=range3[0] and charCode<=range3[1] ) or ( charCode>=range4[0] and charCode<=range4[1] ) or ( charCode>=range5[0] and charCode<=range5[1] ) or ( charCode>=range6[0] and charCode<=range6[1] ) + if valid == False: + return False + + return True + + def include_in_dictionary_transcription(transcription): + """ + Function applied to the Ground Truth transcriptions used in Word Spotting. It removes special characters or terminations + """ + #special case 's at final + if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S": + transcription = transcription[0:len(transcription)-2] + + #hypens at init or final of the word + transcription = transcription.strip('-'); + + specialCharacters = "'!?.:,*\"()·[]/"; + for character in specialCharacters: + transcription = transcription.replace(character,' ') + + transcription = transcription.strip() + + return transcription + + perSampleMetrics = {} + + matchedSum = 0 + + Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + + gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID']) + subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True) + + numGlobalCareGt = 0; + numGlobalCareDet = 0; + + arrGlobalConfidences = []; + arrGlobalMatches = []; + + for resFile in gt: + + gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile]) + if (gtFile is None) : + raise Exception("The file %s is not UTF-8" %resFile) + + recall = 0 + precision = 0 + hmean = 0 + detCorrect = 0 + iouMat = np.empty([1,1]) + gtPols = [] + detPols = [] + gtTrans = [] + detTrans = [] + gtPolPoints = [] + detPolPoints = [] + gtDontCarePolsNum = [] #Array of Ground Truth Polygons' keys marked as don't Care + detDontCarePolsNum = [] #Array of Detected Polygons' matched with a don't Care GT + detMatchedNums = [] + pairs = [] + + arrSampleConfidences = []; + arrSampleMatch = []; + sampleAP = 0; + + evaluationLog = "" + + pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False) + for n in range(len(pointsList)): + points = pointsList[n] + transcription = transcriptionsList[n] + dontCare = transcription == "###" + if evaluationParams['LTRB']: + gtRect = Rectangle(*points) + gtPol = rectangle_to_polygon(gtRect) + else: + gtPol = polygon_from_points(points) + gtPols.append(gtPol) + gtPolPoints.append(points) + + #On word spotting we will filter some transcriptions with special characters + if evaluationParams['WORD_SPOTTING'] : + if dontCare == False : + if include_in_dictionary(transcription) == False : + dontCare = True + else: + transcription = include_in_dictionary_transcription(transcription) + + gtTrans.append(transcription) + if dontCare: + gtDontCarePolsNum.append( len(gtPols)-1 ) + + evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n") + + if resFile in subm: + + detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile]) + + pointsList,confidencesList,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES']) + + for n in range(len(pointsList)): + points = pointsList[n] + transcription = transcriptionsList[n] + + if evaluationParams['LTRB']: + detRect = Rectangle(*points) + detPol = rectangle_to_polygon(detRect) + else: + detPol = polygon_from_points(points) + detPols.append(detPol) + detPolPoints.append(points) + detTrans.append(transcription) + + if len(gtDontCarePolsNum)>0 : + for dontCarePol in gtDontCarePolsNum: + dontCarePol = gtPols[dontCarePol] + intersected_area = get_intersection(dontCarePol,detPol) + pdDimensions = detPol.area() + precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions + if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ): + detDontCarePolsNum.append( len(detPols)-1 ) + break + + evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n") + + if len(gtPols)>0 and len(detPols)>0: + #Calculate IoU and precision matrixs + outputShape=[len(gtPols),len(detPols)] + iouMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtPols),np.int8) + detRectMat = np.zeros(len(detPols),np.int8) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = gtPols[gtNum] + pD = detPols[detNum] + iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG) + + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum : + if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + #detection matched only if transcription is equal + if evaluationParams['WORD_SPOTTING']: + correct = gtTrans[gtNum].upper() == detTrans[detNum].upper() + else: + correct = transcription_match(gtTrans[gtNum].upper(),detTrans[detNum].upper(),evaluationParams['SPECIAL_CHARACTERS'],evaluationParams['ONLY_REMOVE_FIRST_LAST_CHARACTER'])==True + detCorrect += (1 if correct else 0) + if correct: + detMatchedNums.append(detNum) + pairs.append({'gt':gtNum,'det':detNum,'correct':correct}) + evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + " trans. correct: " + str(correct) + "\n" + + if evaluationParams['CONFIDENCES']: + for detNum in range(len(detPols)): + if detNum not in detDontCarePolsNum : + #we exclude the don't care detections + match = detNum in detMatchedNums + + arrSampleConfidences.append(confidencesList[detNum]) + arrSampleMatch.append(match) + + arrGlobalConfidences.append(confidencesList[detNum]); + arrGlobalMatches.append(match); + + numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) + numDetCare = (len(detPols) - len(detDontCarePolsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if numDetCare >0 else float(1) + sampleAP = precision + else: + recall = float(detCorrect) / numGtCare + precision = 0 if numDetCare==0 else float(detCorrect) / numDetCare + if evaluationParams['CONFIDENCES']: + sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare ) + + hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall) + + matchedSum += detCorrect + numGlobalCareGt += numGtCare + numGlobalCareDet += numDetCare + + perSampleMetrics[resFile] = { + 'precision':precision, + 'recall':recall, + 'hmean':hmean, + 'pairs':pairs, + 'AP':sampleAP, + 'iouMat':[] if len(detPols)>100 else iouMat.tolist(), + 'gtPolPoints':gtPolPoints, + 'detPolPoints':detPolPoints, + 'gtTrans':gtTrans, + 'detTrans':detTrans, + 'gtDontCare':gtDontCarePolsNum, + 'detDontCare':detDontCarePolsNum, + 'evaluationParams': evaluationParams, + 'evaluationLog': evaluationLog + } + + # Compute AP + AP = 0 + if evaluationParams['CONFIDENCES']: + AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt) + + methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt + methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet + methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision) + + methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP } + + resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics} + + + return resDict; + + + +if __name__=='__main__': + ''' + results_dir: result directory + score_det: score of detection bounding box + score_rec: score of the mask recognition branch + score_rec_seq: score of the sequence recognition branch + lexicon_type: 1 for generic; 2 for weak; 3 for strong + ''' + results_dir = '../../../output/mixtrain/inference/icdar_2015_test/model_0250000_1440_results/' + lexicon_type = 3 + score_det = 0.01 + score_rec = 0.4 + # score_rec_seq set to 0.7 for lexicon_type 3 or 2; 0.8 for lexicon_type 1 + score_rec_seq = 0.7 + evaluate_result_path = prepare_results_for_evaluation(results_dir, + lexicon_type=lexicon_type, cache_dir='./cache_files', + score_det=score_det, score_rec=score_rec, score_rec_seq=score_rec_seq) + p = { + 'g': "../gt.zip", + 's': evaluate_result_path + } + rrc_evaluation_funcs.main_evaluation(p,default_evaluation_params,validate_data,evaluate_method) \ No newline at end of file diff --git a/evaluation/icdar2015/gt.zip b/evaluation/icdar2015/gt.zip new file mode 100644 index 0000000000000000000000000000000000000000..aaa2ac31bbe11d346917f9d172d7b7decb40078f Binary files /dev/null and b/evaluation/icdar2015/gt.zip differ diff --git a/evaluation/rotated_icdar2013/e2e/prepare_results.py b/evaluation/rotated_icdar2013/e2e/prepare_results.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0d71dcc642b8ec95f16ce79ce599c122a95836 --- /dev/null +++ b/evaluation/rotated_icdar2013/e2e/prepare_results.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import sys +import os +sys.path.append('./') +import shapely +from shapely.geometry import Polygon,MultiPoint +import numpy as np +import editdistance +sys.path.append('../../') +from weighted_editdistance import weighted_edit_distance +from tqdm import tqdm +try: + import pickle +except ImportError: + import cPickle as pickle + +def list_from_str(st): + line = st.split(',') + # box[0:4], polygon[4:12], word, seq_word, detection_score, rec_socre, seq_score, char_score_path + new_line = [float(a) for a in line[4:12]]+[float(line[-4])]+[line[-5]]+[line[-6]]+[float(line[-3])]+[float(line[-2])] + [line[-1]] + return new_line + +def polygon_from_list(line): + """ + Create a shapely polygon object from gt or dt line. + """ + polygon_points = np.array(line).reshape(4, 2) + polygon = Polygon(polygon_points).convex_hull + return polygon + +def polygon_iou(list1, list2): + """ + Intersection over union between two shapely polygons. + """ + polygon_points1 = np.array(list1).reshape(4, 2) + poly1 = Polygon(polygon_points1).convex_hull + polygon_points2 = np.array(list2).reshape(4, 2) + poly2 = Polygon(polygon_points2).convex_hull + union_poly = np.concatenate((polygon_points1,polygon_points2)) + if not poly1.intersects(poly2): # this test is fast and can accelerate calculation + iou = 0 + else: + try: + inter_area = poly1.intersection(poly2).area + #union_area = poly1.area + poly2.area - inter_area + union_area = MultiPoint(union_poly).convex_hull.area + iou = float(inter_area) / (union_area+1e-6) + except shapely.geos.TopologicalError: + print('shapely.geos.TopologicalError occured, iou set to 0') + iou = 0 + return iou + +def nms(boxes,overlap): + rec_scores = [b[-2] for b in boxes] + indices = sorted(range(len(rec_scores)), key=lambda k: -rec_scores[k]) + box_num = len(boxes) + nms_flag = [True]*box_num + for i in range(box_num): + ii = indices[i] + if not nms_flag[ii]: + continue + for j in range(box_num): + jj = indices[j] + if j == i: + continue + if not nms_flag[jj]: + continue + box1 = boxes[ii] + box2 = boxes[jj] + box1_score = rec_scores[ii] + box2_score = rec_scores[jj] + str1 = box1[9] + str2 = box2[9] + box_i = [box1[0],box1[1],box1[4],box1[5]] + box_j = [box2[0],box2[1],box2[4],box2[5]] + poly1 = polygon_from_list(box1[0:8]) + poly2 = polygon_from_list(box2[0:8]) + iou = polygon_iou(box1[0:8],box2[0:8]) + thresh = overlap + + if iou > thresh: + if box1_score > box2_score: + nms_flag[jj] = False + if box1_score == box2_score and poly1.area > poly2.area: + nms_flag[jj] = False + if box1_score == box2_score and poly1.area<=poly2.area: + nms_flag[ii] = False + break + + return nms_flag + +def packing(save_dir, cache_dir, pack_name): + files = os.listdir(save_dir) + if not os.path.exists(cache_dir): + os.mkdir(cache_dir) + os.system('zip -r -q -j '+os.path.join(cache_dir, pack_name+'.zip')+' '+save_dir+'/*') + +def test_single(results_dir,lexicon_type=3,cache_dir='./cache_dir',score_det=0.5,score_rec=0.5,score_rec_seq=0.5,overlap=0.2, use_lexicon=True, weighted_ed=True, use_seq=False, use_char=False, mix=False): + ''' + results_dir: result directory + score_det: score of detection bounding box + score_rec: score of the mask recognition branch + socre_rec_seq: score of the sequence recognition branch + overlap: overlap threshold used for nms + lexicon_type: 1 for generic; 2 for weak; 3 for strong + use_seq: use the recognition result of sequence branch + use_mix: use both the recognition result of the mask and sequence branches, selected by score + ''' + print('score_det:', 'score_det:', score_det, 'score_rec:', score_rec, 'score_rec_seq:', score_rec_seq, 'lexicon_type:', lexicon_type, 'weighted_ed:', weighted_ed, 'use_seq:', use_seq, 'use_char:', use_char, 'mix:', mix) + if not os.path.exists(cache_dir): + os.mkdir(cache_dir) + nms_dir = os.path.join(cache_dir,str(score_det)+'_'+str(score_rec)+'_'+str(score_rec_seq)) + if not os.path.exists(nms_dir): + os.mkdir(nms_dir) + if lexicon_type==1: + # generic lexicon + lexicon_path = '../../lexicons/ic13/GenericVocabulary_new.txt' + lexicon_fid=open(lexicon_path, 'r') + pair_list = open('../../lexicons/ic13/GenericVocabulary_pair_list.txt', 'r') + pairs = dict() + for line in pair_list.readlines(): + line=line.strip() + word = line.split(' ')[0].upper() + word_gt = line[len(word)+1:] + pairs[word] = word_gt + lexicon_fid=open(lexicon_path, 'r') + lexicon=[] + for line in lexicon_fid.readlines(): + line=line.strip() + lexicon.append(line) + if lexicon_type==2: + # weak lexicon + lexicon_path = '../../lexicons/ic13/ch4_test_vocabulary_new.txt' + lexicon_fid=open(lexicon_path, 'r') + pair_list = open('../../lexicons/ic13/ch4_test_vocabulary_pair_list.txt', 'r') + pairs = dict() + for line in pair_list.readlines(): + line=line.strip() + word = line.split(' ')[0].upper() + word_gt = line[len(word)+1:] + pairs[word] = word_gt + lexicon_fid=open(lexicon_path, 'r') + lexicon=[] + for line in lexicon_fid.readlines(): + line=line.strip() + lexicon.append(line) + + for i in tqdm(range(1,234)): + img = 'img_'+str(i)+'.jpg' + gt_img = 'gt_img_'+str(i)+'.txt' + if lexicon_type==3: + # weak + lexicon_path = '../../lexicons/ic13/new_strong_lexicon/new_voc_img_' + str(i) + '.txt' + lexicon_fid=open(lexicon_path, 'r') + pair_list = open('../../lexicons/ic13/new_strong_lexicon/pair_voc_img_' + str(i) + '.txt', 'r') + pairs = dict() + for line in pair_list.readlines(): + line=line.strip() + word = line.split(' ')[0].upper() + word_gt = line[len(word)+1:] + pairs[word] = word_gt + lexicon_fid=open(lexicon_path, 'r') + lexicon=[] + for line in lexicon_fid.readlines(): + line=line.strip() + lexicon.append(line) + result_path = os.path.join(results_dir,'res_img_'+str(i)+'.txt') + if os.path.isfile(result_path): + with open(result_path,'r') as f: + dt_lines = [a.strip() for a in f.readlines()] + dt_lines = [list_from_str(dt) for dt in dt_lines] + else: + dt_lines = [] + dt_lines = [dt for dt in dt_lines if dt[-2]>score_rec_seq and dt[-3]>score_rec and dt[-6]>score_det] + nms_flag = nms(dt_lines,overlap) + boxes = [] + for k in range(len(dt_lines)): + dt = dt_lines[k] + if nms_flag[k]: + if dt not in boxes: + boxes.append(dt) + + with open(os.path.join(nms_dir,'res_img_'+str(i)+'.txt'),'w') as f: + for g in boxes: + gt_coors = [int(b) for b in g[0:8]] + with open('../../../' + g[-1], "rb") as input_file: + # with open(g[-1], "rb") as input_file: + dict_scores = pickle.load(input_file) + if use_char and use_seq: + if g[-2]>g[-3]: + word = g[-5] + scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1) + else: + word = g[-4] + scores = dict_scores['seg_char_scores'] + elif use_seq: + word = g[-5] + scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1) + else: + word = g[-4] + scores = dict_scores['seg_char_scores'] + if not use_lexicon: + match_word = word + match_dist = 0. + else: + match_word, match_dist = find_match_word(word, lexicon, pairs, scores, use_lexicon, weighted_ed) + if match_dist<1.5 or lexicon_type==1: + gt_coor_strs = [str(a) for a in gt_coors]+ [match_word] + f.write(','.join(gt_coor_strs)+'\r\n') + + pack_name = str(score_det)+'_'+str(score_rec)+'_over'+str(overlap) + + packing(nms_dir,cache_dir,pack_name) + submit_file_path = os.path.join(cache_dir, pack_name+'.zip') + return submit_file_path + +def find_match_word(rec_str, lexicon, pairs, scores_numpy, use_ed = True, weighted_ed = False): + if not use_ed: + return rec_str + rec_str = rec_str.upper() + dist_min = 100 + dist_min_pre = 100 + match_word = '' + match_dist = 100 + if not weighted_ed: + for word in lexicon: + word = word.upper() + ed = editdistance.eval(rec_str, word) + length_dist = abs(len(word) - len(rec_str)) + # dist = ed + length_dist + dist = ed + if dist -s= [-o= -p=]' %sys.argv[0]) + sys.exit(2) + + +def load_zip_file_keys(file,fileNameRegExp=''): + """ + Returns an array with the entries of the ZIP file that match with the regular expression. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + """ + try: + archive=zipfile.ZipFile(file, mode='r', allowZip64=True) + except : + raise Exception('Error loading the ZIP archive.') + + pairs = [] + + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp!="": + m = re.match(fileNameRegExp,name) + if m == None: + addFile = False + else: + if len(m.groups())>0: + keyName = m.group(1) + + if addFile: + pairs.append( keyName ) + + return pairs + + +def load_zip_file(file,fileNameRegExp='',allEntries=False): + """ + Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + allEntries validates that all entries in the ZIP file pass the fileNameRegExp + """ + try: + archive=zipfile.ZipFile(file, mode='r', allowZip64=True) + except : + raise Exception('Error loading the ZIP archive') + + pairs = [] + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp!="": + m = re.match(fileNameRegExp,name) + if m == None: + addFile = False + else: + if len(m.groups())>0: + keyName = m.group(1) + + if addFile: + pairs.append( [ keyName , archive.read(name)] ) + else: + if allEntries: + raise Exception('ZIP entry not valid: %s' %name) + + return dict(pairs) + +def decode_utf8(raw): + """ + Returns a Unicode object on success, or None on failure + """ + try: + raw = codecs.decode(raw,'utf-8', 'replace') + #extracts BOM if exists + raw = raw.encode('utf8') + if raw.startswith(codecs.BOM_UTF8): + raw = raw.replace(codecs.BOM_UTF8, '', 1) + return raw.decode('utf-8') + except: + return None + +def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): + """ + This function validates that all lines of the file calling the Line validation function for each line + """ + utf8File = decode_utf8(file_contents) + if (utf8File is None) : + raise Exception("The file %s is not UTF-8" %fileName) + + lines = utf8File.split( "\r\n" if CRLF else "\n" ) + for line in lines: + line = line.replace("\r","").replace("\n","") + if(line != ""): + try: + validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) + except Exception as e: + raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace')) + + + +def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + """ + get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) + + +def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + Returns values from a textline. Points , [Confidences], [Transcriptions] + """ + confidence = 0.0 + transcription = ""; + points = [] + + numPoints = 4; + + if LTRB: + + numPoints = 4; + + if withTranscription and withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + if m == None : + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") + elif withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") + elif withTranscription: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") + else: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") + + xmin = int(m.group(1)) + ymin = int(m.group(2)) + xmax = int(m.group(3)) + ymax = int(m.group(4)) + if(xmax0 and imHeight>0): + validate_point_inside_bounds(xmin,ymin,imWidth,imHeight); + validate_point_inside_bounds(xmax,ymax,imWidth,imHeight); + + else: + + numPoints = 8; + + if withTranscription and withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") + elif withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") + elif withTranscription: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") + else: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") + + points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ] + + validate_clockwise_points(points) + + if (imWidth>0 and imHeight>0): + validate_point_inside_bounds(points[0],points[1],imWidth,imHeight); + validate_point_inside_bounds(points[2],points[3],imWidth,imHeight); + validate_point_inside_bounds(points[4],points[5],imWidth,imHeight); + validate_point_inside_bounds(points[6],points[7],imWidth,imHeight); + + + if withConfidence: + try: + confidence = float(m.group(numPoints+1)) + except ValueError: + raise Exception("Confidence value must be a float") + + if withTranscription: + posTranscription = numPoints + (2 if withConfidence else 1) + transcription = m.group(posTranscription) + m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription) + if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters + transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") + + return points,confidence,transcription + + +def validate_point_inside_bounds(x,y,imWidth,imHeight): + if(x<0 or x>imWidth): + raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight)) + if(y<0 or y>imHeight): + raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight)) + +def validate_clockwise_points(points): + """ + Validates that the points that the 4 points that dlimite a polygon are in clockwise order. + """ + + if len(points) != 8: + raise Exception("Points list not valid." + str(len(points))) + + point = [ + [int(points[0]) , int(points[1])], + [int(points[2]) , int(points[3])], + [int(points[4]) , int(points[5])], + [int(points[6]) , int(points[7])] + ] + edge = [ + ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]), + ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]), + ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]), + ( point[0][0] - point[3][0])*( point[0][1] + point[3][1]) + ] + + summatory = edge[0] + edge[1] + edge[2] + edge[3]; + if summatory>0: + raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") + +def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True): + """ + Returns all points, confindences and transcriptions of a file in lists. Valid line formats: + xmin,ymin,xmax,ymax,[confidence],[transcription] + x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] + """ + pointsList = [] + transcriptionsList = [] + confidencesList = [] + + lines = content.split( "\r\n" if CRLF else "\n" ) + for line in lines: + line = line.replace("\r","").replace("\n","") + if(line != "") : + points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight); + pointsList.append(points) + transcriptionsList.append(transcription) + confidencesList.append(confidence) + + if withConfidence and len(confidencesList)>0 and sort_by_confidences: + import numpy as np + sorted_ind = np.argsort(-np.array(confidencesList)) + confidencesList = [confidencesList[i] for i in sorted_ind] + pointsList = [pointsList[i] for i in sorted_ind] + transcriptionsList = [transcriptionsList[i] for i in sorted_ind] + + return pointsList,confidencesList,transcriptionsList + +def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True): + """ + This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. + Params: + p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results + """ + + if (p == None): + p = dict([s[1:].split('=') for s in sys.argv[1:]]) + if(len(sys.argv)<3): + print_help() + + evalParams = default_evaluation_params_fn() + if 'p' in p.keys(): + evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) + + resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'} + try: + validate_data_fn(p['g'], p['s'], evalParams) + evalData = evaluate_method_fn(p['g'], p['s'], evalParams) + resDict.update(evalData) + + except Exception as e: + resDict['Message']= str(e) + resDict['calculated']=False + + if 'o' in p: + if not os.path.exists(p['o']): + os.makedirs(p['o']) + + resultsOutputname = p['o'] + '/results.zip' + outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) + + del resDict['per_sample'] + if 'output_items' in resDict.keys(): + del resDict['output_items'] + + outZip.writestr('method.json',json.dumps(resDict)) + + if not resDict['calculated']: + if show_result: + sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n') + if 'o' in p: + outZip.close() + return resDict + + if 'o' in p: + if per_sample == True: + for k,v in evalData['per_sample'].items(): + outZip.writestr( k + '.json',json.dumps(v)) + + if 'output_items' in evalData.keys(): + for k, v in evalData['output_items'].items(): + outZip.writestr( k,v) + + outZip.close() + + if show_result: + sys.stdout.write("Calculated!") + sys.stdout.write(json.dumps(resDict['method'])) + + return resDict + + +def main_validation(default_evaluation_params_fn,validate_data_fn): + """ + This process validates a method + Params: + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + """ + try: + p = dict([s[1:].split('=') for s in sys.argv[1:]]) + evalParams = default_evaluation_params_fn() + if 'p' in p.keys(): + evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) + + validate_data_fn(p['g'], p['s'], evalParams) + print('SUCCESS') + sys.exit(0) + except Exception as e: + print(str(e)) + sys.exit(101) \ No newline at end of file diff --git a/evaluation/rotated_icdar2013/e2e/script.py b/evaluation/rotated_icdar2013/e2e/script.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b60296315e6749838c7268452d630a5efac366 --- /dev/null +++ b/evaluation/rotated_icdar2013/e2e/script.py @@ -0,0 +1,460 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# encoding=utf8 +from collections import namedtuple +import rrc_evaluation_funcs +import importlib +from prepare_results import prepare_results_for_evaluation + +def evaluation_imports(): + """ + evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation. + """ + return { + 'Polygon':'plg', + 'numpy':'np' + } + +def default_evaluation_params(): + """ + default_evaluation_params: Default parameters to use for the validation and evaluation. + """ + return { + 'IOU_CONSTRAINT' :0.5, + 'AREA_PRECISION_CONSTRAINT' :0.5, + 'WORD_SPOTTING' :False, + 'MIN_LENGTH_CARE_WORD' :3, + 'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt', + 'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt', + 'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) + 'CRLF':False, # Lines are delimited by Windows CRLF format + 'CONFIDENCES':False, #Detections must include confidence value. MAP and MAR will be calculated, + 'SPECIAL_CHARACTERS':'!?.:,*"()·[]/\'', + 'ONLY_REMOVE_FIRST_LAST_CHARACTER' : True + } + +def validate_data(gtFilePath, submFilePath, evaluationParams): + """ + Method validate_data: validates that all files in the results folder are correct (have the correct name contents). + Validates also that there are no missing files in the folder. + If some error detected, the method raises the error + """ + gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) + + subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) + + #Validate format of GroundTruth + for k in gt: + rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True) + + #Validate format of results + for k in subm: + if (k in gt) == False : + raise Exception("The sample %s not present in GT" %k) + + rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES']) + + +def evaluate_method(gtFilePath, submFilePath, evaluationParams): + """ + Method evaluate_method: evaluate method and returns the results + Results. Dictionary with the following values: + - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } + - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } + """ + for module,alias in evaluation_imports().items(): + globals()[alias] = importlib.import_module(module) + + def polygon_from_points(points,correctOffset=False): + """ + Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 + """ + + if correctOffset: #this will substract 1 from the coordinates that correspond to the xmax and ymax + points[2] -= 1 + points[4] -= 1 + points[5] -= 1 + points[7] -= 1 + + resBoxes=np.empty([1,8],dtype='int32') + resBoxes[0,0]=int(points[0]) + resBoxes[0,4]=int(points[1]) + resBoxes[0,1]=int(points[2]) + resBoxes[0,5]=int(points[3]) + resBoxes[0,2]=int(points[4]) + resBoxes[0,6]=int(points[5]) + resBoxes[0,3]=int(points[6]) + resBoxes[0,7]=int(points[7]) + pointMat = resBoxes[0].reshape([2,4]).T + return plg.Polygon( pointMat) + + def rectangle_to_polygon(rect): + resBoxes=np.empty([1,8],dtype='int32') + resBoxes[0,0]=int(rect.xmin) + resBoxes[0,4]=int(rect.ymax) + resBoxes[0,1]=int(rect.xmin) + resBoxes[0,5]=int(rect.ymin) + resBoxes[0,2]=int(rect.xmax) + resBoxes[0,6]=int(rect.ymin) + resBoxes[0,3]=int(rect.xmax) + resBoxes[0,7]=int(rect.ymax) + + pointMat = resBoxes[0].reshape([2,4]).T + + return plg.Polygon( pointMat) + + def rectangle_to_points(rect): + points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)] + return points + + def get_union(pD,pG): + areaA = pD.area(); + areaB = pG.area(); + return areaA + areaB - get_intersection(pD, pG); + + def get_intersection_over_union(pD,pG): + try: + return get_intersection(pD, pG) / get_union(pD, pG); + except: + return 0 + + def get_intersection(pD,pG): + pInt = pD & pG + if len(pInt) == 0: + return 0 + return pInt.area() + + def compute_ap(confList, matchList,numGtCare): + correct = 0 + AP = 0 + if len(confList)>0: + confList = np.array(confList) + matchList = np.array(matchList) + sorted_ind = np.argsort(-confList) + confList = confList[sorted_ind] + matchList = matchList[sorted_ind] + for n in range(len(confList)): + match = matchList[n] + if match: + correct += 1 + AP += float(correct)/(n + 1) + + if numGtCare>0: + AP /= numGtCare + + return AP + + def transcription_match(transGt,transDet,specialCharacters='!?.:,*"()·[]/\'',onlyRemoveFirstLastCharacterGT=True): + + if onlyRemoveFirstLastCharacterGT: + #special characters in GT are allowed only at initial or final position + if (transGt==transDet): + return True + + if specialCharacters.find(transGt[0])>-1: + if transGt[1:]==transDet: + return True + + if specialCharacters.find(transGt[-1])>-1: + if transGt[0:len(transGt)-1]==transDet: + return True + + if specialCharacters.find(transGt[0])>-1 and specialCharacters.find(transGt[-1])>-1: + if transGt[1:len(transGt)-1]==transDet: + return True + return False + else: + #Special characters are removed from the begining and the end of both Detection and GroundTruth + while len(transGt)>0 and specialCharacters.find(transGt[0])>-1: + transGt = transGt[1:] + + while len(transDet)>0 and specialCharacters.find(transDet[0])>-1: + transDet = transDet[1:] + + while len(transGt)>0 and specialCharacters.find(transGt[-1])>-1 : + transGt = transGt[0:len(transGt)-1] + + while len(transDet)>0 and specialCharacters.find(transDet[-1])>-1: + transDet = transDet[0:len(transDet)-1] + + return transGt == transDet + + + def include_in_dictionary(transcription): + """ + Function used in Word Spotting that finds if the Ground Truth transcription meets the rules to enter into the dictionary. If not, the transcription will be cared as don't care + """ + #special case 's at final + if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S": + transcription = transcription[0:len(transcription)-2] + + #hypens at init or final of the word + transcription = transcription.strip('-'); + + specialCharacters = "'!?.:,*\"()·[]/"; + for character in specialCharacters: + transcription = transcription.replace(character,' ') + + transcription = transcription.strip() + + if len(transcription) != len(transcription.replace(" ","")) : + return False; + + if len(transcription) < evaluationParams['MIN_LENGTH_CARE_WORD']: + return False; + + notAllowed = "×÷·"; + + range1 = [ ord(u'a'), ord(u'z') ] + range2 = [ ord(u'A'), ord(u'Z') ] + range3 = [ ord(u'À'), ord(u'ƿ') ] + range4 = [ ord(u'DŽ'), ord(u'ɿ') ] + range5 = [ ord(u'Ά'), ord(u'Ͽ') ] + range6 = [ ord(u'-'), ord(u'-') ] + + for char in transcription : + charCode = ord(char) + if(notAllowed.find(char) != -1): + return False + + valid = ( charCode>=range1[0] and charCode<=range1[1] ) or ( charCode>=range2[0] and charCode<=range2[1] ) or ( charCode>=range3[0] and charCode<=range3[1] ) or ( charCode>=range4[0] and charCode<=range4[1] ) or ( charCode>=range5[0] and charCode<=range5[1] ) or ( charCode>=range6[0] and charCode<=range6[1] ) + if valid == False: + return False + + return True + + def include_in_dictionary_transcription(transcription): + """ + Function applied to the Ground Truth transcriptions used in Word Spotting. It removes special characters or terminations + """ + #special case 's at final + if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S": + transcription = transcription[0:len(transcription)-2] + + #hypens at init or final of the word + transcription = transcription.strip('-'); + + specialCharacters = "'!?.:,*\"()·[]/"; + for character in specialCharacters: + transcription = transcription.replace(character,' ') + + transcription = transcription.strip() + + return transcription + + perSampleMetrics = {} + + matchedSum = 0 + + Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + + gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID']) + subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True) + + numGlobalCareGt = 0; + numGlobalCareDet = 0; + + arrGlobalConfidences = []; + arrGlobalMatches = []; + + for resFile in gt: + + gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile]) + if (gtFile is None) : + raise Exception("The file %s is not UTF-8" %resFile) + + recall = 0 + precision = 0 + hmean = 0 + detCorrect = 0 + iouMat = np.empty([1,1]) + gtPols = [] + detPols = [] + gtTrans = [] + detTrans = [] + gtPolPoints = [] + detPolPoints = [] + gtDontCarePolsNum = [] #Array of Ground Truth Polygons' keys marked as don't Care + detDontCarePolsNum = [] #Array of Detected Polygons' matched with a don't Care GT + detMatchedNums = [] + pairs = [] + + arrSampleConfidences = []; + arrSampleMatch = []; + sampleAP = 0; + + evaluationLog = "" + + pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False) + for n in range(len(pointsList)): + points = pointsList[n] + transcription = transcriptionsList[n] + dontCare = transcription == "###" + if evaluationParams['LTRB']: + gtRect = Rectangle(*points) + gtPol = rectangle_to_polygon(gtRect) + else: + gtPol = polygon_from_points(points) + gtPols.append(gtPol) + gtPolPoints.append(points) + + #On word spotting we will filter some transcriptions with special characters + if evaluationParams['WORD_SPOTTING'] : + if dontCare == False : + if include_in_dictionary(transcription) == False : + dontCare = True + else: + transcription = include_in_dictionary_transcription(transcription) + + gtTrans.append(transcription) + if dontCare: + gtDontCarePolsNum.append( len(gtPols)-1 ) + + evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n") + + if resFile in subm: + + detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile]) + + pointsList,confidencesList,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES']) + + for n in range(len(pointsList)): + points = pointsList[n] + transcription = transcriptionsList[n] + + if evaluationParams['LTRB']: + detRect = Rectangle(*points) + detPol = rectangle_to_polygon(detRect) + else: + detPol = polygon_from_points(points) + detPols.append(detPol) + detPolPoints.append(points) + detTrans.append(transcription) + + if len(gtDontCarePolsNum)>0 : + for dontCarePol in gtDontCarePolsNum: + dontCarePol = gtPols[dontCarePol] + intersected_area = get_intersection(dontCarePol,detPol) + pdDimensions = detPol.area() + precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions + if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ): + detDontCarePolsNum.append( len(detPols)-1 ) + break + + evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n") + + if len(gtPols)>0 and len(detPols)>0: + #Calculate IoU and precision matrixs + outputShape=[len(gtPols),len(detPols)] + iouMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtPols),np.int8) + detRectMat = np.zeros(len(detPols),np.int8) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = gtPols[gtNum] + pD = detPols[detNum] + iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG) + + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum : + if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + #detection matched only if transcription is equal + if evaluationParams['WORD_SPOTTING']: + correct = gtTrans[gtNum].upper() == detTrans[detNum].upper() + else: + correct = transcription_match(gtTrans[gtNum].upper(),detTrans[detNum].upper(),evaluationParams['SPECIAL_CHARACTERS'],evaluationParams['ONLY_REMOVE_FIRST_LAST_CHARACTER'])==True + detCorrect += (1 if correct else 0) + if correct: + detMatchedNums.append(detNum) + pairs.append({'gt':gtNum,'det':detNum,'correct':correct}) + evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + " trans. correct: " + str(correct) + "\n" + + if evaluationParams['CONFIDENCES']: + for detNum in range(len(detPols)): + if detNum not in detDontCarePolsNum : + #we exclude the don't care detections + match = detNum in detMatchedNums + + arrSampleConfidences.append(confidencesList[detNum]) + arrSampleMatch.append(match) + + arrGlobalConfidences.append(confidencesList[detNum]); + arrGlobalMatches.append(match); + + numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) + numDetCare = (len(detPols) - len(detDontCarePolsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if numDetCare >0 else float(1) + sampleAP = precision + else: + recall = float(detCorrect) / numGtCare + precision = 0 if numDetCare==0 else float(detCorrect) / numDetCare + if evaluationParams['CONFIDENCES']: + sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare ) + + hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall) + + matchedSum += detCorrect + numGlobalCareGt += numGtCare + numGlobalCareDet += numDetCare + + perSampleMetrics[resFile] = { + 'precision':precision, + 'recall':recall, + 'hmean':hmean, + 'pairs':pairs, + 'AP':sampleAP, + 'iouMat':[] if len(detPols)>100 else iouMat.tolist(), + 'gtPolPoints':gtPolPoints, + 'detPolPoints':detPolPoints, + 'gtTrans':gtTrans, + 'detTrans':detTrans, + 'gtDontCare':gtDontCarePolsNum, + 'detDontCare':detDontCarePolsNum, + 'evaluationParams': evaluationParams, + 'evaluationLog': evaluationLog + } + + # Compute AP + AP = 0 + if evaluationParams['CONFIDENCES']: + AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt) + + methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt + methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet + methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision) + + methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP } + + resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics} + + + return resDict; + + + +if __name__=='__main__': + ''' + results_dir: result directory + score_det: score of detection bounding box + score_rec: score of the mask recognition branch + score_rec_seq: score of the sequence recognition branch + lexicon_type: 1 for generic; 2 for weak; 3 for strong + ''' + angle = 45 + results_dir = '../../../output/mixtrain/inference/rotated_ic13_test_' + str(angle) + '/model_0250000_1000_results/' + score_rec_seq = 0.9 + score_rec = 0.4 + score_det = 0.1 + evaluate_result_path = prepare_results_for_evaluation(results_dir, + use_lexicon=False, cache_dir='./cache_files', + score_det=score_det, score_rec=score_rec, score_rec_seq=score_rec_seq) + p = { + 'g': '../gt/gt_'+str(angle)+'.zip', + 's': evaluate_result_path + } + rrc_evaluation_funcs.main_evaluation(p,default_evaluation_params,validate_data,evaluate_method) \ No newline at end of file diff --git a/evaluation/rotated_icdar2013/gt/gt.zip b/evaluation/rotated_icdar2013/gt/gt.zip new file mode 100644 index 0000000000000000000000000000000000000000..c5701e2cbc92bb133f80a9d1e45399d5629dde9c Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_-15.zip b/evaluation/rotated_icdar2013/gt/gt_-15.zip new file mode 100644 index 0000000000000000000000000000000000000000..44ff5e7dfe1f5b3a228fc42cec30c207ac90bc0f Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_-15.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_-30.zip b/evaluation/rotated_icdar2013/gt/gt_-30.zip new file mode 100644 index 0000000000000000000000000000000000000000..f17592af8d6a802606676618ff44da4aead28fe9 Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_-30.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_-45.zip b/evaluation/rotated_icdar2013/gt/gt_-45.zip new file mode 100644 index 0000000000000000000000000000000000000000..7c2b4f8e072c5b03bbc394905cee36378d3e4297 Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_-45.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_-60.zip b/evaluation/rotated_icdar2013/gt/gt_-60.zip new file mode 100644 index 0000000000000000000000000000000000000000..e6381c8919be8da7003783c8f8d1a26eaba1972c Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_-60.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_-75.zip b/evaluation/rotated_icdar2013/gt/gt_-75.zip new file mode 100644 index 0000000000000000000000000000000000000000..227c8cae805facbf645c77ffc1aebf7afd5dee19 Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_-75.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_-90.zip b/evaluation/rotated_icdar2013/gt/gt_-90.zip new file mode 100644 index 0000000000000000000000000000000000000000..77069e78962076f9874707bb8ce0988983854f9c Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_-90.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_0.zip b/evaluation/rotated_icdar2013/gt/gt_0.zip new file mode 100644 index 0000000000000000000000000000000000000000..e899e1400998e88a1b04193f22de95191b5ab7e7 Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_0.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_15.zip b/evaluation/rotated_icdar2013/gt/gt_15.zip new file mode 100644 index 0000000000000000000000000000000000000000..349c7829cdc7661d5f7550dd5ae59a173fe7c859 Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_15.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_30.zip b/evaluation/rotated_icdar2013/gt/gt_30.zip new file mode 100644 index 0000000000000000000000000000000000000000..b03463737b6b95a342105ad3918bec5a8482db08 Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_30.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_45.zip b/evaluation/rotated_icdar2013/gt/gt_45.zip new file mode 100644 index 0000000000000000000000000000000000000000..f7cd7f8f21a480a552980f081076bfa78d6487cc Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_45.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_60.zip b/evaluation/rotated_icdar2013/gt/gt_60.zip new file mode 100644 index 0000000000000000000000000000000000000000..4523de5eb8ea22a32b4007328afc56024bf126e6 Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_60.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_75.zip b/evaluation/rotated_icdar2013/gt/gt_75.zip new file mode 100644 index 0000000000000000000000000000000000000000..c825e7e2112a2febaf3448cf7bfcfc77526ea2e5 Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_75.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_85.zip b/evaluation/rotated_icdar2013/gt/gt_85.zip new file mode 100644 index 0000000000000000000000000000000000000000..6b11e24b67fe90130b92f756b41918f89319fa10 Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_85.zip differ diff --git a/evaluation/rotated_icdar2013/gt/gt_90.zip b/evaluation/rotated_icdar2013/gt/gt_90.zip new file mode 100644 index 0000000000000000000000000000000000000000..12d02644e6453be15b9f3c0b35170aed389a4088 Binary files /dev/null and b/evaluation/rotated_icdar2013/gt/gt_90.zip differ diff --git a/evaluation/totaltext/e2e/prepare_results.py b/evaluation/totaltext/e2e/prepare_results.py new file mode 100644 index 0000000000000000000000000000000000000000..9700cfbe9b1557692a741359ae2f0a44d1fa68a7 --- /dev/null +++ b/evaluation/totaltext/e2e/prepare_results.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import sys +import os +import glob +sys.path.append('./') +import shapely +from shapely.geometry import Polygon,MultiPoint +import numpy as np +import editdistance +sys.path.append('../../') +from weighted_editdistance import weighted_edit_distance +from tqdm import tqdm +try: + import pickle +except ImportError: + import cPickle as pickle + +def list_from_str(st): + line = st.split(';') + segms = line[1].split(',') + scores = line[2].split(',') + new_line = [float(a) for a in segms]+[float(scores[-4])]+[scores[-5]]+[scores[-6]]+[float(scores[-3])]+[float(scores[-2])] + [scores[-1]] + return new_line + +def polygon_from_list(line): + """ + Create a shapely polygon object from gt or dt line. + """ + polygon_points = np.array(line).reshape(-1, 2) + polygon = Polygon(polygon_points).convex_hull + return polygon + +def polygon_iou(list1, list2): + """ + Intersection over union between two shapely polygons. + """ + polygon_points1 = np.array(list1).reshape(-1, 2) + poly1 = Polygon(polygon_points1).convex_hull + polygon_points2 = np.array(list2).reshape(-1, 2) + poly2 = Polygon(polygon_points2).convex_hull + union_poly = np.concatenate((polygon_points1,polygon_points2)) + if not poly1.intersects(poly2): # this test is fast and can accelerate calculation + iou = 0 + else: + try: + inter_area = poly1.intersection(poly2).area + #union_area = poly1.area + poly2.area - inter_area + union_area = MultiPoint(union_poly).convex_hull.area + iou = float(inter_area) / (union_area+1e-6) + except shapely.geos.TopologicalError: + print('shapely.geos.TopologicalError occured, iou set to 0') + iou = 0 + return iou + +def nms(boxes,overlap): + rec_scores = [b[-6] for b in boxes] + indices = sorted(range(len(rec_scores)), key=lambda k: -rec_scores[k]) + box_num = len(boxes) + nms_flag = [True]*box_num + for i in range(box_num): + ii = indices[i] + if not nms_flag[ii]: + continue + for j in range(box_num): + jj = indices[j] + if j == i: + continue + if not nms_flag[jj]: + continue + box1 = boxes[ii] + box2 = boxes[jj] + box1_score = rec_scores[ii] + box2_score = rec_scores[jj] + str1 = box1[9] + str2 = box2[9] + box_i = [box1[0],box1[1],box1[4],box1[5]] + box_j = [box2[0],box2[1],box2[4],box2[5]] + poly1 = polygon_from_list(box1[0:-6]) + poly2 = polygon_from_list(box2[0:-6]) + iou = polygon_iou(box1[0:-6],box2[0:-6]) + thresh = overlap + + if iou > thresh: + if box1_score > box2_score: + nms_flag[jj] = False + if box1_score == box2_score and poly1.area > poly2.area: + nms_flag[jj] = False + if box1_score == box2_score and poly1.area<=poly2.area: + nms_flag[ii] = False + break + + return nms_flag + +def packing(save_dir, cache_dir, pack_name): + files = os.listdir(save_dir) + if not os.path.exists(cache_dir): + os.mkdir(cache_dir) + os.system('zip -r -q -j '+os.path.join(cache_dir, pack_name+'.zip')+' '+save_dir+'/*') + +def test_single(results_dir,lexicon_type=3,cache_dir='./cache_dir',score_det=0.5,score_rec=0.5,score_rec_seq=0.5,overlap=0.2, use_lexicon=True, weighted_ed=True, use_seq=False, use_char=False, mix=False): + ''' + results_dir: result directory + score_det: score of detection bounding box + score_rec: score of the mask recognition branch + socre_rec_seq: score of the sequence recognition branch + overlap: overlap threshold used for nms + lexicon_type: 1 for generic; 2 for weak; 3 for strong + use_seq: use the recognition result of sequence branch + use_mix: use both the recognition result of the mask and sequence branches, selected by score + ''' + print('score_det:', 'score_det:', score_det, 'score_rec:', score_rec, 'score_rec_seq:', score_rec_seq, 'overlap:', overlap,'lexicon_type:', lexicon_type, 'weighted_ed:', weighted_ed, 'use_seq:', use_seq, 'use_char:', use_char, 'mix:', mix) + if not os.path.exists(cache_dir): + os.mkdir(cache_dir) + nms_dir = os.path.join(cache_dir,str(score_det)+'_'+str(score_rec)+'_'+str(score_rec_seq)) + if not os.path.exists(nms_dir): + os.mkdir(nms_dir) + if use_lexicon and lexicon_type==2: + # weak lexicon + lexicon_path = '../../lexicons/totaltext/weak_voc_new.txt' + lexicon_fid=open(lexicon_path, 'r') + pair_list = open('../../lexicons/totaltext/weak_voc_pair_list.txt', 'r') + pairs = dict() + for line in pair_list.readlines(): + line=line.strip() + word = line.split(' ')[0].upper() + word_gt = line[len(word)+1:] + pairs[word] = word_gt + lexicon_fid=open(lexicon_path, 'r') + lexicon=[] + for line in lexicon_fid.readlines(): + line=line.strip() + lexicon.append(line) + + for res_file in glob.glob("*.txt"): + result_path = os.path.join(results_dir,res_file) + if os.path.isfile(result_path): + with open(result_path,'r') as f: + dt_lines = [a.strip() for a in f.readlines()] + dt_lines = [list_from_str(dt) for dt in dt_lines] + else: + dt_lines = [] + dt_lines = [dt for dt in dt_lines if dt[-2]>score_rec_seq and dt[-3]>score_rec and dt[-6]>score_det] + nms_flag = nms(dt_lines,overlap) + boxes = [] + for k in range(len(dt_lines)): + dt = dt_lines[k] + if nms_flag[k]: + if dt not in boxes: + boxes.append(dt) + + with open(os.path.join(nms_dir,'gt_'+res_file.split('.')[0].split('_')[1]+'.txt'),'w') as f: + for g in boxes: + gt_coors = [int(b) for b in g[0:-6]] + with open('../../../' + g[-1], "rb") as input_file: + dict_scores = pickle.load(input_file) + if use_char and use_seq: + if g[-2]>g[-3]: + word = g[-5] + scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1) + else: + word = g[-4] + scores = dict_scores['seg_char_scores'] + elif use_seq: + word = g[-5] + scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1) + else: + word = g[-4] + scores = dict_scores['seg_char_scores'] + if not use_lexicon: + match_word = word + match_dist = 0. + else: + match_word, match_dist = find_match_word(word, pairs, scores, use_lexicon, weighted_ed, lexicon) + if match_dist<1.5 or lexicon_type==1: + gt_coor_strs = [str(a) for a in gt_coors]+ [match_word] + f.write(','.join(gt_coor_strs)+'\r\n') + + pack_name = str(score_det)+'_'+str(score_rec)+'_over'+str(overlap) + + packing(nms_dir,cache_dir,pack_name) + submit_file_path = os.path.join(cache_dir, pack_name+'.zip') + return submit_file_path + +def find_match_word(rec_str, pairs, scores_numpy, use_ed=True, weighted_ed=False, lexicon=None): + if not use_ed: + return rec_str + rec_str = rec_str.upper() + dist_min = 100 + dist_min_pre = 100 + match_word = '' + match_dist = 100 + if not weighted_ed: + for word in lexicon: + word = word.upper() + ed = editdistance.eval(rec_str, word) + length_dist = abs(len(word) - len(rec_str)) + # dist = ed + length_dist + dist = ed + if dist -s= [-o= -p=]' %sys.argv[0]) + sys.exit(2) + + +def load_zip_file_keys(file,fileNameRegExp=''): + """ + Returns an array with the entries of the ZIP file that match with the regular expression. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + """ + try: + archive=zipfile.ZipFile(file, mode='r', allowZip64=True) + except : + raise Exception('Error loading the ZIP archive.') + + pairs = [] + + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp!="": + m = re.match(fileNameRegExp,name) + if m == None: + addFile = False + else: + if len(m.groups())>0: + keyName = m.group(1) + + if addFile: + pairs.append( keyName ) + + return pairs + + +def load_zip_file(file,fileNameRegExp='',allEntries=False): + """ + Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + allEntries validates that all entries in the ZIP file pass the fileNameRegExp + """ + try: + archive=zipfile.ZipFile(file, mode='r', allowZip64=True) + except : + raise Exception('Error loading the ZIP archive') + + pairs = [] + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp!="": + m = re.match(fileNameRegExp,name) + if m == None: + addFile = False + else: + if len(m.groups())>0: + keyName = m.group(1) + + if addFile: + pairs.append( [ keyName , archive.read(name)] ) + else: + if allEntries: + raise Exception('ZIP entry not valid: %s' %name) + + return dict(pairs) + +def decode_utf8(raw): + """ + Returns a Unicode object on success, or None on failure + """ + try: + raw = codecs.decode(raw,'utf-8', 'replace') + #extracts BOM if exists + raw = raw.encode('utf8') + if raw.startswith(codecs.BOM_UTF8): + raw = raw.replace(codecs.BOM_UTF8, '', 1) + return raw.decode('utf-8') + except: + return None + +def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): + """ + This function validates that all lines of the file calling the Line validation function for each line + """ + utf8File = decode_utf8(file_contents) + if (utf8File is None) : + raise Exception("The file %s is not UTF-8" %fileName) + + lines = utf8File.split( "\r\n" if CRLF else "\n" ) + for line in lines: + line = line.replace("\r","").replace("\n","") + if(line != ""): + try: + validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) + except Exception as e: + raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace')) + + + +def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + """ + get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) + + +def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + Returns values from a textline. Points , [Confidences], [Transcriptions] + """ + confidence = 0.0 + transcription = ""; + points = [] + + numPoints = 4; + + if LTRB: + + numPoints = 4; + + if withTranscription and withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + if m == None : + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") + elif withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") + elif withTranscription: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") + else: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") + + xmin = int(m.group(1)) + ymin = int(m.group(2)) + xmax = int(m.group(3)) + ymax = int(m.group(4)) + if(xmax0 and imHeight>0): + validate_point_inside_bounds(xmin,ymin,imWidth,imHeight); + validate_point_inside_bounds(xmax,ymax,imWidth,imHeight); + + else: + + numPoints = 8; + + if withTranscription and withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") + elif withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") + elif withTranscription: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") + else: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") + + points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ] + + validate_clockwise_points(points) + + if (imWidth>0 and imHeight>0): + validate_point_inside_bounds(points[0],points[1],imWidth,imHeight); + validate_point_inside_bounds(points[2],points[3],imWidth,imHeight); + validate_point_inside_bounds(points[4],points[5],imWidth,imHeight); + validate_point_inside_bounds(points[6],points[7],imWidth,imHeight); + + + if withConfidence: + try: + confidence = float(m.group(numPoints+1)) + except ValueError: + raise Exception("Confidence value must be a float") + + if withTranscription: + posTranscription = numPoints + (2 if withConfidence else 1) + transcription = m.group(posTranscription) + m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription) + if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters + transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") + + return points,confidence,transcription + + +def validate_point_inside_bounds(x,y,imWidth,imHeight): + if(x<0 or x>imWidth): + raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight)) + if(y<0 or y>imHeight): + raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight)) + +def validate_clockwise_points(points): + """ + Validates that the points that the 4 points that dlimite a polygon are in clockwise order. + """ + + if len(points) != 8: + raise Exception("Points list not valid." + str(len(points))) + + point = [ + [int(points[0]) , int(points[1])], + [int(points[2]) , int(points[3])], + [int(points[4]) , int(points[5])], + [int(points[6]) , int(points[7])] + ] + edge = [ + ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]), + ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]), + ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]), + ( point[0][0] - point[3][0])*( point[0][1] + point[3][1]) + ] + + summatory = edge[0] + edge[1] + edge[2] + edge[3]; + if summatory>0: + raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") + +def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True): + """ + Returns all points, confindences and transcriptions of a file in lists. Valid line formats: + xmin,ymin,xmax,ymax,[confidence],[transcription] + x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] + """ + pointsList = [] + transcriptionsList = [] + confidencesList = [] + + lines = content.split( "\r\n" if CRLF else "\n" ) + for line in lines: + line = line.replace("\r","").replace("\n","") + if(line != "") : + points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight); + pointsList.append(points) + transcriptionsList.append(transcription) + confidencesList.append(confidence) + + if withConfidence and len(confidencesList)>0 and sort_by_confidences: + import numpy as np + sorted_ind = np.argsort(-np.array(confidencesList)) + confidencesList = [confidencesList[i] for i in sorted_ind] + pointsList = [pointsList[i] for i in sorted_ind] + transcriptionsList = [transcriptionsList[i] for i in sorted_ind] + + return pointsList,confidencesList,transcriptionsList + +def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True): + """ + This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. + Params: + p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results + """ + + if (p == None): + p = dict([s[1:].split('=') for s in sys.argv[1:]]) + if(len(sys.argv)<3): + print_help() + + evalParams = default_evaluation_params_fn() + if 'p' in p.keys(): + evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) + + resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'} + try: + validate_data_fn(p['g'], p['s'], evalParams) + evalData = evaluate_method_fn(p['g'], p['s'], evalParams) + resDict.update(evalData) + + except Exception as e: + resDict['Message']= str(e) + resDict['calculated']=False + + if 'o' in p: + if not os.path.exists(p['o']): + os.makedirs(p['o']) + + resultsOutputname = p['o'] + '/results.zip' + outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) + + del resDict['per_sample'] + if 'output_items' in resDict.keys(): + del resDict['output_items'] + + outZip.writestr('method.json',json.dumps(resDict)) + + if not resDict['calculated']: + if show_result: + sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n') + if 'o' in p: + outZip.close() + return resDict + + if 'o' in p: + if per_sample == True: + for k,v in evalData['per_sample'].items(): + outZip.writestr( k + '.json',json.dumps(v)) + + if 'output_items' in evalData.keys(): + for k, v in evalData['output_items'].items(): + outZip.writestr( k,v) + + outZip.close() + + if show_result: + sys.stdout.write("Calculated!") + sys.stdout.write(json.dumps(resDict['method'])) + + return resDict + + +def main_validation(default_evaluation_params_fn,validate_data_fn): + """ + This process validates a method + Params: + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + """ + try: + p = dict([s[1:].split('=') for s in sys.argv[1:]]) + evalParams = default_evaluation_params_fn() + if 'p' in p.keys(): + evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) + + validate_data_fn(p['g'], p['s'], evalParams) + print('SUCCESS') + sys.exit(0) + except Exception as e: + print(str(e)) + sys.exit(101) \ No newline at end of file diff --git a/evaluation/totaltext/e2e/rrc_evaluation_funcs_total_text.py b/evaluation/totaltext/e2e/rrc_evaluation_funcs_total_text.py new file mode 100644 index 0000000000000000000000000000000000000000..0fda51d491d31191c39f03f66ef08da9b60dc547 --- /dev/null +++ b/evaluation/totaltext/e2e/rrc_evaluation_funcs_total_text.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python2 +#encoding: UTF-8 +import json +import sys;sys.path.append('./') +import zipfile +import re +import sys +import os +import codecs +import importlib +from io import StringIO + +def print_help(): + sys.stdout.write('Usage: python %s.py -g= -s= -o= [-i= -p=]' %sys.argv[0]) + sys.exit(2) + + +def load_zip_file_keys(file,fileNameRegExp=''): + """ + Returns an array with the entries of the ZIP file that match with the regular expression. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + """ + try: + archive=zipfile.ZipFile(file, mode='r', allowZip64=True) + except : + raise Exception('Error loading the ZIP archive.') + + pairs = [] + + for name in archive.namelist(): + addFile = True + keyName = name + # if fileNameRegExp!="": + # m = re.match(fileNameRegExp,name) + # if m == None: + # addFile = False + # else: + # if len(m.groups())>0: + # keyName = m.group(1) + + if addFile: + pairs.append( keyName ) + + return pairs + + +def load_zip_file(file,fileNameRegExp='',allEntries=False): + """ + Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + allEntries validates that all entries in the ZIP file pass the fileNameRegExp + """ + try: + archive=zipfile.ZipFile(file, mode='r', allowZip64=True) + except : + raise Exception('Error loading the ZIP archive') + + pairs = [] + for name in archive.namelist(): + addFile = True + keyName = name + # if fileNameRegExp!="": + # m = re.match(fileNameRegExp,name) + # if m == None: + # addFile = False + # else: + # if len(m.groups())>0: + # keyName = m.group(1) + + if addFile: + pairs.append( [ keyName , archive.read(name)] ) + else: + if allEntries: + raise Exception('ZIP entry not valid: %s' %name) + + return dict(pairs) + +def decode_utf8(raw): + """ + Returns a Unicode object on success, or None on failure + """ + try: + raw = codecs.decode(raw,'utf-8', 'replace') + #extracts BOM if exists + raw = raw.encode('utf8') + if raw.startswith(codecs.BOM_UTF8): + raw = raw.replace(codecs.BOM_UTF8, '', 1) + return raw.decode('utf-8') + except: + return None + +def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): + """ + This function validates that all lines of the file calling the Line validation function for each line + """ + utf8File = decode_utf8(file_contents) + if (utf8File is None) : + raise Exception("The file %s is not UTF-8" %fileName) + + lines = utf8File.split( "\r\n" if CRLF else "\n" ) + for line in lines: + line = line.replace("\r","").replace("\n","") + if(line != ""): + try: + validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) + except Exception as e: + raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace')) + + + +def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + """ + get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) + + +def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + Returns values from a textline. Points , [Confidences], [Transcriptions] + """ + confidence = 0.0 + transcription = ""; + points = [] + + numPoints = 4; + if LTRB: + + numPoints = 4; + + if withTranscription and withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + if m == None : + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") + elif withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") + elif withTranscription: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") + else: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") + + xmin = int(m.group(1)) + ymin = int(m.group(2)) + xmax = int(m.group(3)) + ymax = int(m.group(4)) + if(xmax0 and imHeight>0): + validate_point_inside_bounds(xmin,ymin,imWidth,imHeight); + validate_point_inside_bounds(xmax,ymax,imWidth,imHeight); + + else: + line_split = line.split(',') + # print(line_split) + numPoints = int((len(line_split) - 1) / 2) + points = [ float(line_split[i]) for i in range(2 * numPoints) ] + # print(points) + transcription = line_split[-1] + # numPoints = 8; + + # if withTranscription and withConfidence: + # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + # if m == None : + # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") + # elif withConfidence: + # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) + # if m == None : + # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") + # elif withTranscription: + # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line) + # if m == None : + # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") + # else: + # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line) + # if m == None : + # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") + + # points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ] + + # validate_clockwise_points(points) + + # if (imWidth>0 and imHeight>0): + # validate_point_inside_bounds(points[0],points[1],imWidth,imHeight); + # validate_point_inside_bounds(points[2],points[3],imWidth,imHeight); + # validate_point_inside_bounds(points[4],points[5],imWidth,imHeight); + # validate_point_inside_bounds(points[6],points[7],imWidth,imHeight); + + + # if withConfidence: + # try: + # confidence = float(m.group(numPoints+1)) + # except ValueError: + # raise Exception("Confidence value must be a float") + + # if withTranscription: + # posTranscription = numPoints + (2 if withConfidence else 1) + # transcription = m.group(posTranscription) + # m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription) + # if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters + # transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") + + return points,confidence,transcription + + +def validate_point_inside_bounds(x,y,imWidth,imHeight): + if(x<0 or x>imWidth): + raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight)) + if(y<0 or y>imHeight): + raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight)) + +def validate_clockwise_points(points): + """ + Validates that the points that the 4 points that dlimite a polygon are in clockwise order. + """ + + if len(points) != 8: + raise Exception("Points list not valid." + str(len(points))) + + point = [ + [int(points[0]) , int(points[1])], + [int(points[2]) , int(points[3])], + [int(points[4]) , int(points[5])], + [int(points[6]) , int(points[7])] + ] + edge = [ + ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]), + ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]), + ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]), + ( point[0][0] - point[3][0])*( point[0][1] + point[3][1]) + ] + + summatory = edge[0] + edge[1] + edge[2] + edge[3]; + if summatory>0: + raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") + +def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True): + """ + Returns all points, confindences and transcriptions of a file in lists. Valid line formats: + xmin,ymin,xmax,ymax,[confidence],[transcription] + x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] + """ + pointsList = [] + transcriptionsList = [] + confidencesList = [] + + lines = content.split( "\r\n" if CRLF else "\n" ) + for line in lines: + line = line.replace("\r","").replace("\n","") + if(line != "") : + points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight); + pointsList.append(points) + transcriptionsList.append(transcription) + confidencesList.append(confidence) + + if withConfidence and len(confidencesList)>0 and sort_by_confidences: + confidencesList, pointsList,transcriptionsList = (list(t) for t in zip(*sorted(zip(confidencesList, pointsList, transcriptionsList), reverse=True))) + + return pointsList,confidencesList,transcriptionsList + +def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True): + """ + This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. + Params: + p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results + """ + + if (p == None): + p = dict([s[1:].split('=') for s in sys.argv[1:]]) + if(len(sys.argv)<2): + print_help() + + evalParams = default_evaluation_params_fn() + if 'p' in list(p.keys()): + evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) + + resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'} + try: + validate_data_fn(p['g'], p['s'], evalParams) + evalData = evaluate_method_fn(p['g'], p['s'], evalParams) + resDict.update(evalData) + + except Exception as e: + resDict['Message']= str(e) + resDict['calculated']=False + + if not os.path.exists(p['o']): + os.makedirs(p['o']) + + resultsOutputname = p['o'] + '/results.zip' + outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) + + del resDict['per_sample'] + if 'output_items' in list(resDict.keys()): + del resDict['output_items'] + + outZip.writestr('method.json',json.dumps(resDict)) + + if not resDict['calculated']: + if show_result: + sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n') + outZip.close() + return resDict + + if per_sample == True: + for k,v in evalData['per_sample'].items(): + outZip.writestr( k + '.json',json.dumps(v)) + + if 'output_items' in list(evalData.keys()): + for k, v in evalData['output_items'].items(): + outZip.writestr( k,v) + + outZip.close() + + if show_result: + sys.stdout.write("Calculated!") + sys.stdout.write(json.dumps(resDict['method'])) + + return resDict + + +def main_validation(default_evaluation_params_fn,validate_data_fn): + """ + This process validates a method + Params: + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + """ + try: + p = dict([s[1:].split('=') for s in sys.argv[1:]]) + evalParams = default_evaluation_params_fn() + if 'p' in list(p.keys()): + evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) + + validate_data_fn(p['g'], p['s'], evalParams) + print('SUCCESS') + sys.exit(0) + except Exception as e: + print(str(e)) + sys.exit(101) \ No newline at end of file diff --git a/evaluation/totaltext/e2e/script.py b/evaluation/totaltext/e2e/script.py new file mode 100644 index 0000000000000000000000000000000000000000..5663255701b0d7bc0815f29b75f1d6765a99e244 --- /dev/null +++ b/evaluation/totaltext/e2e/script.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# encoding=utf8 +from collections import namedtuple +import rrc_evaluation_funcs_total_text as rrc_evaluation_funcs +import importlib +from prepare_results import prepare_results_for_evaluation + +def evaluation_imports(): + """ + evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation. + """ + return { + 'Polygon':'plg', + 'numpy':'np' + } + +def default_evaluation_params(): + """ + default_evaluation_params: Default parameters to use for the validation and evaluation. + """ + return { + 'IOU_CONSTRAINT' :0.5, + 'AREA_PRECISION_CONSTRAINT' :0.5, + 'WORD_SPOTTING' :False, + 'MIN_LENGTH_CARE_WORD' :3, + 'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt', + 'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt', + 'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) + 'CRLF':False, # Lines are delimited by Windows CRLF format + 'CONFIDENCES':False, #Detections must include confidence value. MAP and MAR will be calculated, + 'SPECIAL_CHARACTERS':'!?.:,*"()·[]/\'', + 'ONLY_REMOVE_FIRST_LAST_CHARACTER' : True + } + +def validate_data(gtFilePath, submFilePath, evaluationParams): + """ + Method validate_data: validates that all files in the results folder are correct (have the correct name contents). + Validates also that there are no missing files in the folder. + If some error detected, the method raises the error + """ + gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) + + subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) + + #Validate format of GroundTruth + for k in gt: + rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True) + + #Validate format of results + for k in subm: + if (k in gt) == False : + raise Exception("The sample %s not present in GT" %k) + + rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES']) + + +def evaluate_method(gtFilePath, submFilePath, evaluationParams): + """ + Method evaluate_method: evaluate method and returns the results + Results. Dictionary with the following values: + - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } + - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } + """ + for module,alias in evaluation_imports().items(): + globals()[alias] = importlib.import_module(module) + + def polygon_from_points(points,correctOffset=False): + """ + Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 + """ + resBoxes=np.empty([1,len(points)],dtype='int32') + for i in range(int(len(points) / 2)): + resBoxes[0, i] = int(points[2*i]) + resBoxes[0, int(len(points) / 2) + i] = int(points[2*i+1]) + + pointMat = resBoxes[0].reshape([2,-1]).T + return plg.Polygon( pointMat) + + def rectangle_to_polygon(rect): + resBoxes=np.empty([1,8],dtype='int32') + resBoxes[0,0]=int(rect.xmin) + resBoxes[0,4]=int(rect.ymax) + resBoxes[0,1]=int(rect.xmin) + resBoxes[0,5]=int(rect.ymin) + resBoxes[0,2]=int(rect.xmax) + resBoxes[0,6]=int(rect.ymin) + resBoxes[0,3]=int(rect.xmax) + resBoxes[0,7]=int(rect.ymax) + + pointMat = resBoxes[0].reshape([2,4]).T + + return plg.Polygon( pointMat) + + def rectangle_to_points(rect): + points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)] + return points + + def get_union(pD,pG): + areaA = pD.area(); + areaB = pG.area(); + return areaA + areaB - get_intersection(pD, pG); + + def get_intersection_over_union(pD,pG): + try: + return get_intersection(pD, pG) / get_union(pD, pG); + except: + return 0 + + def get_intersection(pD,pG): + pInt = pD & pG + if len(pInt) == 0: + return 0 + return pInt.area() + + def compute_ap(confList, matchList,numGtCare): + correct = 0 + AP = 0 + if len(confList)>0: + confList = np.array(confList) + matchList = np.array(matchList) + sorted_ind = np.argsort(-confList) + confList = confList[sorted_ind] + matchList = matchList[sorted_ind] + for n in range(len(confList)): + match = matchList[n] + if match: + correct += 1 + AP += float(correct)/(n + 1) + + if numGtCare>0: + AP /= numGtCare + + return AP + + def transcription_match(transGt,transDet,specialCharacters='!?.:,*"()·[]/\'',onlyRemoveFirstLastCharacterGT=True): + + if onlyRemoveFirstLastCharacterGT: + #special characters in GT are allowed only at initial or final position + if (transGt==transDet): + return True + + if specialCharacters.find(transGt[0])>-1: + if transGt[1:]==transDet: + return True + + if specialCharacters.find(transGt[-1])>-1: + if transGt[0:len(transGt)-1]==transDet: + return True + + if specialCharacters.find(transGt[0])>-1 and specialCharacters.find(transGt[-1])>-1: + if transGt[1:len(transGt)-1]==transDet: + return True + return False + else: + #Special characters are removed from the begining and the end of both Detection and GroundTruth + while len(transGt)>0 and specialCharacters.find(transGt[0])>-1: + transGt = transGt[1:] + + while len(transDet)>0 and specialCharacters.find(transDet[0])>-1: + transDet = transDet[1:] + + while len(transGt)>0 and specialCharacters.find(transGt[-1])>-1 : + transGt = transGt[0:len(transGt)-1] + + while len(transDet)>0 and specialCharacters.find(transDet[-1])>-1: + transDet = transDet[0:len(transDet)-1] + + return transGt == transDet + + + def include_in_dictionary(transcription): + """ + Function used in Word Spotting that finds if the Ground Truth transcription meets the rules to enter into the dictionary. If not, the transcription will be cared as don't care + """ + #special case 's at final + if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S": + transcription = transcription[0:len(transcription)-2] + + #hypens at init or final of the word + transcription = transcription.strip('-'); + + specialCharacters = "'!?.:,*\"()·[]/"; + for character in specialCharacters: + transcription = transcription.replace(character,' ') + + transcription = transcription.strip() + + if len(transcription) != len(transcription.replace(" ","")) : + return False; + + if len(transcription) < evaluationParams['MIN_LENGTH_CARE_WORD']: + return False; + + notAllowed = "×÷·"; + + range1 = [ ord(u'a'), ord(u'z') ] + range2 = [ ord(u'A'), ord(u'Z') ] + range3 = [ ord(u'À'), ord(u'ƿ') ] + range4 = [ ord(u'DŽ'), ord(u'ɿ') ] + range5 = [ ord(u'Ά'), ord(u'Ͽ') ] + range6 = [ ord(u'-'), ord(u'-') ] + + for char in transcription : + charCode = ord(char) + if(notAllowed.find(char) != -1): + return False + + valid = ( charCode>=range1[0] and charCode<=range1[1] ) or ( charCode>=range2[0] and charCode<=range2[1] ) or ( charCode>=range3[0] and charCode<=range3[1] ) or ( charCode>=range4[0] and charCode<=range4[1] ) or ( charCode>=range5[0] and charCode<=range5[1] ) or ( charCode>=range6[0] and charCode<=range6[1] ) + if valid == False: + return False + + return True + + def include_in_dictionary_transcription(transcription): + """ + Function applied to the Ground Truth transcriptions used in Word Spotting. It removes special characters or terminations + """ + #special case 's at final + if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S": + transcription = transcription[0:len(transcription)-2] + + #hypens at init or final of the word + transcription = transcription.strip('-'); + + specialCharacters = "'!?.:,*\"()·[]/"; + for character in specialCharacters: + transcription = transcription.replace(character,' ') + + transcription = transcription.strip() + + return transcription + + perSampleMetrics = {} + + matchedSum = 0 + + Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + + gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID']) + subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True) + + numGlobalCareGt = 0; + numGlobalCareDet = 0; + + arrGlobalConfidences = []; + arrGlobalMatches = []; + + for resFile in gt: + + gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile]) + if (gtFile is None) : + raise Exception("The file %s is not UTF-8" %resFile) + + recall = 0 + precision = 0 + hmean = 0 + detCorrect = 0 + iouMat = np.empty([1,1]) + gtPols = [] + detPols = [] + gtTrans = [] + detTrans = [] + gtPolPoints = [] + detPolPoints = [] + gtDontCarePolsNum = [] #Array of Ground Truth Polygons' keys marked as don't Care + detDontCarePolsNum = [] #Array of Detected Polygons' matched with a don't Care GT + detMatchedNums = [] + pairs = [] + + arrSampleConfidences = []; + arrSampleMatch = []; + sampleAP = 0; + + evaluationLog = "" + + pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False) + for n in range(len(pointsList)): + points = pointsList[n] + transcription = transcriptionsList[n] + dontCare = transcription == "###" + if evaluationParams['LTRB']: + gtRect = Rectangle(*points) + gtPol = rectangle_to_polygon(gtRect) + else: + gtPol = polygon_from_points(points) + gtPols.append(gtPol) + gtPolPoints.append(points) + + #On word spotting we will filter some transcriptions with special characters + if evaluationParams['WORD_SPOTTING'] : + if dontCare == False : + if include_in_dictionary(transcription) == False : + dontCare = True + else: + transcription = include_in_dictionary_transcription(transcription) + + gtTrans.append(transcription) + if dontCare: + gtDontCarePolsNum.append( len(gtPols)-1 ) + + evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n") + + if resFile in subm: + + detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile]) + + pointsList,confidencesList,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES']) + + for n in range(len(pointsList)): + points = pointsList[n] + transcription = transcriptionsList[n] + + if evaluationParams['LTRB']: + detRect = Rectangle(*points) + detPol = rectangle_to_polygon(detRect) + else: + detPol = polygon_from_points(points) + detPols.append(detPol) + detPolPoints.append(points) + detTrans.append(transcription) + + if len(gtDontCarePolsNum)>0 : + for dontCarePol in gtDontCarePolsNum: + dontCarePol = gtPols[dontCarePol] + intersected_area = get_intersection(dontCarePol,detPol) + pdDimensions = detPol.area() + precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions + if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ): + detDontCarePolsNum.append( len(detPols)-1 ) + break + + evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n") + + if len(gtPols)>0 and len(detPols)>0: + #Calculate IoU and precision matrixs + outputShape=[len(gtPols),len(detPols)] + iouMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtPols),np.int8) + detRectMat = np.zeros(len(detPols),np.int8) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = gtPols[gtNum] + pD = detPols[detNum] + iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG) + + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum : + if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + #detection matched only if transcription is equal + if evaluationParams['WORD_SPOTTING']: + correct = gtTrans[gtNum].upper() == detTrans[detNum].upper() + else: + correct = transcription_match(gtTrans[gtNum].upper(),detTrans[detNum].upper(),evaluationParams['SPECIAL_CHARACTERS'],evaluationParams['ONLY_REMOVE_FIRST_LAST_CHARACTER'])==True + detCorrect += (1 if correct else 0) + if correct: + detMatchedNums.append(detNum) + pairs.append({'gt':gtNum,'det':detNum,'correct':correct}) + evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + " trans. correct: " + str(correct) + "\n" + + if evaluationParams['CONFIDENCES']: + for detNum in range(len(detPols)): + if detNum not in detDontCarePolsNum : + #we exclude the don't care detections + match = detNum in detMatchedNums + + arrSampleConfidences.append(confidencesList[detNum]) + arrSampleMatch.append(match) + + arrGlobalConfidences.append(confidencesList[detNum]); + arrGlobalMatches.append(match); + + numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) + numDetCare = (len(detPols) - len(detDontCarePolsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if numDetCare >0 else float(1) + sampleAP = precision + else: + recall = float(detCorrect) / numGtCare + precision = 0 if numDetCare==0 else float(detCorrect) / numDetCare + if evaluationParams['CONFIDENCES']: + sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare ) + + hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall) + + matchedSum += detCorrect + numGlobalCareGt += numGtCare + numGlobalCareDet += numDetCare + + perSampleMetrics[resFile] = { + 'precision':precision, + 'recall':recall, + 'hmean':hmean, + 'pairs':pairs, + 'AP':sampleAP, + 'iouMat':[] if len(detPols)>100 else iouMat.tolist(), + 'gtPolPoints':gtPolPoints, + 'detPolPoints':detPolPoints, + 'gtTrans':gtTrans, + 'detTrans':detTrans, + 'gtDontCare':gtDontCarePolsNum, + 'detDontCare':detDontCarePolsNum, + 'evaluationParams': evaluationParams, + 'evaluationLog': evaluationLog + } + + # Compute AP + AP = 0 + if evaluationParams['CONFIDENCES']: + AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt) + + methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt + methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet + methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision) + + methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP } + + resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics} + + + return resDict; + + + +if __name__=='__main__': + ''' + results_dir: result directory + score_det: score of detection bounding box + score_rec: score of the mask recognition branch + score_rec_seq: score of the sequence recognition branch + lexicon_type: 1 for generic; 2 for weak; 3 for strong + ''' + results_dir = '../../../output/mixtrain/inference/total_text_test/model_0250000_1000_results/' + score_det = 0.05 + score_rec = 0.5 + use_lexicon = False + score_rec_seq = 0.9 + # use_lexicon = True + # score_rec_seq = 0.8 + evaluate_result_path = prepare_results_for_evaluation(results_dir, + use_lexicon=use_lexicon, cache_dir='./cache_files', + score_det=score_det, score_rec=score_rec, score_rec_seq=score_rec_seq) + p = { + 'g': "../gt.zip", + 'o': "./cache_files", + 's': evaluate_result_path + } + rrc_evaluation_funcs.main_evaluation(p,default_evaluation_params,validate_data,evaluate_method) \ No newline at end of file diff --git a/evaluation/totaltext/gt.zip b/evaluation/totaltext/gt.zip new file mode 100644 index 0000000000000000000000000000000000000000..992f7f340c79347c76c76972b941c33e03ba5ee3 Binary files /dev/null and b/evaluation/totaltext/gt.zip differ diff --git a/evaluation/weighted_editdistance.py b/evaluation/weighted_editdistance.py new file mode 100644 index 0000000000000000000000000000000000000000..3477e54497715a537e2b531bded8141f856de5b5 --- /dev/null +++ b/evaluation/weighted_editdistance.py @@ -0,0 +1,55 @@ +def weighted_edit_distance(word1, word2, scores): + m = len(word1) + n = len(word2) + dp = [[0 for __ in range(m + 1)] for __ in range(n + 1)] + for j in range(m + 1): + dp[0][j] = j + for i in range(n + 1): + dp[i][0] = i + for i in range(1, n + 1): ## word2 + for j in range(1, m + 1): ## word1 + delect_cost = ed_delect_cost(j-1, i-1, word1, word2, scores) ## delect a[i] + insert_cost = ed_insert_cost(j-1, i-1, word1, word2, scores) ## insert b[j] + if word1[j - 1] != word2[i - 1]: + replace_cost = ed_replace_cost(j-1, i-1, word1, word2, scores) ## replace a[i] with b[j] + else: + replace_cost = 0 + dp[i][j] = min(dp[i-1][j] + insert_cost, dp[i][j-1] + delect_cost, dp[i-1][j-1] + replace_cost) + + return dp[n][m] + +def ed_delect_cost(j, i, word1, word2, scores): + ## delect a[i] + c = char2num(word1[j]) + return scores[c][j] + + +def ed_insert_cost(i, j, word1, word2, scores): + ## insert b[j] + if i < len(word1) - 1: + c1 = char2num(word1[i]) + c2 = char2num(word1[i + 1]) + return (scores[c1][i] + scores[c2][i+1])/2 + else: + c1 = char2num(word1[i]) + return scores[c1][i] + + +def ed_replace_cost(i, j, word1, word2, scores): + ## replace a[i] with b[j] + c1 = char2num(word1[i]) + c2 = char2num(word2[j]) + # if word1 == "eeatpisaababarait".upper(): + # print(scores[c2][i]/scores[c1][i]) + + return max(1 - scores[c2][i]/scores[c1][i]*5, 0) + +def char2num(char): + if char in '0123456789': + num = ord(char) - ord('0') + 1 + elif char in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ': + num = ord(char.lower()) - ord('a') + 11 + else: + print('error symbol', char) + exit() + return num - 1 \ No newline at end of file diff --git a/example1.jpg b/example1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fb8063a68c4a7422810a9a83d323a78cd0faf67c Binary files /dev/null and b/example1.jpg differ diff --git a/example2.jpg b/example2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1903262693565fb8c70c63e454421857b575a55f Binary files /dev/null and b/example2.jpg differ diff --git a/example3.jpg b/example3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..76e7f5c7b9866a4d43b246d213c0efa9cc56c333 Binary files /dev/null and b/example3.jpg differ diff --git a/maskrcnn_benchmark/config/__init__.py b/maskrcnn_benchmark/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22a15023b1b06dad1f8c36924cdbb96bf1f5dc8d --- /dev/null +++ b/maskrcnn_benchmark/config/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .defaults import _C as cfg diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..6abf574d86e234c93b4e0162db49356a1f9e4fff --- /dev/null +++ b/maskrcnn_benchmark/config/defaults.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import os + +from yacs.config import CfgNode as CN + + +# ----------------------------------------------------------------------------- +# Convention about Training / Test specific parameters +# ----------------------------------------------------------------------------- +# Whenever an argument can be either used for training or for testing, the +# corresponding name will be post-fixed by a _TRAIN for a training parameter, +# or _TEST for a test-specific parameter. +# For example, the number of images during training will be +# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be +# IMAGES_PER_BATCH_TEST + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- + +_C = CN() + +_C.MODEL = CN() +_C.MODEL.RPN_ONLY = False +_C.MODEL.MASK_ON = False +_C.MODEL.SEG_ON = False +_C.MODEL.CHAR_MASK_ON = False +_C.MODEL.DEVICE = "cuda" +_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN" +_C.MODEL.TRAIN_DETECTION_ONLY = False +_C.MODEL.RESNET34 = False + +# If the WEIGHT starts with a catalog://, like :R-50, the code will look for +# the path in paths_catalog. Else, it will use it as the specified absolute +# path +_C.MODEL.WEIGHT = "" + +_C.SEQUENCE = CN() +_C.SEQUENCE.SEQ_ON = False +_C.SEQUENCE.NUM_CHAR = 38 +_C.SEQUENCE.BOS_TOKEN = 0 +_C.SEQUENCE.MAX_LENGTH = 32 +_C.SEQUENCE.TEACHER_FORCE_RATIO = 1.0 +_C.SEQUENCE.TWO_CONV = False +_C.SEQUENCE.MEAN_SCORE = False +_C.SEQUENCE.RESIZE_HEIGHT = 16 +_C.SEQUENCE.RESIZE_WIDTH = 64 + + +# ----------------------------------------------------------------------------- +# INPUT +# ----------------------------------------------------------------------------- +_C.INPUT = CN() +# Size of the smallest side of the image during training +_C.INPUT.MIN_SIZE_TRAIN = (800,) # (800,) +# Maximum size of the side of the image during training +_C.INPUT.MAX_SIZE_TRAIN = 1333 +# Size of the smallest side of the image during testing +_C.INPUT.MIN_SIZE_TEST = 800 +# Maximum size of the side of the image during testing +_C.INPUT.MAX_SIZE_TEST = 1333 +# Values to be used for image normalization +_C.INPUT.PIXEL_MEAN = [102.9801, 115.9465, 122.7717] +# Values to be used for image normalization +_C.INPUT.PIXEL_STD = [1.0, 1.0, 1.0] +# Convert image to BGR format (for Caffe2 models), in range 0-255 +_C.INPUT.TO_BGR255 = True +_C.INPUT.STRICT_RESIZE = False + + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- +_C.DATASETS = CN() +# List of the dataset names for training, as present in paths_catalog.py +_C.DATASETS.TRAIN = () +# List of the dataset names for testing, as present in paths_catalog.py +_C.DATASETS.TEST = () + +_C.DATASETS.RATIOS = [] + +_C.DATASETS.AUG = False +_C.DATASETS.RANDOM_CROP_PROB = 0.0 +_C.DATASETS.IGNORE_DIFFICULT = False +_C.DATASETS.FIX_CROP = False +_C.DATASETS.CROP_SIZE = (512, 512) +_C.DATASETS.MAX_ROTATE_THETA = 30 +_C.DATASETS.FIX_ROTATE = False + +# ----------------------------------------------------------------------------- +# DataLoader +# ----------------------------------------------------------------------------- +_C.DATALOADER = CN() +# Number of data loading threads +_C.DATALOADER.NUM_WORKERS = 4 +# If > 0, this enforces that each collated batch should have a size divisible +# by SIZE_DIVISIBILITY +_C.DATALOADER.SIZE_DIVISIBILITY = 0 +# If True, each batch should contain only images for which the aspect ratio +# is compatible. This groups portrait images together, and landscape images +# are not batched with portrait images. +_C.DATALOADER.ASPECT_RATIO_GROUPING = True + +# ---------------------------------------------------------------------------- # +# Backbone options +# ---------------------------------------------------------------------------- # +_C.MODEL.BACKBONE = CN() + +# The backbone conv body to use +# The string must match a function that is imported in modeling.model_builder +# (e.g., 'FPN.add_fpn_ResNet101_conv5_body' to specify a ResNet-101-FPN +# backbone) +_C.MODEL.BACKBONE.CONV_BODY = "R-50-C4" + +# Add StopGrad at a specified stage so the bottom layers are frozen +_C.MODEL.BACKBONE.FREEZE_CONV_BODY_AT = 2 +_C.MODEL.BACKBONE.OUT_CHANNELS = 256 * 4 + +# ---------------------------------------------------------------------------- # +# ResNe[X]t options (ResNets = {ResNet, ResNeXt} +# Note that parts of a resnet may be used for both the backbone and the head +# These options apply to both +# ---------------------------------------------------------------------------- # +_C.MODEL.RESNETS = CN() + +# Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt +_C.MODEL.RESNETS.NUM_GROUPS = 1 + +# Baseline width of each group +_C.MODEL.RESNETS.WIDTH_PER_GROUP = 64 + +# Place the stride 2 conv on the 1x1 filter +# Use True only for the original MSRA ResNet; use False for C2 and Torch models +_C.MODEL.RESNETS.STRIDE_IN_1X1 = True + +# Residual transformation function +_C.MODEL.RESNETS.TRANS_FUNC = "BottleneckWithFixedBatchNorm" +# ResNet's stem function (conv1 and pool1) +_C.MODEL.RESNETS.STEM_FUNC = "StemWithFixedBatchNorm" + +# Apply dilation in stage "res5" +_C.MODEL.RESNETS.RES5_DILATION = 1 + +_C.MODEL.RESNETS.BACKBONE_OUT_CHANNELS = 256 * 4 +_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256 +_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64 + +_C.MODEL.RESNETS.STAGE_WITH_DCN = (False, False, False, False) +_C.MODEL.RESNETS.WITH_MODULATED_DCN = False +_C.MODEL.RESNETS.DEFORMABLE_GROUPS = 1 +_C.MODEL.RESNETS.LAYERS = (3, 4, 6, 3) + +# ---------------------------------------------------------------------------- # +# FPN options +# ---------------------------------------------------------------------------- # +_C.MODEL.FPN = CN() +_C.MODEL.FPN.USE_GN = False +_C.MODEL.FPN.USE_RELU = False + +# ---------------------------------------------------------------------------- # +# RPN options +# ---------------------------------------------------------------------------- # +_C.MODEL.RPN = CN() +_C.MODEL.RPN.USE_FPN = False +# Base RPN anchor sizes given in absolute pixels w.r.t. the scaled network input +_C.MODEL.RPN.ANCHOR_SIZES = (32, 64, 128, 256, 512) +# Stride of the feature map that RPN is attached. +# For FPN, number of strides should match number of scales +_C.MODEL.RPN.ANCHOR_STRIDE = (16,) +# RPN anchor aspect ratios +_C.MODEL.RPN.ASPECT_RATIOS = (0.5, 1.0, 2.0) +# Remove RPN anchors that go outside the image by RPN_STRADDLE_THRESH pixels +# Set to -1 or a large value, e.g. 100000, to disable pruning anchors +_C.MODEL.RPN.STRADDLE_THRESH = 0 +# Minimum overlap required between an anchor and ground-truth box for the +# (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD +# ==> positive RPN example) +_C.MODEL.RPN.FG_IOU_THRESHOLD = 0.7 +# Maximum overlap allowed between an anchor and ground-truth box for the +# (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD +# ==> negative RPN example) +_C.MODEL.RPN.BG_IOU_THRESHOLD = 0.3 +# Total number of RPN examples per image +_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256 +# Target fraction of foreground (positive) examples per RPN minibatch +_C.MODEL.RPN.POSITIVE_FRACTION = 0.5 +# Number of top scoring RPN proposals to keep before applying NMS +# When FPN is used, this is *per FPN level* (not total) +_C.MODEL.RPN.PRE_NMS_TOP_N_TRAIN = 12000 +_C.MODEL.RPN.PRE_NMS_TOP_N_TEST = 6000 +# Number of top scoring RPN proposals to keep after applying NMS +_C.MODEL.RPN.POST_NMS_TOP_N_TRAIN = 2000 +_C.MODEL.RPN.POST_NMS_TOP_N_TEST = 1000 +# NMS threshold used on RPN proposals +_C.MODEL.RPN.NMS_THRESH = 0.7 +# Proposal height and width both need to be greater than RPN_MIN_SIZE +# (a the scale used during training or inference) +_C.MODEL.RPN.MIN_SIZE = 0 +# Number of top scoring RPN proposals to keep after combining proposals from +# all FPN levels +_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN = 2000 +_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 2000 + +_C.MODEL.SEG = CN() +_C.MODEL.SEG.USE_FPN = False +_C.MODEL.SEG.USE_FUSE_FEATURE = False +# Total number of SEG examples per image +_C.MODEL.SEG.BATCH_SIZE_PER_IMAGE = 256 +# Target fraction of foreground (positive) examples per SEG minibatch +_C.MODEL.SEG.POSITIVE_FRACTION = 0.5 +# NMS threshold used on SEG proposals +_C.MODEL.SEG.BINARY_THRESH = 0.5 +_C.MODEL.SEG.USE_MULTIPLE_THRESH = False +_C.MODEL.SEG.MULTIPLE_THRESH = (0.2, 0.3, 0.5, 0.7) +_C.MODEL.SEG.BOX_THRESH = 0.7 +# Proposal height and width both need to be greater than RPN_MIN_SIZE +# (a the scale used during training or inference) +_C.MODEL.SEG.MIN_SIZE = 0 +_C.MODEL.SEG.SHRINK_RATIO = 0.5 +# Number of top scoring RPN proposals to keep after combining proposals from +# all FPN levels +_C.MODEL.SEG.TOP_N_TRAIN = 1000 +_C.MODEL.SEG.TOP_N_TEST = 1000 +_C.MODEL.SEG.AUG_PROPOSALS = False +_C.MODEL.SEG.IGNORE_DIFFICULT = True +_C.MODEL.SEG.EXPAND_RATIO = 1.6 +_C.MODEL.SEG.BOX_EXPAND_RATIO = 1.5 +_C.MODEL.SEG.USE_SEG_POLY = False +_C.MODEL.SEG.USE_PPM = False + + +# ---------------------------------------------------------------------------- # +# ROI HEADS options +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_HEADS = CN() +_C.MODEL.ROI_HEADS.USE_FPN = False +# Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD) +_C.MODEL.ROI_HEADS.FG_IOU_THRESHOLD = 0.5 +# Overlap threshold for an RoI to be considered background +# (class = 0 if overlap in [0, BG_IOU_THRESHOLD)) +_C.MODEL.ROI_HEADS.BG_IOU_THRESHOLD = 0.5 +# Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets +# These are empirically chosen to approximately lead to unit variance targets +_C.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0) +# RoI minibatch size *per image* (number of regions of interest [ROIs]) +# Total number of RoIs per training minibatch = +# TRAIN.BATCH_SIZE_PER_IM * TRAIN.IMS_PER_BATCH * NUM_GPUS +# E.g., a common configuration is: 512 * 2 * 8 = 8192 +_C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512 +# Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0) +_C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25 + +# Only used on test mode + +# Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to +# balance obtaining high recall with not having too many low precision +# detections that will slow down inference post processing steps (like NMS) +# _C.MODEL.ROI_HEADS.SCORE_THRESH = 0.05 +_C.MODEL.ROI_HEADS.SCORE_THRESH = 0.0 +# Overlap threshold used for non-maximum suppression (suppress boxes with +# IoU >= this threshold) +_C.MODEL.ROI_HEADS.NMS = 0.5 +# Maximum number of detections to return per image (100 is based on the limit +# established for the COCO dataset) +_C.MODEL.ROI_HEADS.DETECTIONS_PER_IMG = 100 + + +_C.MODEL.ROI_BOX_HEAD = CN() +_C.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR = "ResNet50Conv5ROIFeatureExtractor" +_C.MODEL.ROI_BOX_HEAD.PREDICTOR = "FastRCNNPredictor" +_C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0 +_C.MODEL.ROI_BOX_HEAD.POOLER_SCALES = (1.0 / 16,) +_C.MODEL.ROI_BOX_HEAD.NUM_CLASSES = 81 +# Hidden layer dimension when using an MLP for the RoI box head +_C.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM = 1024 +_C.MODEL.ROI_BOX_HEAD.USE_REGRESSION = True +_C.MODEL.ROI_BOX_HEAD.INFERENCE_USE_BOX = True +_C.MODEL.ROI_BOX_HEAD.USE_MASKED_FEATURE = False +_C.MODEL.ROI_BOX_HEAD.SOFT_MASKED_FEATURE_RATIO = 0. +_C.MODEL.ROI_BOX_HEAD.MIX_OPTION = "" + + +_C.MODEL.ROI_MASK_HEAD = CN() +_C.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR = "ResNet50Conv5ROIFeatureExtractor" +_C.MODEL.ROI_MASK_HEAD.PREDICTOR = "MaskRCNNC4Predictor" +_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_H = 32 +_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_W = 128 +_C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0 +_C.MODEL.ROI_MASK_HEAD.POOLER_SCALES = (1.0 / 16,) +_C.MODEL.ROI_MASK_HEAD.MLP_HEAD_DIM = 1024 +_C.MODEL.ROI_MASK_HEAD.CONV_LAYERS = (256, 256, 256, 256) +_C.MODEL.ROI_MASK_HEAD.RESOLUTION = 14 +_C.MODEL.ROI_MASK_HEAD.RESOLUTION_H = 32 +_C.MODEL.ROI_MASK_HEAD.RESOLUTION_W = 128 +_C.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True +_C.MODEL.ROI_MASK_HEAD.CHAR_NUM_CLASSES = 38 +_C.MODEL.ROI_MASK_HEAD.USE_WEIGHTED_CHAR_MASK = False +_C.MODEL.ROI_MASK_HEAD.MASK_BATCH_SIZE_PER_IM = 64 +_C.MODEL.ROI_MASK_HEAD.USE_MASKED_FEATURE = False +_C.MODEL.ROI_MASK_HEAD.SOFT_MASKED_FEATURE_RATIO = 0. +_C.MODEL.ROI_MASK_HEAD.MIX_OPTION = "" + +# ---------------------------------------------------------------------------- # +# Solver +# ---------------------------------------------------------------------------- # +_C.SOLVER = CN() +_C.SOLVER.MAX_ITER = 40000 + +_C.SOLVER.BASE_LR = 0.001 +_C.SOLVER.BIAS_LR_FACTOR = 2 + +_C.SOLVER.MOMENTUM = 0.9 + +_C.SOLVER.WEIGHT_DECAY = 0.0005 +_C.SOLVER.WEIGHT_DECAY_BIAS = 0 + +_C.SOLVER.GAMMA = 0.1 +_C.SOLVER.STEPS = (30000,) + +_C.SOLVER.WARMUP_FACTOR = 1.0 / 3 +_C.SOLVER.WARMUP_ITERS = 500 +_C.SOLVER.WARMUP_METHOD = "linear" + +_C.SOLVER.CHECKPOINT_PERIOD = 5000 + +# Number of images per batch +# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will +# see 2 images per batch +_C.SOLVER.IMS_PER_BATCH = 16 + +_C.SOLVER.RESUME = True + +_C.SOLVER.USE_ADAM = False + +_C.SOLVER.POW_SCHEDULE = False + +_C.SOLVER.DISPLAY_FREQ = 20 + +# ---------------------------------------------------------------------------- # +# Specific test options +# ---------------------------------------------------------------------------- # +_C.TEST = CN() +_C.TEST.EXPECTED_RESULTS = [] +_C.TEST.EXPECTED_RESULTS_SIGMA_TOL = 4 +# Number of images per batch +# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will +# see 2 images per batch +_C.TEST.IMS_PER_BATCH = 8 +_C.TEST.VIS = False +# from 0 to 255 +_C.TEST.CHAR_THRESH = 128 + + +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # +_C.OUTPUT_DIR = "." + +_C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py") + + +# ---------------------------------------------------------------------------- # +# Precision options +# ---------------------------------------------------------------------------- # + +# Precision of input, allowable: (float32, float16) +_C.DTYPE = "float32" + +# Enable verbosity in apex.amp +_C.AMP_VERBOSE = False \ No newline at end of file diff --git a/maskrcnn_benchmark/config/paths_catalog.py b/maskrcnn_benchmark/config/paths_catalog.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb95f7535aa172a9800c078dff9d00777f3ea88 --- /dev/null +++ b/maskrcnn_benchmark/config/paths_catalog.py @@ -0,0 +1,237 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +"""Centralized catalog of paths.""" + +import os + + +class DatasetCatalog(object): + DATA_DIR = "datasets" + # DATA_DIR = "/share/mhliao/MaskTextSpotterV3/datasets/" + + DATASETS = { + "coco_2014_train": ( + "coco/train2014", + "coco/annotations/instances_train2014.json", + ), + "coco_2014_val": ("coco/val2014", "coco/annotations/instances_val2014.json"), + "coco_2014_minival": ( + "coco/val2014", + "coco/annotations/instances_minival2014.json", + ), + "coco_2014_valminusminival": ( + "coco/val2014", + "coco/annotations/instances_valminusminival2014.json", + ), + "icdar_2013_train": ("icdar2013/train_images", "icdar2013/train_gts"), + "icdar_2013_test": ("icdar2013/test_images", "icdar2013/test_gts"), + "rotated_ic13_test_0": ("icdar2013/rotated_test_images_0", "icdar2013/rotated_test_gts_0"), + "rotated_ic13_test_15": ("icdar2013/rotated_test_images_15", "icdar2013/rotated_test_gts_15"), + "rotated_ic13_test_30": ("icdar2013/rotated_test_images_30", "icdar2013/rotated_test_gts_30"), + "rotated_ic13_test_45": ("icdar2013/rotated_test_images_45", "icdar2013/rotated_test_gts_45"), + "rotated_ic13_test_60": ("icdar2013/rotated_test_images_60", "icdar2013/rotated_test_gts_60"), + "rotated_ic13_test_75": ("icdar2013/rotated_test_images_75", "icdar2013/rotated_test_gts_75"), + "rotated_ic13_test_85": ("icdar2013/rotated_test_images_85", "icdar2013/rotated_test_gts_85"), + "rotated_ic13_test_90": ("icdar2013/rotated_test_images_90", "icdar2013/rotated_test_gts_90"), + "rotated_ic13_test_-15": ("icdar2013/rotated_test_images_-15", "icdar2013/rotated_test_gts_-15"), + "rotated_ic13_test_-30": ("icdar2013/rotated_test_images_-30", "icdar2013/rotated_test_gts_-30"), + "rotated_ic13_test_-45": ("icdar2013/rotated_test_images_-45", "icdar2013/rotated_test_gts_-45"), + "rotated_ic13_test_-60": ("icdar2013/rotated_test_images_-60", "icdar2013/rotated_test_gts_-60"), + "rotated_ic13_test_-75": ("icdar2013/rotated_test_images_-75", "icdar2013/rotated_test_gts_-75"), + "rotated_ic13_test_-90": ("icdar2013/rotated_test_images_-90", "icdar2013/rotated_test_gts_-90"), + "icdar_2015_train": ("icdar2015/train_images", "icdar2015/train_gts"), + "icdar_2015_test": ( + "icdar2015/test_images", + # "icdar2015/test_gts", + ), + "synthtext_train": ("synthtext/train_images", "synthtext/train_gts"), + "synthtext_test": ("synthtext/test_images", "synthtext/test_gts"), + "total_text_train": ("total_text/train_images", "total_text/train_gts"), + "td500_train": ("TD_TR/TD500/train_images", "TD500/train_gts"), + "td500_test": ("TD_TR/TD500/test_images", ), + "tr400_train": ("TD_TR/TR400/train_images", "TR400/train_gts"), + "total_text_test": ( + "total_text/test_images", + # "total_text/test_gts", + ), + "scut-eng-char_train": ( + "scut-eng-char/train_images", + "scut-eng-char/train_gts", + ), + } + + @staticmethod + def get(name): + if "coco" in name: + data_dir = DatasetCatalog.DATA_DIR + attrs = DatasetCatalog.DATASETS[name] + args = dict( + root=os.path.join(data_dir, attrs[0]), + ann_file=os.path.join(data_dir, attrs[1]), + ) + return dict(factory="COCODataset", args=args) + elif "icdar_2013" in name: + data_dir = DatasetCatalog.DATA_DIR + attrs = DatasetCatalog.DATASETS[name] + args = dict( + use_charann=True, + imgs_dir=os.path.join(data_dir, attrs[0]), + gts_dir=os.path.join(data_dir, attrs[1]), + # imgs_dir='/tmp/icdar2013/icdar2013/train_images', + # gts_dir='/tmp/icdar2013/icdar2013/train_gts', + ) + return dict(args=args, factory="IcdarDataset") + elif "rotated_ic13" in name: + data_dir = DatasetCatalog.DATA_DIR + attrs = DatasetCatalog.DATASETS[name] + args = dict( + use_charann=True, + imgs_dir=os.path.join(data_dir, attrs[0]), + gts_dir=os.path.join(data_dir, attrs[1]), + ) + return dict(args=args, factory="IcdarDataset") + elif "icdar_2015" in name: + data_dir = DatasetCatalog.DATA_DIR + attrs = DatasetCatalog.DATASETS[name] + if len(attrs) > 1: + gts_dir = os.path.join(data_dir, attrs[1]) + else: + gts_dir = None + + args = dict( + use_charann=False, + imgs_dir=os.path.join(data_dir, attrs[0]), + gts_dir=gts_dir, + # imgs_dir='/tmp/icdar2015/icdar2015/train_images/', + # gts_dir='/tmp/icdar2015/icdar2015/train_gts/', + ) + return dict(args=args, factory="IcdarDataset") + elif "synthtext" in name: + data_dir = DatasetCatalog.DATA_DIR + attrs = DatasetCatalog.DATASETS[name] + args = dict( + use_charann=True, + list_file_path=os.path.join(data_dir, "synthtext/train_list.txt"), + imgs_dir=os.path.join(data_dir, attrs[0]), + gts_dir=os.path.join(data_dir, attrs[1]), + # imgs_dir='/tmp/synth/SynthText/', + # gts_dir='/tmp/synth_gt/SynthText_GT_E2E/', + ) + return dict(args=args, factory="SynthtextDataset") + elif "total_text" in name: + data_dir = DatasetCatalog.DATA_DIR + # data_dir = '/tmp/total_text/' + attrs = DatasetCatalog.DATASETS[name] + if len(attrs) > 1: + gts_dir = os.path.join(data_dir, attrs[1]) + else: + gts_dir = None + args = dict( + use_charann=False, + imgs_dir=os.path.join(data_dir, attrs[0]), + gts_dir=gts_dir, + # imgs_dir='/tmp/total_text/total_text/train_images/', + # gts_dir='/tmp/total_text/total_text/train_gts/', + ) + return dict(args=args, factory="TotaltextDataset") + elif "scut-eng-char" in name: + data_dir = DatasetCatalog.DATA_DIR + attrs = DatasetCatalog.DATASETS[name] + args = dict( + use_charann=True, + imgs_dir=os.path.join(data_dir, attrs[0]), + gts_dir=os.path.join(data_dir, attrs[1]), + # imgs_dir='/tmp/scut-eng-char/scut-eng-char/train_images/', + # gts_dir='/tmp/scut-eng-char/scut-eng-char/train_gts/', + ) + return dict(args=args, factory="ScutDataset") + elif "td500" in name: + data_dir = DatasetCatalog.DATA_DIR + attrs = DatasetCatalog.DATASETS[name] + if len(attrs) > 1: + gts_dir = os.path.join(data_dir, attrs[1]) + else: + gts_dir = None + args = dict( + use_charann=False, + imgs_dir=os.path.join(data_dir, attrs[0]), + gts_dir=gts_dir, + ) + return dict(args=args, factory="TotaltextDataset") + elif "tr400" in name: + data_dir = DatasetCatalog.DATA_DIR + attrs = DatasetCatalog.DATASETS[name] + if len(attrs) > 1: + gts_dir = os.path.join(data_dir, attrs[1]) + else: + gts_dir = None + args = dict( + use_charann=False, + imgs_dir=os.path.join(data_dir, attrs[0]), + gts_dir=gts_dir, + ) + return dict(args=args, factory="TotaltextDataset") + raise RuntimeError("Dataset not available: {}".format(name)) + + +class ModelCatalog(object): + S3_C2_DETECTRON_URL = "https://dl.fbaipublicfiles.com/detectron" + C2_IMAGENET_MODELS = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + "MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl", + "MSRA/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl", + "MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl", + "MSRA/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl", + "FAIR/20171220/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl", + } + + C2_DETECTRON_SUFFIX = "output/train/{}coco_2014_train%3A{}coco_2014_valminusminival/generalized_rcnn/model_final.pkl" + C2_DETECTRON_MODELS = { + "35857197/e2e_faster_rcnn_R-50-C4_1x": "01_33_49.iAX0mXvW", + "35857345/e2e_faster_rcnn_R-50-FPN_1x": "01_36_30.cUF7QR7I", + "35857890/e2e_faster_rcnn_R-101-FPN_1x": "01_38_50.sNxI7sX7", + "36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "06_31_39.5MIHi1fZ", + "35858791/e2e_mask_rcnn_R-50-C4_1x": "01_45_57.ZgkA7hPB", + "35858933/e2e_mask_rcnn_R-50-FPN_1x": "01_48_14.DzEQe4wC", + "35861795/e2e_mask_rcnn_R-101-FPN_1x": "02_31_37.KqyEK4tT", + "36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "06_35_59.RZotkLKI", + "37129812/e2e_mask_rcnn_X-152-32x8d-FPN-IN5k_1.44x": "09_35_36.8pzTQKYK", + # keypoints + "37697547/e2e_keypoint_rcnn_R-50-FPN_1x": "08_42_54.kdzV35ao" + } + + @staticmethod + def get(name): + if name.startswith("Caffe2Detectron/COCO"): + return ModelCatalog.get_c2_detectron_12_2017_baselines(name) + if name.startswith("ImageNetPretrained"): + return ModelCatalog.get_c2_imagenet_pretrained(name) + raise RuntimeError("model not present in the catalog {}".format(name)) + + @staticmethod + def get_c2_imagenet_pretrained(name): + prefix = ModelCatalog.S3_C2_DETECTRON_URL + name = name[len("ImageNetPretrained/") :] + name = ModelCatalog.C2_IMAGENET_MODELS[name] + if 'resnet34' in name or 'resnet18' in name: + return name + url = "/".join([prefix, name]) + return url + + @staticmethod + def get_c2_detectron_12_2017_baselines(name): + # Detectron C2 models are stored following the structure + # prefix//2012_2017_baselines/.yaml./suffix + # we use as identifiers in the catalog Caffe2Detectron/COCO// + prefix = ModelCatalog.S3_C2_DETECTRON_URL + suffix = ModelCatalog.C2_DETECTRON_SUFFIX + # remove identification prefix + name = name[len("Caffe2Detectron/COCO/") :] + # split in and + model_id, model_name = name.split("/") + # parsing to make it match the url address from the Caffe2 models + model_name = "{}.yaml".format(model_name) + signature = ModelCatalog.C2_DETECTRON_MODELS[name] + unique_name = ".".join([model_name, signature]) + url = "/".join([prefix, model_id, "12_2017_baselines", unique_name, suffix]) + return url diff --git a/maskrcnn_benchmark/csrc/ROIAlign.h b/maskrcnn_benchmark/csrc/ROIAlign.h new file mode 100644 index 0000000000000000000000000000000000000000..3907deab2a750a9f83f0f3ef38fee279c1445c61 --- /dev/null +++ b/maskrcnn_benchmark/csrc/ROIAlign.h @@ -0,0 +1,46 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once + +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + +// Interface for Python +at::Tensor ROIAlign_forward(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + +at::Tensor ROIAlign_backward(const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio) { + if (grad.type().is_cuda()) { +#ifdef WITH_CUDA + return ROIAlign_backward_cuda(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + diff --git a/maskrcnn_benchmark/csrc/ROIPool.h b/maskrcnn_benchmark/csrc/ROIPool.h new file mode 100644 index 0000000000000000000000000000000000000000..200fd7390b4629747f0ea9e16c0823ac5f099ac1 --- /dev/null +++ b/maskrcnn_benchmark/csrc/ROIPool.h @@ -0,0 +1,48 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once + +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + + +std::tuple ROIPool_forward(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +at::Tensor ROIPool_backward(const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& rois, + const at::Tensor& argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) { + if (grad.type().is_cuda()) { +#ifdef WITH_CUDA + return ROIPool_backward_cuda(grad, input, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + + + diff --git a/maskrcnn_benchmark/csrc/SigmoidFocalLoss.h b/maskrcnn_benchmark/csrc/SigmoidFocalLoss.h new file mode 100644 index 0000000000000000000000000000000000000000..308861e44774dffd89b3f5ebff7cc6c5491fe3a5 --- /dev/null +++ b/maskrcnn_benchmark/csrc/SigmoidFocalLoss.h @@ -0,0 +1,41 @@ +#pragma once + +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + +// Interface for Python +at::Tensor SigmoidFocalLoss_forward( + const at::Tensor& logits, + const at::Tensor& targets, + const int num_classes, + const float gamma, + const float alpha) { + if (logits.type().is_cuda()) { +#ifdef WITH_CUDA + return SigmoidFocalLoss_forward_cuda(logits, targets, num_classes, gamma, alpha); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +at::Tensor SigmoidFocalLoss_backward( + const at::Tensor& logits, + const at::Tensor& targets, + const at::Tensor& d_losses, + const int num_classes, + const float gamma, + const float alpha) { + if (logits.type().is_cuda()) { +#ifdef WITH_CUDA + return SigmoidFocalLoss_backward_cuda(logits, targets, d_losses, num_classes, gamma, alpha); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} diff --git a/maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp b/maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d531da623997781b599a2702cdf2ec04be583bac --- /dev/null +++ b/maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp @@ -0,0 +1,257 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include "cpu/vision.h" + +// implementation taken from Caffe2 +template +struct PreCalc { + int pos1; + int pos2; + int pos3; + int pos4; + T w1; + T w2; + T w3; + T w4; +}; + +template +void pre_calc_for_bilinear_interpolate( + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int iy_upper, + const int ix_upper, + T roi_start_h, + T roi_start_w, + T bin_size_h, + T bin_size_w, + int roi_bin_grid_h, + int roi_bin_grid_w, + std::vector>& pre_calc) { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + for (int iy = 0; iy < iy_upper; iy++) { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < ix_upper; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T x = xx; + T y = yy; + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + PreCalc pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc[pre_calc_index] = pc; + pre_calc_index += 1; + continue; + } + + if (y <= 0) { + y = 0; + } + if (x <= 0) { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // save weights and indices + PreCalc pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc[pre_calc_index] = pc; + + pre_calc_index += 1; + } + } + } + } +} + +template +void ROIAlignForward_cpu_kernel( + const int nthreads, + const T* bottom_data, + const T& spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const T* bottom_rois, + //int roi_cols, + T* top_data) { + //AT_ASSERT(roi_cols == 4 || roi_cols == 5); + int roi_cols = 5; + + int n_rois = nthreads / channels / pooled_width / pooled_height; + // (n, c, ph, pw) is an element in the pooled output + // can be parallelized using omp + // #pragma omp parallel for num_threads(32) + for (int n = 0; n < n_rois; n++) { + int index_n = n * channels * pooled_width * pooled_height; + + // roi could have 4 or 5 columns + const T* offset_bottom_rois = bottom_rois + n * roi_cols; + int roi_batch_ind = 0; + if (roi_cols == 5) { + roi_batch_ind = offset_bottom_rois[0]; + offset_bottom_rois++; + } + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_bottom_rois[0] * spatial_scale; + T roi_start_h = offset_bottom_rois[1] * spatial_scale; + T roi_end_w = offset_bottom_rois[2] * spatial_scale; + T roi_end_h = offset_bottom_rois[3] * spatial_scale; + // T roi_start_w = round(offset_bottom_rois[0] * spatial_scale); + // T roi_start_h = round(offset_bottom_rois[1] * spatial_scale); + // T roi_end_w = round(offset_bottom_rois[2] * spatial_scale); + // T roi_end_h = round(offset_bottom_rois[3] * spatial_scale); + + // Force malformed ROIs to be 1x1 + T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); + T roi_height = std::max(roi_end_h - roi_start_h, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + // we want to precalculate indices and weights shared by all channels, + // this is the key point of optimization + std::vector> pre_calc( + roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); + pre_calc_for_bilinear_interpolate( + height, + width, + pooled_height, + pooled_width, + roi_bin_grid_h, + roi_bin_grid_w, + roi_start_h, + roi_start_w, + bin_size_h, + bin_size_w, + roi_bin_grid_h, + roi_bin_grid_w, + pre_calc); + + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * pooled_width * pooled_height; + const T* offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + int index = index_n_c + ph * pooled_width + pw; + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + PreCalc pc = pre_calc[pre_calc_index]; + output_val += pc.w1 * offset_bottom_data[pc.pos1] + + pc.w2 * offset_bottom_data[pc.pos2] + + pc.w3 * offset_bottom_data[pc.pos3] + + pc.w4 * offset_bottom_data[pc.pos4]; + + pre_calc_index += 1; + } + } + output_val /= count; + + top_data[index] = output_val; + } // for pw + } // for ph + } // for c + } // for n +} + +at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + AT_ASSERTM(!input.type().is_cuda(), "input must be a CPU tensor"); + AT_ASSERTM(!rois.type().is_cuda(), "rois must be a CPU tensor"); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options()); + auto output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return output; + } + + AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { + ROIAlignForward_cpu_kernel( + output_size, + input.data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois.data(), + output.data()); + }); + return output; +} diff --git a/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp b/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1153dea04f032c67c41bd0d2a285376a72c5a595 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp @@ -0,0 +1,75 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include "cpu/vision.h" + + +template +at::Tensor nms_cpu_kernel(const at::Tensor& dets, + const at::Tensor& scores, + const float threshold) { + AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor"); + AT_ASSERTM(!scores.type().is_cuda(), "scores must be a CPU tensor"); + AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores"); + + if (dets.numel() == 0) { + return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); + } + + auto x1_t = dets.select(1, 0).contiguous(); + auto y1_t = dets.select(1, 1).contiguous(); + auto x2_t = dets.select(1, 2).contiguous(); + auto y2_t = dets.select(1, 3).contiguous(); + + at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1); + + auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); + + auto ndets = dets.size(0); + at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte).device(at::kCPU)); + + auto suppressed = suppressed_t.data(); + auto order = order_t.data(); + auto x1 = x1_t.data(); + auto y1 = y1_t.data(); + auto x2 = x2_t.data(); + auto y2 = y2_t.data(); + auto areas = areas_t.data(); + + for (int64_t _i = 0; _i < ndets; _i++) { + auto i = order[_i]; + if (suppressed[i] == 1) + continue; + auto ix1 = x1[i]; + auto iy1 = y1[i]; + auto ix2 = x2[i]; + auto iy2 = y2[i]; + auto iarea = areas[i]; + + for (int64_t _j = _i + 1; _j < ndets; _j++) { + auto j = order[_j]; + if (suppressed[j] == 1) + continue; + auto xx1 = std::max(ix1, x1[j]); + auto yy1 = std::max(iy1, y1[j]); + auto xx2 = std::min(ix2, x2[j]); + auto yy2 = std::min(iy2, y2[j]); + + auto w = std::max(static_cast(0), xx2 - xx1 + 1); + auto h = std::max(static_cast(0), yy2 - yy1 + 1); + auto inter = w * h; + auto ovr = inter / (iarea + areas[j] - inter); + if (ovr >= threshold) + suppressed[j] = 1; + } + } + return at::nonzero(suppressed_t == 0).squeeze(1); +} + +at::Tensor nms_cpu(const at::Tensor& dets, + const at::Tensor& scores, + const float threshold) { + at::Tensor result; + AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] { + result = nms_cpu_kernel(dets, scores, threshold); + }); + return result; +} diff --git a/maskrcnn_benchmark/csrc/cpu/vision.h b/maskrcnn_benchmark/csrc/cpu/vision.h new file mode 100644 index 0000000000000000000000000000000000000000..92611253616c16efdbed66318da9930b233ae09c --- /dev/null +++ b/maskrcnn_benchmark/csrc/cpu/vision.h @@ -0,0 +1,16 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include + + +at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio); + + +at::Tensor nms_cpu(const at::Tensor& dets, + const at::Tensor& scores, + const float threshold); diff --git a/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu b/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..1142fb37597141122ee63161d0abd7beac510a74 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu @@ -0,0 +1,346 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include +#include + +#include +#include +#include + +// TODO make it in a common file +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +template +__device__ T bilinear_interpolate(const T* bottom_data, + const int height, const int width, + T y, T x, + const int index /* index for debug only*/) { + + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + //empty + return 0; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + int y_low = (int) y; + int x_low = (int) x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T) y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T) x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation + T v1 = bottom_data[y_low * width + x_low]; + T v2 = bottom_data[y_low * width + x_high]; + T v3 = bottom_data[y_high * width + x_low]; + T v4 = bottom_data[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +__global__ void RoIAlignForward(const int nthreads, const T* bottom_data, + const T spatial_scale, const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int sampling_ratio, + const T* bottom_rois, T* top_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_bottom_rois[1] * spatial_scale; + T roi_start_h = offset_bottom_rois[2] * spatial_scale; + T roi_end_w = offset_bottom_rois[3] * spatial_scale; + T roi_end_h = offset_bottom_rois[4] * spatial_scale; + // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); + // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); + // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); + // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + T roi_width = max(roi_end_w - roi_start_w, (T)1.); + T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix ++) + { + const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index); + output_val += val; + } + } + output_val /= count; + + top_data[index] = output_val; + } +} + + +template +__device__ void bilinear_interpolate_gradient( + const int height, const int width, + T y, T x, + T & w1, T & w2, T & w3, T & w4, + int & x_low, int & x_high, int & y_low, int & y_high, + const int index /* index for debug only*/) { + + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + //empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + y_low = (int) y; + x_low = (int) x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T) y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T) x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = bottom_data[y_low * width + x_low]; + // T v2 = bottom_data[y_low * width + x_high]; + // T v3 = bottom_data[y_high * width + x_low]; + // T v4 = bottom_data[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +template +__global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff, + const int num_rois, const T spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, + const int sampling_ratio, + T* bottom_diff, + const T* bottom_rois) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_bottom_rois[1] * spatial_scale; + T roi_start_h = offset_bottom_rois[2] * spatial_scale; + T roi_end_w = offset_bottom_rois[3] * spatial_scale; + T roi_end_h = offset_bottom_rois[4] * spatial_scale; + // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); + // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); + // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); + // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + T roi_width = max(roi_end_w - roi_start_w, (T)1.); + T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width; + + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const T* offset_top_diff = top_diff + top_offset; + const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix ++) + { + const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient(height, width, y, x, + w1, w2, w3, w4, + x_low, x_high, y_low, y_high, + index); + + T g1 = top_diff_this_bin * w1 / count; + T g2 = top_diff_this_bin * w2 / count; + T g3 = top_diff_this_bin * w3 / count; + T g4 = top_diff_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) + { + atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast(g1)); + atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast(g2)); + atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast(g3)); + atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } // CUDA_1D_KERNEL_LOOP +} // RoIAlignBackward + + +at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options()); + auto output_size = num_rois * pooled_height * pooled_width * channels; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L)); + dim3 block(512); + + if (output.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return output; + } + + AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { + RoIAlignForward<<>>( + output_size, + input.contiguous().data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois.contiguous().data(), + output.data()); + }); + THCudaCheck(cudaGetLastError()); + return output; +} + +// TODO remove the dependency on input and use instead its sizes -> save memory +at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio) { + AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + + auto num_rois = rois.size(0); + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L)); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return grad_input; + } + + AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] { + RoIAlignBackwardFeature<<>>( + grad.numel(), + grad.contiguous().data(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + grad_input.data(), + rois.contiguous().data()); + }); + THCudaCheck(cudaGetLastError()); + return grad_input; +} diff --git a/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu b/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..8f072ffc2bd6de310f0d92c8c513dd9cfcc80dbc --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu @@ -0,0 +1,202 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include +#include + +#include +#include +#include + + +// TODO make it in a common file +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +template +__global__ void RoIPoolFForward(const int nthreads, const T* bottom_data, + const T spatial_scale, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const T* bottom_rois, T* top_data, int* argmax_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); + int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); + int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); + int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + T bin_size_h = static_cast(roi_height) + / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) + / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) + * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) + * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) + * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) + * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), height); + hend = min(max(hend + roi_start_h, 0), height); + wstart = min(max(wstart + roi_start_w, 0), width); + wend = min(max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + T maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + const T* offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int bottom_index = h * width + w; + if (offset_bottom_data[bottom_index] > maxval) { + maxval = offset_bottom_data[bottom_index]; + maxidx = bottom_index; + } + } + } + top_data[index] = maxval; + argmax_data[index] = maxidx; + } +} + +template +__global__ void RoIPoolFBackward(const int nthreads, const T* top_diff, + const int* argmax_data, const int num_rois, const T spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, T* bottom_diff, + const T* bottom_rois) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + int bottom_offset = (roi_batch_ind * channels + c) * height * width; + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const T* offset_top_diff = top_diff + top_offset; + T* offset_bottom_diff = bottom_diff + bottom_offset; + const int* offset_argmax_data = argmax_data + top_offset; + + int argmax = offset_argmax_data[ph * pooled_width + pw]; + if (argmax != -1) { + atomicAdd( + offset_bottom_diff + argmax, + static_cast(offset_top_diff[ph * pooled_width + pw])); + + } + } +} + +std::tuple ROIPool_forward_cuda(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) { + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options()); + auto output_size = num_rois * pooled_height * pooled_width * channels; + auto argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt)); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L)); + dim3 block(512); + + if (output.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return std::make_tuple(output, argmax); + } + + AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIPool_forward", [&] { + RoIPoolFForward<<>>( + output_size, + input.contiguous().data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois.contiguous().data(), + output.data(), + argmax.data()); + }); + THCudaCheck(cudaGetLastError()); + return std::make_tuple(output, argmax); +} + +// TODO remove the dependency on input and use instead its sizes -> save memory +at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& rois, + const at::Tensor& argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) { + AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + // TODO add more checks + + auto num_rois = rois.size(0); + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L)); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return grad_input; + } + + AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIPool_backward", [&] { + RoIPoolFBackward<<>>( + grad.numel(), + grad.contiguous().data(), + argmax.data(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + grad_input.data(), + rois.contiguous().data()); + }); + THCudaCheck(cudaGetLastError()); + return grad_input; +} diff --git a/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu b/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..456a5f2354ed6ab14f439ab6b3f5f8d21aff24f5 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu @@ -0,0 +1,189 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This file is modified from https://github.com/pytorch/pytorch/blob/master/modules/detectron/sigmoid_focal_loss_op.cu +// Cheng-Yang Fu +// cyfu@cs.unc.edu +#include +#include + +#include +#include +#include + +#include + +// TODO make it in a common file +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +template +__global__ void SigmoidFocalLossForward(const int nthreads, + const T* logits, + const int* targets, + const int num_classes, + const float gamma, + const float alpha, + const int num, + T* losses) { + CUDA_1D_KERNEL_LOOP(i, nthreads) { + + int n = i / num_classes; + int d = i % num_classes; // current class[0~79]; + int t = targets[n]; // target class [1~80]; + + // Decide it is positive or negative case. + T c1 = (t == (d+1)); + T c2 = (t>=0 & t != (d+1)); + + T zn = (1.0 - alpha); + T zp = (alpha); + + // p = 1. / 1. + expf(-x); p = sigmoid(x) + T p = 1. / (1. + expf(-logits[i])); + + // (1-p)**gamma * log(p) where + T term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN)); + + // p**gamma * log(1-p) + T term2 = powf(p, gamma) * + (-1. * logits[i] * (logits[i] >= 0) - + logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))); + + losses[i] = 0.0; + losses[i] += -c1 * term1 * zp; + losses[i] += -c2 * term2 * zn; + + } // CUDA_1D_KERNEL_LOOP +} // SigmoidFocalLossForward + + +template +__global__ void SigmoidFocalLossBackward(const int nthreads, + const T* logits, + const int* targets, + const T* d_losses, + const int num_classes, + const float gamma, + const float alpha, + const int num, + T* d_logits) { + CUDA_1D_KERNEL_LOOP(i, nthreads) { + + int n = i / num_classes; + int d = i % num_classes; // current class[0~79]; + int t = targets[n]; // target class [1~80], 0 is background; + + // Decide it is positive or negative case. + T c1 = (t == (d+1)); + T c2 = (t>=0 & t != (d+1)); + + T zn = (1.0 - alpha); + T zp = (alpha); + // p = 1. / 1. + expf(-x); p = sigmoid(x) + T p = 1. / (1. + expf(-logits[i])); + + // (1-p)**g * (1 - p - g*p*log(p) + T term1 = powf((1. - p), gamma) * + (1. - p - (p * gamma * logf(max(p, FLT_MIN)))); + + // (p**g) * (g*(1-p)*log(1-p) - p) + T term2 = powf(p, gamma) * + ((-1. * logits[i] * (logits[i] >= 0) - + logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))) * + (1. - p) * gamma - p); + d_logits[i] = 0.0; + d_logits[i] += -c1 * term1 * zp; + d_logits[i] += -c2 * term2 * zn; + d_logits[i] = d_logits[i] * d_losses[i]; + + } // CUDA_1D_KERNEL_LOOP +} // SigmoidFocalLossBackward + + +at::Tensor SigmoidFocalLoss_forward_cuda( + const at::Tensor& logits, + const at::Tensor& targets, + const int num_classes, + const float gamma, + const float alpha) { + AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor"); + AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor"); + AT_ASSERTM(logits.dim() == 2, "logits should be NxClass"); + + const int num_samples = logits.size(0); + + auto losses = at::empty({num_samples, logits.size(1)}, logits.options()); + auto losses_size = num_samples * logits.size(1); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv((long)losses_size, 512L), 4096L)); + + dim3 block(512); + + if (losses.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return losses; + } + + AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_forward", [&] { + SigmoidFocalLossForward<<>>( + losses_size, + logits.contiguous().data(), + targets.contiguous().data(), + num_classes, + gamma, + alpha, + num_samples, + losses.data()); + }); + THCudaCheck(cudaGetLastError()); + return losses; +} + + +at::Tensor SigmoidFocalLoss_backward_cuda( + const at::Tensor& logits, + const at::Tensor& targets, + const at::Tensor& d_losses, + const int num_classes, + const float gamma, + const float alpha) { + AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor"); + AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor"); + AT_ASSERTM(d_losses.type().is_cuda(), "d_losses must be a CUDA tensor"); + + AT_ASSERTM(logits.dim() == 2, "logits should be NxClass"); + + const int num_samples = logits.size(0); + AT_ASSERTM(logits.size(1) == num_classes, "logits.size(1) should be num_classes"); + + auto d_logits = at::zeros({num_samples, num_classes}, logits.options()); + auto d_logits_size = num_samples * logits.size(1); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv((long)d_logits_size, 512L), 4096L)); + dim3 block(512); + + if (d_logits.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return d_logits; + } + + AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_backward", [&] { + SigmoidFocalLossBackward<<>>( + d_logits_size, + logits.contiguous().data(), + targets.contiguous().data(), + d_losses.contiguous().data(), + num_classes, + gamma, + alpha, + num_samples, + d_logits.data()); + }); + + THCudaCheck(cudaGetLastError()); + return d_logits; +} + diff --git a/maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..74f7d339900fd77f2bd63e8a8481cab372ba9e1c --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu @@ -0,0 +1,691 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#include +#include +#include + + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) +{ + AT_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + AT_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + AT_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + AT_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + AT_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + AT_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + AT_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + AT_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + AT_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + AT_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) +{ + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) +{ + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) +{ + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) +{ + AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) +{ + AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} diff --git a/maskrcnn_benchmark/csrc/cuda/deform_conv_kernel_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_conv_kernel_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..b4f8813b431dbd67fcfab634b21e942135a10fbd --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/deform_conv_kernel_cuda.cu @@ -0,0 +1,874 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + + +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +/* +const int CUDA_NUM_THREADS = 1024; + +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +}*/ + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data(); + const scalar_t *data_offset_ = data_offset.data(); + scalar_t *data_col_ = data_col.data(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data(); + const scalar_t *data_offset_ = data_offset.data(); + scalar_t *grad_im_ = grad_im.data(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data(); + const scalar_t *data_im_ = data_im.data(); + const scalar_t *data_offset_ = data_offset.data(); + scalar_t *grad_offset_ = grad_offset.data(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data(); + const scalar_t *data_offset_ = data_offset.data(); + const scalar_t *data_mask_ = data_mask.data(); + scalar_t *data_col_ = data_col.data(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data(); + const scalar_t *data_offset_ = data_offset.data(); + const scalar_t *data_mask_ = data_mask.data(); + scalar_t *grad_im_ = grad_im.data(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data(); + const scalar_t *data_im_ = data_im.data(); + const scalar_t *data_offset_ = data_offset.data(); + const scalar_t *data_mask_ = data_mask.data(); + scalar_t *grad_offset_ = grad_offset.data(); + scalar_t *grad_mask_ = grad_mask.data(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..71f305af96037cd545b4feed501deb7ec285f574 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu @@ -0,0 +1,87 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c + +// based on +// author: Charles Shang +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu + +#include +#include + +#include +#include + +#include +#include +#include + + +void DeformablePSROIPoolForward( + const at::Tensor data, const at::Tensor bbox, const at::Tensor trans, + at::Tensor out, at::Tensor top_count, const int batch, const int channels, + const int height, const int width, const int num_bbox, + const int channels_trans, const int no_trans, const float spatial_scale, + const int output_dim, const int group_size, const int pooled_size, + const int part_size, const int sample_per_part, const float trans_std); + +void DeformablePSROIPoolBackwardAcc( + const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox, + const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad, + at::Tensor trans_grad, const int batch, const int channels, + const int height, const int width, const int num_bbox, + const int channels_trans, const int no_trans, const float spatial_scale, + const int output_dim, const int group_size, const int pooled_size, + const int part_size, const int sample_per_part, const float trans_std); + +void deform_psroi_pooling_cuda_forward( + at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out, + at::Tensor top_count, const int no_trans, const float spatial_scale, + const int output_dim, const int group_size, const int pooled_size, + const int part_size, const int sample_per_part, const float trans_std) +{ + AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + const int channels_trans = no_trans ? 2 : trans.size(1); + + const int num_bbox = bbox.size(0); + if (num_bbox != out.size(0)) + AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", + out.size(0), num_bbox); + + DeformablePSROIPoolForward( + input, bbox, trans, out, top_count, batch, channels, height, width, + num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size, + pooled_size, part_size, sample_per_part, trans_std); +} + +void deform_psroi_pooling_cuda_backward( + at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans, + at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad, + const int no_trans, const float spatial_scale, const int output_dim, + const int group_size, const int pooled_size, const int part_size, + const int sample_per_part, const float trans_std) +{ + AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); + AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + const int channels_trans = no_trans ? 2 : trans.size(1); + + const int num_bbox = bbox.size(0); + if (num_bbox != out_grad.size(0)) + AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", + out_grad.size(0), num_bbox); + + DeformablePSROIPoolBackwardAcc( + out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch, + channels, height, width, num_bbox, channels_trans, no_trans, + spatial_scale, output_dim, group_size, pooled_size, part_size, + sample_per_part, trans_std); +} diff --git a/maskrcnn_benchmark/csrc/cuda/deform_pool_kernel_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_pool_kernel_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..127899ec68006e8b0d5c6a7fb420b34449e0b31c --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/deform_pool_kernel_cuda.cu @@ -0,0 +1,365 @@ +/*! + * Copyright (c) 2017 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file deformable_psroi_pooling.cu + * \brief + * \author Yi Li, Guodong Zhang, Jifeng Dai +*/ +/***************** Adapted by Charles Shang *********************/ +// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu + + +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +template +__device__ scalar_t bilinear_interp( + const scalar_t *data, + const scalar_t x, + const scalar_t y, + const int width, + const int height) +{ + int x1 = floor(x); + int x2 = ceil(x); + int y1 = floor(y); + int y2 = ceil(y); + scalar_t dist_x = (scalar_t)(x - x1); + scalar_t dist_y = (scalar_t)(y - y1); + scalar_t value11 = data[y1 * width + x1]; + scalar_t value12 = data[y2 * width + x1]; + scalar_t value21 = data[y1 * width + x2]; + scalar_t value22 = data[y2 * width + x2]; + scalar_t value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; + return value; +} + +template +__global__ void DeformablePSROIPoolForwardKernel( + const int count, + const scalar_t *bottom_data, + const scalar_t spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const scalar_t *bottom_rois, const scalar_t *bottom_trans, + const int no_trans, + const scalar_t trans_std, + const int sample_per_part, + const int output_dim, + const int group_size, + const int part_size, + const int num_classes, + const int channels_each_class, + scalar_t *top_data, + scalar_t *top_count) +{ + CUDA_KERNEL_LOOP(index, count) + { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const scalar_t *offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 + scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1); + + // Compute w and h at bottom + scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height); + scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width); + + scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part); + scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part); + + int part_h = floor((scalar_t)(ph) / pooled_height * part_size); + int part_w = floor((scalar_t)(pw) / pooled_width * part_size); + int class_id = ctop / channels_each_class; + scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std; + scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std; + + scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + scalar_t sum = 0; + int count = 0; + int gw = floor((scalar_t)(pw)*group_size / pooled_width); + int gh = floor((scalar_t)(ph)*group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + + const scalar_t *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; + for (int ih = 0; ih < sample_per_part; ih++) + { + for (int iw = 0; iw < sample_per_part; iw++) + { + scalar_t w = wstart + iw * sub_bin_size_w; + scalar_t h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) + { + continue; + } + w = min(max(w, 0.), width - 1.); + h = min(max(h, 0.), height - 1.); + int c = (ctop * group_size + gh) * group_size + gw; + scalar_t val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height); + sum += val; + count++; + } + } + top_data[index] = count == 0 ? (scalar_t)(0) : sum / count; + top_count[index] = count; + } +} + +template +__global__ void DeformablePSROIPoolBackwardAccKernel( + const int count, + const scalar_t *top_diff, + const scalar_t *top_count, + const int num_rois, + const scalar_t spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int output_dim, + scalar_t *bottom_data_diff, scalar_t *bottom_trans_diff, + const scalar_t *bottom_data, + const scalar_t *bottom_rois, + const scalar_t *bottom_trans, + const int no_trans, + const scalar_t trans_std, + const int sample_per_part, + const int group_size, + const int part_size, + const int num_classes, + const int channels_each_class) +{ + CUDA_KERNEL_LOOP(index, count) + { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const scalar_t *offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 + scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1); + + // Compute w and h at bottom + scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height); + scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width); + + scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part); + scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part); + + int part_h = floor((scalar_t)(ph) / pooled_height * part_size); + int part_w = floor((scalar_t)(pw) / pooled_width * part_size); + int class_id = ctop / channels_each_class; + scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std; + scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std; + + scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + if (top_count[index] <= 0) + { + continue; + } + scalar_t diff_val = top_diff[index] / top_count[index]; + const scalar_t *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; + scalar_t *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; + int gw = floor((scalar_t)(pw)*group_size / pooled_width); + int gh = floor((scalar_t)(ph)*group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + + for (int ih = 0; ih < sample_per_part; ih++) + { + for (int iw = 0; iw < sample_per_part; iw++) + { + scalar_t w = wstart + iw * sub_bin_size_w; + scalar_t h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) + { + continue; + } + w = min(max(w, 0.), width - 1.); + h = min(max(h, 0.), height - 1.); + int c = (ctop * group_size + gh) * group_size + gw; + // backward on feature + int x0 = floor(w); + int x1 = ceil(w); + int y0 = floor(h); + int y1 = ceil(h); + scalar_t dist_x = w - x0, dist_y = h - y0; + scalar_t q00 = (1 - dist_x) * (1 - dist_y); + scalar_t q01 = (1 - dist_x) * dist_y; + scalar_t q10 = dist_x * (1 - dist_y); + scalar_t q11 = dist_x * dist_y; + int bottom_index_base = c * height * width; + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); + + if (no_trans) + { + continue; + } + scalar_t U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; + scalar_t U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; + scalar_t U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; + scalar_t U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; + scalar_t diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; + diff_x *= roi_width; + scalar_t diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; + diff_y *= roi_height; + + atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); + atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y); + } + } + } +} + +void DeformablePSROIPoolForward(const at::Tensor data, + const at::Tensor bbox, + const at::Tensor trans, + at::Tensor out, + at::Tensor top_count, + const int batch, + const int channels, + const int height, + const int width, + const int num_bbox, + const int channels_trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + const int pooled_height = pooled_size; + const int pooled_width = pooled_size; + const int count = num_bbox * output_dim * pooled_height * pooled_width; + const int num_classes = no_trans ? 1 : channels_trans / 2; + const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data.type(), "deformable_psroi_pool_forward", ([&] { + const scalar_t *bottom_data = data.data(); + const scalar_t *bottom_rois = bbox.data(); + const scalar_t *bottom_trans = no_trans ? NULL : trans.data(); + scalar_t *top_data = out.data(); + scalar_t *top_count_data = top_count.data(); + + DeformablePSROIPoolForwardKernel<<>>( + count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width, + bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim, + group_size, part_size, num_classes, channels_each_class, top_data, top_count_data); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); + } +} + +void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad, + const at::Tensor data, + const at::Tensor bbox, + const at::Tensor trans, + const at::Tensor top_count, + at::Tensor in_grad, + at::Tensor trans_grad, + const int batch, + const int channels, + const int height, + const int width, + const int num_bbox, + const int channels_trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + // LOG(INFO) << "DeformablePSROIPoolBackward"; + const int num_rois = num_bbox; + const int pooled_height = pooled_size; + const int pooled_width = pooled_size; + const int count = num_bbox * output_dim * pooled_height * pooled_width; + const int num_classes = no_trans ? 1 : channels_trans / 2; + const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + out_grad.type(), "deformable_psroi_pool_backward_acc", ([&] { + const scalar_t *top_diff = out_grad.data(); + const scalar_t *bottom_data = data.data(); + const scalar_t *bottom_rois = bbox.data(); + const scalar_t *bottom_trans = no_trans ? NULL : trans.data(); + scalar_t *bottom_data_diff = in_grad.data(); + scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data(); + const scalar_t *top_count_data = top_count.data(); + + DeformablePSROIPoolBackwardAccKernel<<>>( + count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width, + pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff, + bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, + group_size, part_size, num_classes, channels_each_class); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/maskrcnn_benchmark/csrc/cuda/nms.cu b/maskrcnn_benchmark/csrc/cuda/nms.cu new file mode 100644 index 0000000000000000000000000000000000000000..833d8523a5809d99a1078a144a384c864a9d8df9 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/nms.cu @@ -0,0 +1,131 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include +#include + +#include +#include + +#include +#include + +int const threadsPerBlock = sizeof(unsigned long long) * 8; + +__device__ inline float devIoU(float const * const a, float const * const b) { + float left = max(a[0], b[0]), right = min(a[2], b[2]); + float top = max(a[1], b[1]), bottom = min(a[3], b[3]); + float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); + float interS = width * height; + float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); + float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); + return interS / (Sa + Sb - interS); +} + +__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, + const float *dev_boxes, unsigned long long *dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + __shared__ float block_boxes[threadsPerBlock * 5]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const float *cur_box = dev_boxes + cur_box_idx * 5; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +// boxes is a N x 5 tensor +at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { + using scalar_t = float; + AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor"); + auto scores = boxes.select(1, 4); + auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); + auto boxes_sorted = boxes.index_select(0, order_t); + + int boxes_num = boxes.size(0); + + const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock); + + scalar_t* boxes_dev = boxes_sorted.data(); + + THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState + + unsigned long long* mask_dev = NULL; + //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev, + // boxes_num * col_blocks * sizeof(unsigned long long))); + + mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long)); + + dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), + THCCeilDiv(boxes_num, threadsPerBlock)); + dim3 threads(threadsPerBlock); + nms_kernel<<>>(boxes_num, + nms_overlap_thresh, + boxes_dev, + mask_dev); + + std::vector mask_host(boxes_num * col_blocks); + THCudaCheck(cudaMemcpy(&mask_host[0], + mask_dev, + sizeof(unsigned long long) * boxes_num * col_blocks, + cudaMemcpyDeviceToHost)); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); + + at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU)); + int64_t* keep_out = keep.data(); + + int num_to_keep = 0; + for (int i = 0; i < boxes_num; i++) { + int nblock = i / threadsPerBlock; + int inblock = i % threadsPerBlock; + + if (!(remv[nblock] & (1ULL << inblock))) { + keep_out[num_to_keep++] = i; + unsigned long long *p = &mask_host[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } + } + + THCudaFree(state, mask_dev); + // TODO improve this part + return std::get<0>(order_t.index({ + keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to( + order_t.device(), keep.scalar_type()) + }).sort(0, false)); +} diff --git a/maskrcnn_benchmark/csrc/cuda/vision.h b/maskrcnn_benchmark/csrc/cuda/vision.h new file mode 100644 index 0000000000000000000000000000000000000000..32d3c695605eb7c36aec01740075f4d0e76a75ef --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/vision.h @@ -0,0 +1,116 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include + + +at::Tensor SigmoidFocalLoss_forward_cuda( + const at::Tensor& logits, + const at::Tensor& targets, + const int num_classes, + const float gamma, + const float alpha); + +at::Tensor SigmoidFocalLoss_backward_cuda( + const at::Tensor& logits, + const at::Tensor& targets, + const at::Tensor& d_losses, + const int num_classes, + const float gamma, + const float alpha); + +at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio); + +at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio); + + +std::tuple ROIPool_forward_cuda(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width); + +at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& rois, + const at::Tensor& argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width); + +at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh); + + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); + +void deform_psroi_pooling_cuda_forward( + at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out, + at::Tensor top_count, const int no_trans, const float spatial_scale, + const int output_dim, const int group_size, const int pooled_size, + const int part_size, const int sample_per_part, const float trans_std); + +void deform_psroi_pooling_cuda_backward( + at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans, + at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad, + const int no_trans, const float spatial_scale, const int output_dim, + const int group_size, const int pooled_size, const int part_size, + const int sample_per_part, const float trans_std); + + +at::Tensor compute_flow_cuda(const at::Tensor& boxes, + const int height, + const int width); diff --git a/maskrcnn_benchmark/csrc/deform_conv.h b/maskrcnn_benchmark/csrc/deform_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..a5930e390518c67d1a618d8be3f3eb11a2a964c0 --- /dev/null +++ b/maskrcnn_benchmark/csrc/deform_conv.h @@ -0,0 +1,191 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + + +// Interface for Python +int deform_conv_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor output, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) +{ + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_forward_cuda( + input, weight, offset, output, columns, ones, + kW, kH, dW, dH, padW, padH, dilationW, dilationH, + group, deformable_group, im2col_step + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + + +int deform_conv_backward_input( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradInput, + at::Tensor gradOffset, + at::Tensor weight, + at::Tensor columns, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) +{ + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_input_cuda( + input, offset, gradOutput, gradInput, gradOffset, weight, columns, + kW, kH, dW, dH, padW, padH, dilationW, dilationH, + group, deformable_group, im2col_step + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + + +int deform_conv_backward_parameters( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + float scale, + int im2col_step) +{ + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_parameters_cuda( + input, offset, gradOutput, gradWeight, columns, ones, + kW, kH, dW, dH, padW, padH, dilationW, dilationH, + group, deformable_group, scale, im2col_step + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + + +void modulated_deform_conv_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor output, + at::Tensor columns, + int kernel_h, + int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int group, + const int deformable_group, + const bool with_bias) +{ + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_forward( + input, weight, bias, ones, offset, mask, output, columns, + kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, + group, deformable_group, with_bias + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + + +void modulated_deform_conv_backward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor columns, + at::Tensor grad_input, + at::Tensor grad_weight, + at::Tensor grad_bias, + at::Tensor grad_offset, + at::Tensor grad_mask, + at::Tensor grad_output, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w, + int group, + int deformable_group, + const bool with_bias) +{ + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_backward( + input, weight, bias, ones, offset, mask, columns, + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, + group, deformable_group, with_bias + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} \ No newline at end of file diff --git a/maskrcnn_benchmark/csrc/deform_pool.h b/maskrcnn_benchmark/csrc/deform_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..234223809bad726a8ecf71697b9281e75eec5288 --- /dev/null +++ b/maskrcnn_benchmark/csrc/deform_pool.h @@ -0,0 +1,70 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + + +// Interface for Python +void deform_psroi_pooling_forward( + at::Tensor input, + at::Tensor bbox, + at::Tensor trans, + at::Tensor out, + at::Tensor top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return deform_psroi_pooling_cuda_forward( + input, bbox, trans, out, top_count, + no_trans, spatial_scale, output_dim, group_size, + pooled_size, part_size, sample_per_part, trans_std + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + + +void deform_psroi_pooling_backward( + at::Tensor out_grad, + at::Tensor input, + at::Tensor bbox, + at::Tensor trans, + at::Tensor top_count, + at::Tensor input_grad, + at::Tensor trans_grad, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return deform_psroi_pooling_cuda_backward( + out_grad, input, bbox, trans, top_count, input_grad, trans_grad, + no_trans, spatial_scale, output_dim, group_size, pooled_size, + part_size, sample_per_part, trans_std + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} diff --git a/maskrcnn_benchmark/csrc/nms.h b/maskrcnn_benchmark/csrc/nms.h new file mode 100644 index 0000000000000000000000000000000000000000..312fed4a7cb7c1bc6c2345b5e5d678cc6c1a7141 --- /dev/null +++ b/maskrcnn_benchmark/csrc/nms.h @@ -0,0 +1,28 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + + +at::Tensor nms(const at::Tensor& dets, + const at::Tensor& scores, + const float threshold) { + + if (dets.type().is_cuda()) { +#ifdef WITH_CUDA + // TODO raise error if not compiled with CUDA + if (dets.numel() == 0) + return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); + auto b = at::cat({dets, scores.unsqueeze(1)}, 1); + return nms_cuda(b, threshold); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + + at::Tensor result = nms_cpu(dets, scores, threshold); + return result; +} diff --git a/maskrcnn_benchmark/csrc/vision.cpp b/maskrcnn_benchmark/csrc/vision.cpp new file mode 100644 index 0000000000000000000000000000000000000000..30971995d4aafa3bbd4f1cb00138098c139163f6 --- /dev/null +++ b/maskrcnn_benchmark/csrc/vision.cpp @@ -0,0 +1,25 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include "nms.h" +#include "ROIAlign.h" +#include "ROIPool.h" +#include "SigmoidFocalLoss.h" +#include "deform_conv.h" +#include "deform_pool.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("nms", &nms, "non-maximum suppression"); + m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward"); + m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward"); + m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward"); + m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward"); + m.def("sigmoid_focalloss_forward", &SigmoidFocalLoss_forward, "SigmoidFocalLoss_forward"); + m.def("sigmoid_focalloss_backward", &SigmoidFocalLoss_backward, "SigmoidFocalLoss_backward"); + // dcn-v2 + m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward"); + m.def("deform_conv_backward_input", &deform_conv_backward_input, "deform_conv_backward_input"); + m.def("deform_conv_backward_parameters", &deform_conv_backward_parameters, "deform_conv_backward_parameters"); + m.def("modulated_deform_conv_forward", &modulated_deform_conv_forward, "modulated_deform_conv_forward"); + m.def("modulated_deform_conv_backward", &modulated_deform_conv_backward, "modulated_deform_conv_backward"); + m.def("deform_psroi_pooling_forward", &deform_psroi_pooling_forward, "deform_psroi_pooling_forward"); + m.def("deform_psroi_pooling_backward", &deform_psroi_pooling_backward, "deform_psroi_pooling_backward"); +} \ No newline at end of file diff --git a/maskrcnn_benchmark/data/__init__.py b/maskrcnn_benchmark/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ba1e52473f97615cc41f82aef279fff4d194527 --- /dev/null +++ b/maskrcnn_benchmark/data/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .build import make_data_loader diff --git a/maskrcnn_benchmark/data/build.py b/maskrcnn_benchmark/data/build.py new file mode 100644 index 0000000000000000000000000000000000000000..8c45724e6a2fafac2e8c3bf54bb4b1fe4e1ffe85 --- /dev/null +++ b/maskrcnn_benchmark/data/build.py @@ -0,0 +1,175 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import bisect +import logging + +import torch.utils.data +from maskrcnn_benchmark.utils.comm import get_world_size +from maskrcnn_benchmark.utils.imports import import_file + +from . import datasets as D +from . import samplers + +from .collate_batch import BatchCollator +from .transforms import build_transforms + + +def build_dataset(cfg, dataset_list, transforms, dataset_catalog, is_train=True): + """ + Arguments: + dataset_list (list[str]): Contains the names of the datasets, i.e., + coco_2014_trian, coco_2014_val, etc + transforms (callable): transforms to apply to each (image, target) sample + dataset_catalog (DatasetCatalog): contains the information on how to + construct a dataset. + is_train (bool): whether to setup the dataset for training or testing + """ + if not isinstance(dataset_list, (list, tuple)): + raise RuntimeError( + "dataset_list should be a list of strings, got {}".format(dataset_list)) + datasets = [] + for dataset_name in dataset_list: + data = dataset_catalog.get(dataset_name) + factory = getattr(D, data["factory"]) + args = data["args"] + # for COCODataset, we want to remove images without annotations + # during training + if data["factory"] == "COCODataset": + args["remove_images_without_annotations"] = is_train + args["transforms"] = transforms + args["ignore_difficult"] = cfg.DATASETS.IGNORE_DIFFICULT + # make dataset from factory + dataset = factory(**args) + datasets.append(dataset) + + # for testing, return a list of datasets + if not is_train: + return datasets + + # for training, concatenate all datasets into a single one + dataset = datasets[0] + if len(datasets) > 1: + dataset = D.MixDataset(datasets, cfg.DATASETS.RATIOS) + # dataset = D.ConcatDataset(datasets) + + return [dataset] + + +def make_data_sampler(dataset, shuffle, distributed): + if distributed: + return samplers.DistributedSampler(dataset, shuffle=shuffle) + if shuffle: + sampler = torch.utils.data.sampler.RandomSampler(dataset) + else: + sampler = torch.utils.data.sampler.SequentialSampler(dataset) + return sampler + + +def _quantize(x, bins): + bins = sorted(bins.copy()) + quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) + return quantized + + +def _compute_aspect_ratios(dataset): + aspect_ratios = [] + for i in range(len(dataset)): + img_info = dataset.get_img_info(i) + aspect_ratio = float(img_info["height"]) / float(img_info["width"]) + aspect_ratios.append(aspect_ratio) + return aspect_ratios + + +def make_batch_data_sampler( + dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0 +): + if aspect_grouping: + if not isinstance(aspect_grouping, (list, tuple)): + aspect_grouping = [aspect_grouping] + aspect_ratios = _compute_aspect_ratios(dataset) + group_ids = _quantize(aspect_ratios, aspect_grouping) + batch_sampler = samplers.GroupedBatchSampler( + sampler, group_ids, images_per_batch, drop_uneven=False + ) + else: + batch_sampler = torch.utils.data.sampler.BatchSampler( + sampler, images_per_batch, drop_last=False + ) + if num_iters is not None: + batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, num_iters, start_iter) + return batch_sampler + + +def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): + num_gpus = get_world_size() + if is_train: + images_per_batch = cfg.SOLVER.IMS_PER_BATCH + assert ( + images_per_batch % num_gpus == 0 + ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number " + "of GPUs ({}) used.".format(images_per_batch, num_gpus) + images_per_gpu = images_per_batch // num_gpus + shuffle = True + num_iters = cfg.SOLVER.MAX_ITER + else: + images_per_batch = cfg.TEST.IMS_PER_BATCH + assert ( + images_per_batch % num_gpus == 0 + ), "TEST.IMS_PER_BATCH ({}) must be divisible by the number " + "of GPUs ({}) used.".format(images_per_batch, num_gpus) + images_per_gpu = images_per_batch // num_gpus + shuffle = False if not is_distributed else True + num_iters = None + start_iter = 0 + + if images_per_gpu > 1: + logger = logging.getLogger(__name__) + logger.warning( + "When using more than one image per GPU you may encounter " + "an out-of-memory (OOM) error if your GPU does not have " + "sufficient memory. If this happens, you can reduce " + "SOLVER.IMS_PER_BATCH (for training) or " + "TEST.IMS_PER_BATCH (for inference). For training, you must " + "also adjust the learning rate and schedule length according " + "to the linear scaling rule. See for example: " + "https://github.com/facebookresearch/Detectron/blob/master/configs/getting_started/tutorial_1gpu_e2e_faster_rcnn_R-50-FPN.yaml#L14" + ) + + # group images which have similar aspect ratio. In this case, we only + # group in two cases: those with width / height > 1, and the other way around, + # but the code supports more general grouping strategy + aspect_grouping = [1] if cfg.DATALOADER.ASPECT_RATIO_GROUPING else [] + + paths_catalog = import_file( + "maskrcnn_benchmark.config.paths_catalog", cfg.PATHS_CATALOG, True + ) + DatasetCatalog = paths_catalog.DatasetCatalog + dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST + + transforms = build_transforms(cfg, is_train) + datasets = build_dataset(cfg,dataset_list, transforms, DatasetCatalog, is_train) + + data_loaders = [] + for dataset in datasets: + ''' + for i in range(20): + a=dataset[i] + ipdb.set_trace() + ''' + sampler = make_data_sampler(dataset, shuffle, is_distributed) + batch_sampler = make_batch_data_sampler( + dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter + ) + collator = BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY) + num_workers = cfg.DATALOADER.NUM_WORKERS + data_loader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + batch_sampler=batch_sampler, + collate_fn=collator, + ) + data_loaders.append(data_loader) + if is_train: + # during training, a single (possibly concatenated) data_loader is returned + assert len(data_loaders) == 1 + return data_loaders[0] + return data_loaders diff --git a/maskrcnn_benchmark/data/collate_batch.py b/maskrcnn_benchmark/data/collate_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..c10ec95008f91b72b413e287d515918ed1e9746f --- /dev/null +++ b/maskrcnn_benchmark/data/collate_batch.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from maskrcnn_benchmark.structures.image_list import to_image_list, to_image_target_list + + +class BatchCollator(object): + """ + From a list of samples from the dataset, + returns the batched images and targets. + This should be passed to the DataLoader + """ + + def __init__(self, size_divisible=0): + self.size_divisible = size_divisible + + def __call__(self, batch): + transposed_batch = list(zip(*batch)) + images = to_image_list(transposed_batch[0], self.size_divisible) + targets = transposed_batch[1] + img_ids = transposed_batch[2] + # if transposed_batch[1] is None: + # images = to_image_list(transposed_batch[0], self.size_divisible) + # targets = transposed_batch[1] + # img_ids = transposed_batch[2] + # else: + # images, targets = to_image_target_list(transposed_batch[0], self.size_divisible, transposed_batch[1]) + # img_ids = transposed_batch[2] + return images, targets, img_ids diff --git a/maskrcnn_benchmark/data/datasets/__init__.py b/maskrcnn_benchmark/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1982b8919f642b7b5775b0c9964e94e21747ade --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .coco import COCODataset +from .concat_dataset import ConcatDataset, MixDataset +from .icdar import IcdarDataset +from .scut import ScutDataset +from .synthtext import SynthtextDataset +from .total_text import TotaltextDataset + +__all__ = [ + "COCODataset", + "ConcatDataset", + "IcdarDataset", + "SynthtextDataset", + "MixDataset", + "ScutDataset", + "TotaltextDataset", +] diff --git a/maskrcnn_benchmark/data/datasets/coco.py b/maskrcnn_benchmark/data/datasets/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..030238d1a6f3b8b78f1dd21474ed5c0eaa7003d7 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/coco.py @@ -0,0 +1,65 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +import torchvision + +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask + + +class COCODataset(torchvision.datasets.coco.CocoDetection): + def __init__( + self, ann_file, root, remove_images_without_annotations, transforms=None + ): + super(COCODataset, self).__init__(root, ann_file) + + # sort indices for reproducible results + self.ids = sorted(self.ids) + + # filter images without detection annotations + if remove_images_without_annotations: + self.ids = [ + img_id + for img_id in self.ids + if len(self.coco.getAnnIds(imgIds=img_id, iscrowd=None)) > 0 + ] + + self.json_category_id_to_contiguous_id = { + v: i + 1 for i, v in enumerate(self.coco.getCatIds()) + } + self.contiguous_category_id_to_json_id = { + v: k for k, v in self.json_category_id_to_contiguous_id.items() + } + self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} + self.transforms = transforms + + def __getitem__(self, idx): + img, anno = super(COCODataset, self).__getitem__(idx) + + # filter crowd annotations + # TODO might be better to add an extra field + anno = [obj for obj in anno if obj["iscrowd"] == 0] + + boxes = [obj["bbox"] for obj in anno] + boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes + target = BoxList(boxes, img.size, mode="xywh",use_char_ann=False).convert("xyxy") + + classes = [obj["category_id"] for obj in anno] + classes = [self.json_category_id_to_contiguous_id[c] for c in classes] + classes = torch.tensor(classes) + target.add_field("labels", classes) + + masks = [obj["segmentation"] for obj in anno] + masks = SegmentationMask(masks, img.size) + target.add_field("masks", masks) + + target = target.clip_to_image(remove_empty=True) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target, idx + + def get_img_info(self, index): + img_id = self.id_to_img_map[index] + img_data = self.coco.imgs[img_id] + return img_data diff --git a/maskrcnn_benchmark/data/datasets/concat_dataset.py b/maskrcnn_benchmark/data/datasets/concat_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..52bf9f09d4ffa8e2b44f8a5b71bd19341d692948 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/concat_dataset.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import bisect +import numpy as np +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset + + +class ConcatDataset(_ConcatDataset): + """ + Same as torch.utils.data.dataset.ConcatDataset, but exposes an extra + method for querying the sizes of the image + """ + + def get_idxs(self, idx): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return dataset_idx, sample_idx + + def get_img_info(self, idx): + dataset_idx, sample_idx = self.get_idxs(idx) + return self.datasets[dataset_idx].get_img_info(sample_idx) + +class MixDataset(object): + def __init__(self, datasets, ratios): + self.datasets = datasets + self.ratios = ratios + self.lengths = [] + for dataset in self.datasets: + self.lengths.append(len(dataset)) + self.lengths = np.array(self.lengths) + self.seperate_inds = [] + s = 0 + for i in self.ratios[:-1]: + s += i + self.seperate_inds.append(s) + + def __len__(self): + return self.lengths.sum() + + def __getitem__(self, item): + i = np.random.rand() + ind = bisect.bisect_right(self.seperate_inds, i) + b_ind = np.random.randint(self.lengths[ind]) + return self.datasets[ind][b_ind] + + + + + + + + diff --git a/maskrcnn_benchmark/data/datasets/icdar.py b/maskrcnn_benchmark/data/datasets/icdar.py new file mode 100644 index 0000000000000000000000000000000000000000..e8c4538de715bae0315260311916e648fdd3bf45 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/icdar.py @@ -0,0 +1,274 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Simple dataset class that wraps a list of path names +""" + +import os + +import numpy as np +import torch +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.segmentation_mask import ( + SegmentationCharMask, + SegmentationMask, +) +from PIL import Image, ImageDraw + + +class IcdarDataset(object): + def __init__(self, use_charann, imgs_dir, gts_dir, transforms=None, ignore_difficult=False): + self.use_charann = use_charann + self.image_lists = [os.path.join(imgs_dir, img) for img in os.listdir(imgs_dir)] + self.gts_dir = gts_dir + self.transforms = transforms + self.min_proposal_size = 2 + self.char_classes = "_0123456789abcdefghijklmnopqrstuvwxyz" + self.vis = False + self.ignore_difficult = ignore_difficult + if self.ignore_difficult and self.gts_dir is not None and 'train' in self.gts_dir: + self.image_lists = self.filter_image_lists() + + def filter_image_lists(self): + new_image_lists = [] + for img_path in self.image_lists: + has_positive = False + im_name = os.path.basename(img_path) + gt_path = os.path.join(self.gts_dir, im_name + ".txt") + if not os.path.isfile(gt_path): + gt_path = os.path.join( + self.gts_dir, "gt_" + im_name.split(".")[0] + ".txt" + ) + lines = open(gt_path, 'r').readlines() + for line in lines: + charbbs = [] + strs, loc = self.line2boxes(line) + word = strs[0] + if word == "###": + continue + else: + has_positive = True + if has_positive: + new_image_lists.append(img_path) + return new_image_lists + + def __getitem__(self, item): + im_name = os.path.basename(self.image_lists[item]) + img = Image.open(self.image_lists[item]).convert("RGB") + width, height = img.size + if self.gts_dir is not None: + gt_path = os.path.join(self.gts_dir, im_name + ".txt") + if not os.path.isfile(gt_path): + gt_path = os.path.join( + self.gts_dir, "gt_" + im_name.split(".")[0] + ".txt" + ) + words, boxes, charsbbs, segmentations, labels = self.load_gt_from_txt( + gt_path, height, width + ) + target = BoxList( + boxes[:, :4], img.size, mode="xyxy", use_char_ann=self.use_charann + ) + if self.ignore_difficult: + labels = torch.from_numpy(np.array(labels)) + else: + labels = torch.ones(len(boxes)) + target.add_field("labels", labels) + masks = SegmentationMask(segmentations, img.size) + target.add_field("masks", masks) + if words[0] == "": + use_char_ann = False + else: + use_char_ann = True + if not self.use_charann: + use_char_ann = False + char_masks = SegmentationCharMask( + charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes) + ) + target.add_field("char_masks", char_masks) + else: + target = None + if self.transforms is not None: + img, target = self.transforms(img, target) + if self.vis: + new_im = img.numpy().copy().transpose([1, 2, 0]) + [ + 102.9801, + 115.9465, + 122.7717, + ] + new_im = Image.fromarray(new_im.astype(np.uint8)).convert("RGB") + mask = target.extra_fields["masks"].polygons[0].convert("mask") + mask = Image.fromarray((mask.numpy() * 255).astype(np.uint8)).convert("RGB") + if self.use_charann: + m, _ = ( + target.extra_fields["char_masks"] + .chars_boxes[0] + .convert("char_mask") + ) + color = self.creat_color_map(37, 255) + color_map = color[m.numpy().astype(np.uint8)] + char = Image.fromarray(color_map.astype(np.uint8)).convert("RGB") + char = Image.blend(char, new_im, 0.5) + else: + char = new_im + new = Image.blend(char, mask, 0.5) + img_draw = ImageDraw.Draw(new) + for box in target.bbox.numpy(): + box = list(box) + box = box[:2] + [box[2], box[1]] + box[2:] + [box[0], box[3]] + box[:2] + img_draw.line(box, fill=(255, 0, 0), width=2) + new.save("./vis/char_" + im_name) + return img, target, self.image_lists[item] + + def creat_color_map(self, n_class, width): + splits = int(np.ceil(np.power((n_class * 1.0), 1.0 / 3))) + maps = [] + for i in range(splits): + r = int(i * width * 1.0 / (splits - 1)) + for j in range(splits): + g = int(j * width * 1.0 / (splits - 1)) + for k in range(splits - 1): + b = int(k * width * 1.0 / (splits - 1)) + maps.append([r, g, b]) + return np.array(maps) + + def __len__(self): + return len(self.image_lists) + + def load_gt_from_txt(self, gt_path, height=None, width=None): + words, boxes, charsboxes, segmentations, labels = [], [], [], [], [] + lines = open(gt_path).readlines() + for line in lines: + charbbs = [] + strs, loc = self.line2boxes(line) + word = strs[0] + if word == "###": + if self.ignore_difficult: + rect = list(loc[0]) + min_x = min(rect[::2]) - 1 + min_y = min(rect[1::2]) - 1 + max_x = max(rect[::2]) - 1 + max_y = max(rect[1::2]) - 1 + box = [min_x, min_y, max_x, max_y] + segmentations.append([loc[0, :]]) + tindex = len(boxes) + boxes.append(box) + words.append(word) + labels.append(-1) + charbbs = np.zeros((10,), dtype=np.float32) + if loc.shape[0] > 1: + for i in range(1, loc.shape[0]): + charbb[9] = tindex + charbbs.append(charbb.copy()) + charsboxes.append(charbbs) + else: + continue + else: + rect = list(loc[0]) + min_x = min(rect[::2]) - 1 + min_y = min(rect[1::2]) - 1 + max_x = max(rect[::2]) - 1 + max_y = max(rect[1::2]) - 1 + box = [min_x, min_y, max_x, max_y] + segmentations.append([loc[0, :]]) + tindex = len(boxes) + boxes.append(box) + words.append(word) + labels.append(1) + c_class = self.char2num(strs[1:]) + charbb = np.zeros((10,), dtype=np.float32) + if loc.shape[0] > 1: + for i in range(1, loc.shape[0]): + charbb[:8] = loc[i, :] + charbb[8] = c_class[i - 1] + charbb[9] = tindex + charbbs.append(charbb.copy()) + charsboxes.append(charbbs) + num_boxes = len(boxes) + if len(boxes) > 0: + keep_boxes = np.zeros((num_boxes, 5)) + keep_boxes[:, :4] = np.array(boxes) + keep_boxes[:, 4] = range( + num_boxes + ) + # the 5th column is the box label, + # same as the 10th column of all charsboxes which belong to the box + if self.use_charann: + return words, np.array(keep_boxes), charsboxes, segmentations, labels + else: + charbbs = np.zeros((10,), dtype=np.float32) + if len(charsboxes) == 0: + for _ in range(len(words)): + charsboxes.append([charbbs]) + return words, np.array(keep_boxes), charsboxes, segmentations, labels + else: + words.append("") + charbbs = np.zeros((10,), dtype=np.float32) + return ( + words, + np.zeros((1, 5), dtype=np.float32), + [[charbbs]], + [[np.zeros((8,), dtype=np.float32)]], + [1] + ) + + def line2boxes(self, line): + parts = line.strip().split(",") + if "\xef\xbb\xbf" in parts[0]: + parts[0] = parts[0][3:] + if "\ufeff" in parts[0]: + parts[0] = parts[0].replace("\ufeff", "") + x1 = np.array([int(float(x)) for x in parts[::9]]) + y1 = np.array([int(float(x)) for x in parts[1::9]]) + x2 = np.array([int(float(x)) for x in parts[2::9]]) + y2 = np.array([int(float(x)) for x in parts[3::9]]) + x3 = np.array([int(float(x)) for x in parts[4::9]]) + y3 = np.array([int(float(x)) for x in parts[5::9]]) + x4 = np.array([int(float(x)) for x in parts[6::9]]) + y4 = np.array([int(float(x)) for x in parts[7::9]]) + strs = parts[8::9] + loc = np.vstack((x1, y1, x2, y2, x3, y3, x4, y4)).transpose() + return strs, loc + + def check_charbbs(self, charbbs): + xmins = np.minimum.reduce( + [charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]] + ) + xmaxs = np.maximum.reduce( + [charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]] + ) + ymins = np.minimum.reduce( + [charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]] + ) + ymaxs = np.maximum.reduce( + [charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]] + ) + return np.logical_and( + xmaxs - xmins > self.min_proposal_size, + ymaxs - ymins > self.min_proposal_size, + ) + + def check_charbb(self, charbb): + xmins = min(charbb[0], charbb[2], charbb[4], charbb[6]) + xmaxs = max(charbb[0], charbb[2], charbb[4], charbb[6]) + ymins = min(charbb[1], charbb[3], charbb[5], charbb[7]) + ymaxs = max(charbb[1], charbb[3], charbb[5], charbb[7]) + return ( + xmaxs - xmins > self.min_proposal_size + and ymaxs - ymins > self.min_proposal_size + ) + + def char2num(self, chars): + ## chars ['h', 'e', 'l', 'l', 'o'] + nums = [self.char_classes.index(c.lower()) for c in chars] + return nums + + def get_img_info(self, item): + """ + Return the image dimensions for the image, without + loading and pre-processing it + """ + + im_name = os.path.basename(self.image_lists[item]) + img = Image.open(self.image_lists[item]) + width, height = img.size + img_info = {"im_name": im_name, "height": height, "width": width} + return img_info diff --git a/maskrcnn_benchmark/data/datasets/list_dataset.py b/maskrcnn_benchmark/data/datasets/list_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9058d35b3d4279048732074f4a8dbb6edd4c9ed0 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/list_dataset.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Simple dataset class that wraps a list of path names +""" + +from PIL import Image + +from maskrcnn_benchmark.structures.bounding_box import BoxList + + +class ListDataset(object): + def __init__(self, image_lists, transforms=None): + self.image_lists = image_lists + self.transforms = transforms + + def __getitem__(self, item): + img = Image.open(self.image_lists[item]).convert("RGB") + + # dummy target + w, h = img.size + target = BoxList([[0, 0, w, h]], img.size, mode="xyxy") + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self): + return len(self.image_lists) + + def get_img_info(self, item): + """ + Return the image dimensions for the image, without + loading and pre-processing it + """ + pass diff --git a/maskrcnn_benchmark/data/datasets/scut.py b/maskrcnn_benchmark/data/datasets/scut.py new file mode 100644 index 0000000000000000000000000000000000000000..74bfa28a0eded6bca78bd1ebb060fad216c1c3b7 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/scut.py @@ -0,0 +1,328 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Simple dataset class that wraps a list of path names +""" + +import os + +import numpy as np +import torch +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.segmentation_mask import ( + CharPolygons, + SegmentationCharMask, + SegmentationMask, +) +from PIL import Image, ImageDraw, ImageFile + + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class ScutDataset(object): + def __init__(self, use_charann, imgs_dir, gts_dir, transforms=None, ignore_difficult=False): + self.use_charann = use_charann + self.image_lists = [os.path.join(imgs_dir, img) for img in os.listdir(imgs_dir)] + self.gts_dir = gts_dir + self.transforms = transforms + self.min_proposal_size = 2 + self.char_classes = "_0123456789abcdefghijklmnopqrstuvwxyz" + self.vis = False + self.ignore_difficult = ignore_difficult + if self.ignore_difficult and 'train' in self.gts_dir: + self.image_lists = self.filter_image_lists() + + def filter_image_lists(self): + new_image_lists = [] + for img_path in self.image_lists: + has_positive = False + im_name = os.path.basename(img_path) + gt_path = os.path.join(self.gts_dir, im_name + ".txt") + if not os.path.isfile(gt_path): + gt_path = os.path.join( + self.gts_dir, "gt_" + im_name.split(".")[0] + ".txt" + ) + lines = open(gt_path, 'r').readlines() + for line in lines: + charbbs = [] + strs, loc = self.line2boxes(line) + word = strs[0] + if word == "###": + continue + else: + has_positive = True + if has_positive: + new_image_lists.append(img_path) + return new_image_lists + + def __getitem__(self, item): + im_name = os.path.basename(self.image_lists[item]) + # print(self.image_lists[item]) + img = Image.open(self.image_lists[item]).convert("RGB") + width, height = img.size + gt_path = os.path.join(self.gts_dir, im_name + ".txt") + words, boxes, charsbbs, segmentations, labels = self.load_gt_from_txt( + gt_path, height, width + ) + if words[0] == "": + use_char_ann = False + else: + use_char_ann = True + if not self.use_charann: + use_char_ann = False + target = BoxList(boxes[:, :4], img.size, mode="xyxy", use_char_ann=use_char_ann) + if self.ignore_difficult: + labels = torch.from_numpy(np.array(labels)) + else: + labels = torch.ones(len(boxes)) + target.add_field("labels", labels) + masks = SegmentationMask(segmentations, img.size) + target.add_field("masks", masks) + char_masks = SegmentationCharMask( + charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes) + ) + target.add_field("char_masks", char_masks) + if self.transforms is not None: + img, target = self.transforms(img, target) + if self.vis: + new_im = img.numpy().copy().transpose([1, 2, 0]) + [ + 102.9801, + 115.9465, + 122.7717, + ] + new_im = Image.fromarray(new_im.astype(np.uint8)).convert("RGB") + mask = target.extra_fields["masks"].polygons[0].convert("mask") + mask = Image.fromarray((mask.numpy() * 255).astype(np.uint8)).convert("RGB") + if self.use_charann: + m, _ = ( + target.extra_fields["char_masks"] + .chars_boxes[0] + .convert("char_mask") + ) + color = self.creat_color_map(37, 255) + color_map = color[m.numpy().astype(np.uint8)] + char = Image.fromarray(color_map.astype(np.uint8)).convert("RGB") + char = Image.blend(char, new_im, 0.5) + else: + char = new_im + new = Image.blend(char, mask, 0.5) + img_draw = ImageDraw.Draw(new) + for box in target.bbox.numpy(): + box = list(box) + box = box[:2] + [box[2], box[1]] + box[2:] + [box[0], box[3]] + box[:2] + img_draw.line(box, fill=(255, 0, 0), width=2) + new.save("./vis/char_" + im_name) + return img, target, self.image_lists[item] + + def creat_color_map(self, n_class, width): + splits = int(np.ceil(np.power((n_class * 1.0), 1.0 / 3))) + maps = [] + for i in range(splits): + r = int(i * width * 1.0 / (splits - 1)) + for j in range(splits): + g = int(j * width * 1.0 / (splits - 1)) + for k in range(splits - 1): + b = int(k * width * 1.0 / (splits - 1)) + maps.append([r, g, b]) + return np.array(maps) + + def __len__(self): + return len(self.image_lists) + + # def load_gt_from_txt(self, gt_path, height=None, width=None): + # words, boxes, charsboxes, segmentations, labels = [], [], [], [], [] + # lines = open(gt_path).readlines() + # for line in lines: + # charbbs = [] + # strs, loc = self.line2boxes(line) + # word = strs[0] + # if word == "###": + # labels.append(-1) + # continue + # else: + # labels.append(1) + # rect = list(loc[0]) + # min_x = min(rect[::2]) - 1 + # min_y = min(rect[1::2]) - 1 + # max_x = max(rect[::2]) - 1 + # max_y = max(rect[1::2]) - 1 + # box = [min_x, min_y, max_x, max_y] + # segmentations.append([loc[0, :]]) + # tindex = len(boxes) + # boxes.append(box) + # words.append(word) + # c_class = self.char2num(strs[1:]) + # charbb = np.zeros((10,), dtype=np.float32) + # if loc.shape[0] > 1: + # for i in range(1, loc.shape[0]): + # charbb[:8] = loc[i, :] + # charbb[8] = c_class[i - 1] + # charbb[9] = tindex + # charbbs.append(charbb.copy()) + # charsboxes.append(charbbs) + # num_boxes = len(boxes) + # if len(boxes) > 0: + # keep_boxes = np.zeros((num_boxes, 5)) + # keep_boxes[:, :4] = np.array(boxes) + # keep_boxes[:, 4] = range( + # num_boxes + # ) # the 5th column is the box label,same as the 10th column of all charsboxes which belong to the box + # if self.use_charann: + # return words, np.array(keep_boxes), charsboxes, segmentations, labels + # else: + # charbbs = np.zeros((10,), dtype=np.float32) + # if len(charsboxes) == 0: + # for i in range(len(words)): + # charsboxes.append([charbbs]) + # return words, np.array(keep_boxes), charsboxes, segmentations, labels + # else: + # words.append("") + # charbbs = np.zeros((10,), dtype=np.float32) + # return ( + # words, + # np.zeros((1, 5), dtype=np.float32), + # [[charbbs]], + # [[np.zeros((8,), dtype=np.float32)]], + # labels + # ) + + def load_gt_from_txt(self, gt_path, height=None, width=None): + words, boxes, charsboxes, segmentations, labels = [], [], [], [], [] + lines = open(gt_path).readlines() + for line in lines: + charbbs = [] + strs, loc = self.line2boxes(line) + word = strs[0] + if word == "###": + if self.ignore_difficult: + rect = list(loc[0]) + min_x = min(rect[::2]) - 1 + min_y = min(rect[1::2]) - 1 + max_x = max(rect[::2]) - 1 + max_y = max(rect[1::2]) - 1 + box = [min_x, min_y, max_x, max_y] + segmentations.append([loc[0, :]]) + tindex = len(boxes) + boxes.append(box) + words.append(word) + labels.append(-1) + charbbs = np.zeros((10,), dtype=np.float32) + if loc.shape[0] > 1: + for i in range(1, loc.shape[0]): + charbb[9] = tindex + charbbs.append(charbb.copy()) + charsboxes.append(charbbs) + else: + continue + else: + rect = list(loc[0]) + min_x = min(rect[::2]) - 1 + min_y = min(rect[1::2]) - 1 + max_x = max(rect[::2]) - 1 + max_y = max(rect[1::2]) - 1 + box = [min_x, min_y, max_x, max_y] + segmentations.append([loc[0, :]]) + tindex = len(boxes) + boxes.append(box) + words.append(word) + labels.append(1) + c_class = self.char2num(strs[1:]) + charbb = np.zeros((10,), dtype=np.float32) + if loc.shape[0] > 1: + for i in range(1, loc.shape[0]): + charbb[:8] = loc[i, :] + charbb[8] = c_class[i - 1] + charbb[9] = tindex + charbbs.append(charbb.copy()) + charsboxes.append(charbbs) + num_boxes = len(boxes) + if len(boxes) > 0: + keep_boxes = np.zeros((num_boxes, 5)) + keep_boxes[:, :4] = np.array(boxes) + keep_boxes[:, 4] = range( + num_boxes + ) + # the 5th column is the box label, + # same as the 10th column of all charsboxes which belong to the box + if self.use_charann: + return words, np.array(keep_boxes), charsboxes, segmentations, labels + else: + charbbs = np.zeros((10,), dtype=np.float32) + if len(charsboxes) == 0: + for _ in range(len(words)): + charsboxes.append([charbbs]) + return words, np.array(keep_boxes), charsboxes, segmentations, labels + else: + words.append("") + charbbs = np.zeros((10,), dtype=np.float32) + return ( + words, + np.zeros((1, 5), dtype=np.float32), + [[charbbs]], + [[np.zeros((8,), dtype=np.float32)]], + [1] + ) + + + def line2boxes(self, line): + parts = line.strip().split(",") + if "\xef\xbb\xbf" in parts[0]: + parts[0] = parts[0][3:] + if "\ufeff" in parts[0]: + parts[0] = parts[0].replace("\ufeff", "") + x1 = np.array([int(float(x)) for x in parts[::9]]) + y1 = np.array([int(float(x)) for x in parts[1::9]]) + x2 = np.array([int(float(x)) for x in parts[2::9]]) + y2 = np.array([int(float(x)) for x in parts[3::9]]) + x3 = np.array([int(float(x)) for x in parts[4::9]]) + y3 = np.array([int(float(x)) for x in parts[5::9]]) + x4 = np.array([int(float(x)) for x in parts[6::9]]) + y4 = np.array([int(float(x)) for x in parts[7::9]]) + strs = parts[8::9] + loc = np.vstack((x1, y1, x2, y2, x3, y3, x4, y4)).transpose() + return strs, loc + + def check_charbbs(self, charbbs): + xmins = np.minimum.reduce( + [charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]] + ) + xmaxs = np.maximum.reduce( + [charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]] + ) + ymins = np.minimum.reduce( + [charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]] + ) + ymaxs = np.maximum.reduce( + [charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]] + ) + return np.logical_and( + xmaxs - xmins > self.min_proposal_size, + ymaxs - ymins > self.min_proposal_size, + ) + + def check_charbb(self, charbb): + xmins = min(charbb[0], charbb[2], charbb[4], charbb[6]) + xmaxs = max(charbb[0], charbb[2], charbb[4], charbb[6]) + ymins = min(charbb[1], charbb[3], charbb[5], charbb[7]) + ymaxs = max(charbb[1], charbb[3], charbb[5], charbb[7]) + return ( + xmaxs - xmins > self.min_proposal_size + and ymaxs - ymins > self.min_proposal_size + ) + + def char2num(self, chars): + ## chars ['h', 'e', 'l', 'l', 'o'] + nums = [self.char_classes.index(c.lower()) for c in chars] + return nums + + def get_img_info(self, item): + """ + Return the image dimensions for the image, without + loading and pre-processing it + """ + + im_name = os.path.basename(self.image_lists[item]) + img = Image.open(self.image_lists[item]) + width, height = img.size + img_info = {"im_name": im_name, "height": height, "width": width} + return img_info diff --git a/maskrcnn_benchmark/data/datasets/synthtext.py b/maskrcnn_benchmark/data/datasets/synthtext.py new file mode 100644 index 0000000000000000000000000000000000000000..9d8fa1ce049a96497cfd0ed32558bc6cababb600 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/synthtext.py @@ -0,0 +1,232 @@ + +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Simple dataset class that wraps a list of path names +""" + +import os +import numpy as np +import torch +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.segmentation_mask import ( + SegmentationCharMask, + SegmentationMask, +) +from PIL import Image, ImageDraw + + +class SynthtextDataset(object): + def __init__(self, use_charann, list_file_path, imgs_dir, gts_dir, transforms=None, ignore_difficult=False): + self.use_charann = use_charann + with open(list_file_path, "r") as list_file: + image_lines = list_file.readlines() + self.image_lists = [ + os.path.join(imgs_dir, line.strip()) for line in image_lines + ] + self.gt_lists = [ + os.path.join(gts_dir, line.strip() + ".txt") for line in image_lines + ] + self.filtered_gts = [] + self.transforms = transforms + self.min_proposal_size = 2 + self.char_classes = "_0123456789abcdefghijklmnopqrstuvwxyz" + self.vis = False + self.ignore_difficult = ignore_difficult + + def __getitem__(self, item): + while True: + img_path = self.image_lists[item] + try: + img = Image.open(img_path).convert("RGB") + break + except BaseException: + item += 1 + im_name = os.path.basename(img_path) + width, height = img.size + gt_path = self.gt_lists[item] + words, boxes, charsbbs, segmentations = self.load_gt_from_txt( + gt_path, height, width + ) + target = BoxList( + boxes[:, :4], img.size, mode="xyxy", use_char_ann=self.use_charann + ) + classes = torch.ones(len(boxes)) + target.add_field("labels", classes) + masks = SegmentationMask(segmentations, img.size) + target.add_field("masks", masks) + if words[0] == "": + use_char_ann = False + else: + use_char_ann = True + if not self.use_charann: + use_char_ann = False + char_masks = SegmentationCharMask( + charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes) + ) + target.add_field("char_masks", char_masks) + if self.transforms is not None: + img, target = self.transforms(img, target) + if self.vis: + new_im = img.numpy().copy().transpose([1, 2, 0]) + [ + 102.9801, + 115.9465, + 122.7717, + ] + new_im = Image.fromarray(new_im.astype(np.uint8)).convert("RGB") + mask = target.extra_fields["masks"].polygons[0].convert("mask") + mask = Image.fromarray((mask.numpy() * 255).astype(np.uint8)).convert("RGB") + if self.use_charann: + m, _ = ( + target.extra_fields["char_masks"] + .chars_boxes[0] + .convert("char_mask") + ) + color = self.creat_color_map(37, 255) + color_map = color[m.numpy().astype(np.uint8)] + char = Image.fromarray(color_map.astype(np.uint8)).convert("RGB") + char = Image.blend(char, new_im, 0.5) + else: + char = new_im + new = Image.blend(char, mask, 0.5) + img_draw = ImageDraw.Draw(new) + for box in target.bbox.numpy(): + box = list(box) + box = box[:2] + [box[2], box[1]] + box[2:] + [box[0], box[3]] + box[:2] + img_draw.line(box, fill=(255, 0, 0), width=2) + new.save("./vis/char_" + im_name) + return img, target, self.image_lists[item] + + def creat_color_map(self, n_class, width): + splits = int(np.ceil(np.power((n_class * 1.0), 1.0 / 3))) + maps = [] + for i in range(splits): + r = int(i * width * 1.0 / (splits - 1)) + for j in range(splits): + g = int(j * width * 1.0 / (splits - 1)) + for k in range(splits - 1): + b = int(k * width * 1.0 / (splits - 1)) + maps.append([r, g, b]) + return np.array(maps) + + def __len__(self): + return len(self.image_lists) + + def load_gt_from_txt(self, gt_path, height=None, width=None): + words, boxes, charsboxes, segmentations = [], [], [], [] + lines = open(gt_path).readlines() + for line in lines: + charbbs = [] + strs, loc = self.line2boxes(line) + word = strs[0] + if word == "###": + continue + else: + rect = list(loc[0]) + min_x = min(rect[::2]) - 1 + min_y = min(rect[1::2]) - 1 + max_x = max(rect[::2]) - 1 + max_y = max(rect[1::2]) - 1 + box = [min_x, min_y, max_x, max_y] + segmentations.append([loc[0, :]]) + tindex = len(boxes) + boxes.append(box) + words.append(word) + c_class = self.char2num(strs[1:]) + charbb = np.zeros((10,), dtype=np.float32) + if loc.shape[0] > 1: + for i in range(1, loc.shape[0]): + charbb[:8] = loc[i, :] + charbb[8] = c_class[i - 1] + charbb[9] = tindex + charbbs.append(charbb.copy()) + else: + charbbs.append(charbb.copy()) + charsboxes.append(charbbs) + num_boxes = len(boxes) + if len(boxes) > 0: + keep_boxes = np.zeros((num_boxes, 5)) + keep_boxes[:, :4] = np.array(boxes) + keep_boxes[:, 4] = range( + num_boxes + ) + # the 5th column is the box label, + # same as the 10th column of all charsboxes which belong to the box + if self.use_charann: + return words, np.array(keep_boxes), charsboxes, segmentations + else: + charbbs = np.zeros((10,), dtype=np.float32) + for _ in range(len(words)): + charsboxes.append([charbbs]) + return words, np.array(keep_boxes), [[charbbs]], segmentations + else: + words.append("") + charbbs = np.zeros((10,), dtype=np.float32) + return ( + words, + np.zeros((1, 5), dtype=np.float32), + [[charbbs]], + [[np.zeros((8,), dtype=np.float32)]], + ) + + def line2boxes(self, line): + parts = line.strip().split(",") + if "\xef\xbb\xbf" in parts[0]: + parts[0] = parts[0][3:] + if "\ufeff" in parts[0]: + parts[0] = parts[0].replace("\ufeff", "") + x1 = np.array([int(float(x)) for x in parts[::9]]) + y1 = np.array([int(float(x)) for x in parts[1::9]]) + x2 = np.array([int(float(x)) for x in parts[2::9]]) + y2 = np.array([int(float(x)) for x in parts[3::9]]) + x3 = np.array([int(float(x)) for x in parts[4::9]]) + y3 = np.array([int(float(x)) for x in parts[5::9]]) + x4 = np.array([int(float(x)) for x in parts[6::9]]) + y4 = np.array([int(float(x)) for x in parts[7::9]]) + strs = parts[8::9] + loc = np.vstack((x1, y1, x2, y2, x3, y3, x4, y4)).transpose() + return strs, loc + + def check_charbbs(self, charbbs): + xmins = np.minimum.reduce( + [charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]] + ) + xmaxs = np.maximum.reduce( + [charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]] + ) + ymins = np.minimum.reduce( + [charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]] + ) + ymaxs = np.maximum.reduce( + [charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]] + ) + return np.logical_and( + xmaxs - xmins > self.min_proposal_size, + ymaxs - ymins > self.min_proposal_size, + ) + + def check_charbb(self, charbb): + xmins = min(charbb[0], charbb[2], charbb[4], charbb[6]) + xmaxs = max(charbb[0], charbb[2], charbb[4], charbb[6]) + ymins = min(charbb[1], charbb[3], charbb[5], charbb[7]) + ymaxs = max(charbb[1], charbb[3], charbb[5], charbb[7]) + return ( + xmaxs - xmins > self.min_proposal_size + and ymaxs - ymins > self.min_proposal_size + ) + + def char2num(self, chars): + ## chars ['h', 'e', 'l', 'l', 'o'] + nums = [self.char_classes.index(c.lower()) for c in chars] + return nums + + def get_img_info(self, item): + """ + Return the image dimensions for the image, without + loading and pre-processing it + """ + + im_name = os.path.basename(self.image_lists[item]) + img = Image.open(self.image_lists[item]) + width, height = img.size + img_info = {"im_name": im_name, "height": height, "width": width} + return img_info diff --git a/maskrcnn_benchmark/data/datasets/tdtr.py b/maskrcnn_benchmark/data/datasets/tdtr.py new file mode 100644 index 0000000000000000000000000000000000000000..4103c086b3fc6109315c63910d955eba85d5e4ba --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/tdtr.py @@ -0,0 +1,297 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Simple dataset class that wraps a list of path names +""" + +import os + +import numpy as np +import torch +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.segmentation_mask import ( + CharPolygons, + SegmentationCharMask, + SegmentationMask, +) +from PIL import Image, ImageDraw + + +class Tdtr(object): + def __init__(self, use_charann, imgs_dir, gts_dir, transforms=None, ignore_difficult=False): + self.use_charann = use_charann + self.image_lists = [os.path.join(imgs_dir, img) for img in os.listdir(imgs_dir)] + self.gts_dir = gts_dir + self.transforms = transforms + self.min_proposal_size = 2 + self.char_classes = "_0123456789abcdefghijklmnopqrstuvwxyz" + self.vis = False + self.ignore_difficult = ignore_difficult + if self.ignore_difficult and (self.gts_dir is not None) and 'train' in self.gts_dir: + self.image_lists = self.filter_image_lists() + + def filter_image_lists(self): + new_image_lists = [] + for img_path in self.image_lists: + has_positive = False + im_name = os.path.basename(img_path) + gt_path = os.path.join(self.gts_dir, im_name + ".txt") + if not os.path.isfile(gt_path): + gt_path = os.path.join( + self.gts_dir, "gt_" + im_name.split(".")[0] + ".txt" + ) + lines = open(gt_path, 'r').readlines() + for line in lines: + charbbs = [] + strs, loc = self.line2boxes(line) + word = strs[0] + if word == "1": + continue + else: + has_positive = True + if has_positive: + new_image_lists.append(img_path) + return new_image_lists + + def __getitem__(self, item): + im_name = os.path.basename(self.image_lists[item]) + # print(self.image_lists[item]) + img = Image.open(self.image_lists[item]).convert("RGB") + width, height = img.size + if self.gts_dir is not None: + gt_path = os.path.join(self.gts_dir, im_name + ".txt") + words, boxes, charsbbs, segmentations, labels = self.load_gt_from_txt( + gt_path, height, width + ) + if words[0] == "": + use_char_ann = False + else: + use_char_ann = True + if not self.use_charann: + use_char_ann = False + target = BoxList( + boxes[:, :4], img.size, mode="xyxy", use_char_ann=use_char_ann + ) + if self.ignore_difficult: + labels = torch.from_numpy(np.array(labels)) + else: + labels = torch.ones(len(boxes)) + target.add_field("labels", labels) + masks = SegmentationMask(segmentations, img.size) + target.add_field("masks", masks) + char_masks = SegmentationCharMask( + charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes) + ) + target.add_field("char_masks", char_masks) + else: + target = None + if self.transforms is not None: + img, target = self.transforms(img, target) + if self.vis: + new_im = img.numpy().copy().transpose([1, 2, 0]) + [ + 102.9801, + 115.9465, + 122.7717, + ] + new_im = Image.fromarray(new_im.astype(np.uint8)).convert("RGB") + mask = target.extra_fields["masks"].polygons[0].convert("mask") + mask = Image.fromarray((mask.numpy() * 255).astype(np.uint8)).convert("RGB") + if self.use_charann: + m, _ = ( + target.extra_fields["char_masks"] + .chars_boxes[0] + .convert("char_mask") + ) + color = self.creat_color_map(37, 255) + color_map = color[m.numpy().astype(np.uint8)] + char = Image.fromarray(color_map.astype(np.uint8)).convert("RGB") + char = Image.blend(char, new_im, 0.5) + else: + char = new_im + new = Image.blend(char, mask, 0.5) + img_draw = ImageDraw.Draw(new) + for box in target.bbox.numpy(): + box = list(box) + box = box[:2] + [box[2], box[1]] + box[2:] + [box[0], box[3]] + box[:2] + img_draw.line(box, fill=(255, 0, 0), width=2) + new.save("./vis/char_" + im_name) + return img, target, self.image_lists[item] + + def creat_color_map(self, n_class, width): + splits = int(np.ceil(np.power((n_class * 1.0), 1.0 / 3))) + maps = [] + for i in range(splits): + r = int(i * width * 1.0 / (splits - 1)) + for j in range(splits): + g = int(j * width * 1.0 / (splits - 1)) + for k in range(splits - 1): + b = int(k * width * 1.0 / (splits - 1)) + maps.append([r, g, b]) + return np.array(maps) + + def __len__(self): + return len(self.image_lists) + + # def load_gt_from_txt(self, gt_path, height=None, width=None): + # words, boxes, charsboxes, segmentations, labels = [], [], [], [], [] + # lines = open(gt_path).readlines() + # for line in lines: + # charbbs = [] + # strs, loc = self.line2boxes(line) + # word = strs[0] + # if word == "###": + # labels.append(-1) + # continue + # else: + # labels.append(1) + # rect = list(loc[0]) + # min_x = min(rect[::2]) - 1 + # min_y = min(rect[1::2]) - 1 + # max_x = max(rect[::2]) - 1 + # max_y = max(rect[1::2]) - 1 + # box = [min_x, min_y, max_x, max_y] + # segmentations.append([loc[0, :]]) + # tindex = len(boxes) + # boxes.append(box) + # words.append(word) + # c_class = self.char2num(strs[1:]) + # charbb = np.zeros((10,), dtype=np.float32) + # if loc.shape[0] > 1: + # for i in range(1, loc.shape[0]): + # charbb[:8] = loc[i, :] + # charbb[8] = c_class[i - 1] + # charbb[9] = tindex + # charbbs.append(charbb.copy()) + # charsboxes.append(charbbs) + # num_boxes = len(boxes) + # if len(boxes) > 0: + # keep_boxes = np.zeros((num_boxes, 5)) + # keep_boxes[:, :4] = np.array(boxes) + # keep_boxes[:, 4] = range( + # num_boxes + # ) # the 5th column is the box label,same as the 10th column of all charsboxes which belong to the box + # if self.use_charann: + # return words, np.array(keep_boxes), charsboxes, segmentations, labels + # else: + # charbbs = np.zeros((10,), dtype=np.float32) + # for i in range(len(words)): + # charsboxes.append([charbbs]) + # return words, np.array(keep_boxes), charsboxes, segmentations, labels + # else: + # words.append("") + # charbbs = np.zeros((10,), dtype=np.float32) + # return ( + # words, + # np.zeros((1, 5), dtype=np.float32), + # [[charbbs]], + # [[np.zeros((8,), dtype=np.float32)]], + # labels + # ) + + def load_gt_from_txt(self, gt_path, height=None, width=None): + words, boxes, charsboxes, segmentations, labels = [], [], [], [], [] + lines = open(gt_path).readlines() + for line in lines: + charbbs = [] + strs, loc = self.line2boxes(line) + word = strs[0] + if self.ignore_difficult: + rect = list(loc[0]) + min_x = min(rect[::2]) - 1 + min_y = min(rect[1::2]) - 1 + max_x = max(rect[::2]) - 1 + max_y = max(rect[1::2]) - 1 + box = [min_x, min_y, max_x, max_y] + # segmentations.append([loc[0, :]]) + segmentations.append([[min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]]) + tindex = len(boxes) + boxes.append(box) + # words.append(word) + if word =='1': + labels.append(-1) + else: + labels.append(1) + charbb = np.zeros((10,), dtype=np.float32) + if loc.shape[0] > 1: + for i in range(1, loc.shape[0]): + charbb[9] = tindex + charbbs.append(charbb.copy()) + charsboxes.append(charbbs) + else: + continue + num_boxes = len(boxes) + if len(boxes) > 0: + keep_boxes = np.zeros((num_boxes, 5)) + keep_boxes[:, :4] = np.array(boxes) + keep_boxes[:, 4] = range( + num_boxes + ) + # the 5th column is the box label, + # same as the 10th column of all charsboxes which belong to the box + if self.use_charann: + return words, np.array(keep_boxes), charsboxes, segmentations, labels + else: + charbbs = np.zeros((10,), dtype=np.float32) + if len(charsboxes) == 0: + for _ in range(len(words)): + charsboxes.append([charbbs]) + return words, np.array(keep_boxes), charsboxes, segmentations, labels + else: + words.append("") + charbbs = np.zeros((10,), dtype=np.float32) + return ( + words, + np.zeros((1, 5), dtype=np.float32), + [[charbbs]], + [[np.zeros((8,), dtype=np.float32)]], + [1] + ) + + + def line2boxes(self, line): + parts = line.strip().split(",") + return [parts[-1]], np.array([[float(x) for x in parts[:-1]]]) + + def check_charbbs(self, charbbs): + xmins = np.minimum.reduce( + [charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]] + ) + xmaxs = np.maximum.reduce( + [charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]] + ) + ymins = np.minimum.reduce( + [charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]] + ) + ymaxs = np.maximum.reduce( + [charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]] + ) + return np.logical_and( + xmaxs - xmins > self.min_proposal_size, + ymaxs - ymins > self.min_proposal_size, + ) + + def check_charbb(self, charbb): + xmins = min(charbb[0], charbb[2], charbb[4], charbb[6]) + xmaxs = max(charbb[0], charbb[2], charbb[4], charbb[6]) + ymins = min(charbb[1], charbb[3], charbb[5], charbb[7]) + ymaxs = max(charbb[1], charbb[3], charbb[5], charbb[7]) + return ( + xmaxs - xmins > self.min_proposal_size + and ymaxs - ymins > self.min_proposal_size + ) + + def char2num(self, chars): + ## chars ['h', 'e', 'l', 'l', 'o'] + nums = [self.char_classes.index(c.lower()) for c in chars] + return nums + + def get_img_info(self, item): + """ + Return the image dimensions for the image, without + loading and pre-processing it + """ + + im_name = os.path.basename(self.image_lists[item]) + img = Image.open(self.image_lists[item]) + width, height = img.size + img_info = {"im_name": im_name, "height": height, "width": width} + return img_info diff --git a/maskrcnn_benchmark/data/datasets/total_text.py b/maskrcnn_benchmark/data/datasets/total_text.py new file mode 100644 index 0000000000000000000000000000000000000000..390d01c8d8aa88dfb9b5dcdc50c3816a113b942f --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/total_text.py @@ -0,0 +1,316 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Simple dataset class that wraps a list of path names +""" + +import os + +import numpy as np +import torch +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.segmentation_mask import ( + CharPolygons, + SegmentationCharMask, + SegmentationMask, +) +from PIL import Image, ImageDraw + + +class TotaltextDataset(object): + def __init__(self, use_charann, imgs_dir, gts_dir, transforms=None, ignore_difficult=False): + self.use_charann = use_charann + self.image_lists = [os.path.join(imgs_dir, img) for img in os.listdir(imgs_dir)] + self.gts_dir = gts_dir + self.transforms = transforms + self.min_proposal_size = 2 + self.char_classes = "_0123456789abcdefghijklmnopqrstuvwxyz" + self.vis = False + self.ignore_difficult = ignore_difficult + if self.ignore_difficult and (self.gts_dir is not None) and 'train' in self.gts_dir: + self.image_lists = self.filter_image_lists() + + def filter_image_lists(self): + new_image_lists = [] + for img_path in self.image_lists: + has_positive = False + im_name = os.path.basename(img_path) + gt_path = os.path.join(self.gts_dir, im_name + ".txt") + if not os.path.isfile(gt_path): + gt_path = os.path.join( + self.gts_dir, "gt_" + im_name.split(".")[0] + ".txt" + ) + lines = open(gt_path, 'r').readlines() + for line in lines: + charbbs = [] + strs, loc = self.line2boxes(line) + word = strs[0] + if word == "###": + continue + else: + has_positive = True + if has_positive: + new_image_lists.append(img_path) + return new_image_lists + + def __getitem__(self, item): + im_name = os.path.basename(self.image_lists[item]) + # print(self.image_lists[item]) + img = Image.open(self.image_lists[item]).convert("RGB") + width, height = img.size + if self.gts_dir is not None: + gt_path = os.path.join(self.gts_dir, im_name + ".txt") + words, boxes, charsbbs, segmentations, labels = self.load_gt_from_txt( + gt_path, height, width + ) + if words[0] == "": + use_char_ann = False + else: + use_char_ann = True + if not self.use_charann: + use_char_ann = False + target = BoxList( + boxes[:, :4], img.size, mode="xyxy", use_char_ann=use_char_ann + ) + if self.ignore_difficult: + labels = torch.from_numpy(np.array(labels)) + else: + labels = torch.ones(len(boxes)) + target.add_field("labels", labels) + masks = SegmentationMask(segmentations, img.size) + target.add_field("masks", masks) + char_masks = SegmentationCharMask( + charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes) + ) + target.add_field("char_masks", char_masks) + else: + target = None + if self.transforms is not None: + img, target = self.transforms(img, target) + if self.vis: + new_im = img.numpy().copy().transpose([1, 2, 0]) + [ + 102.9801, + 115.9465, + 122.7717, + ] + new_im = Image.fromarray(new_im.astype(np.uint8)).convert("RGB") + mask = target.extra_fields["masks"].polygons[0].convert("mask") + mask = Image.fromarray((mask.numpy() * 255).astype(np.uint8)).convert("RGB") + if self.use_charann: + m, _ = ( + target.extra_fields["char_masks"] + .chars_boxes[0] + .convert("char_mask") + ) + color = self.creat_color_map(37, 255) + color_map = color[m.numpy().astype(np.uint8)] + char = Image.fromarray(color_map.astype(np.uint8)).convert("RGB") + char = Image.blend(char, new_im, 0.5) + else: + char = new_im + new = Image.blend(char, mask, 0.5) + img_draw = ImageDraw.Draw(new) + for box in target.bbox.numpy(): + box = list(box) + box = box[:2] + [box[2], box[1]] + box[2:] + [box[0], box[3]] + box[:2] + img_draw.line(box, fill=(255, 0, 0), width=2) + new.save("./vis/char_" + im_name) + return img, target, self.image_lists[item] + + def creat_color_map(self, n_class, width): + splits = int(np.ceil(np.power((n_class * 1.0), 1.0 / 3))) + maps = [] + for i in range(splits): + r = int(i * width * 1.0 / (splits - 1)) + for j in range(splits): + g = int(j * width * 1.0 / (splits - 1)) + for k in range(splits - 1): + b = int(k * width * 1.0 / (splits - 1)) + maps.append([r, g, b]) + return np.array(maps) + + def __len__(self): + return len(self.image_lists) + + # def load_gt_from_txt(self, gt_path, height=None, width=None): + # words, boxes, charsboxes, segmentations, labels = [], [], [], [], [] + # lines = open(gt_path).readlines() + # for line in lines: + # charbbs = [] + # strs, loc = self.line2boxes(line) + # word = strs[0] + # if word == "###": + # labels.append(-1) + # continue + # else: + # labels.append(1) + # rect = list(loc[0]) + # min_x = min(rect[::2]) - 1 + # min_y = min(rect[1::2]) - 1 + # max_x = max(rect[::2]) - 1 + # max_y = max(rect[1::2]) - 1 + # box = [min_x, min_y, max_x, max_y] + # segmentations.append([loc[0, :]]) + # tindex = len(boxes) + # boxes.append(box) + # words.append(word) + # c_class = self.char2num(strs[1:]) + # charbb = np.zeros((10,), dtype=np.float32) + # if loc.shape[0] > 1: + # for i in range(1, loc.shape[0]): + # charbb[:8] = loc[i, :] + # charbb[8] = c_class[i - 1] + # charbb[9] = tindex + # charbbs.append(charbb.copy()) + # charsboxes.append(charbbs) + # num_boxes = len(boxes) + # if len(boxes) > 0: + # keep_boxes = np.zeros((num_boxes, 5)) + # keep_boxes[:, :4] = np.array(boxes) + # keep_boxes[:, 4] = range( + # num_boxes + # ) # the 5th column is the box label,same as the 10th column of all charsboxes which belong to the box + # if self.use_charann: + # return words, np.array(keep_boxes), charsboxes, segmentations, labels + # else: + # charbbs = np.zeros((10,), dtype=np.float32) + # for i in range(len(words)): + # charsboxes.append([charbbs]) + # return words, np.array(keep_boxes), charsboxes, segmentations, labels + # else: + # words.append("") + # charbbs = np.zeros((10,), dtype=np.float32) + # return ( + # words, + # np.zeros((1, 5), dtype=np.float32), + # [[charbbs]], + # [[np.zeros((8,), dtype=np.float32)]], + # labels + # ) + + def load_gt_from_txt(self, gt_path, height=None, width=None): + words, boxes, charsboxes, segmentations, labels = [], [], [], [], [] + lines = open(gt_path).readlines() + for line in lines: + charbbs = [] + strs, loc = self.line2boxes(line) + word = strs[0] + if word == "###": + if self.ignore_difficult: + rect = list(loc[0]) + min_x = min(rect[::2]) - 1 + min_y = min(rect[1::2]) - 1 + max_x = max(rect[::2]) - 1 + max_y = max(rect[1::2]) - 1 + box = [min_x, min_y, max_x, max_y] + # segmentations.append([loc[0, :]]) + segmentations.append([[min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]]) + tindex = len(boxes) + boxes.append(box) + words.append(word) + labels.append(-1) + charbbs = np.zeros((10,), dtype=np.float32) + if loc.shape[0] > 1: + for i in range(1, loc.shape[0]): + charbb[9] = tindex + charbbs.append(charbb.copy()) + charsboxes.append(charbbs) + else: + continue + else: + rect = list(loc[0]) + min_x = min(rect[::2]) - 1 + min_y = min(rect[1::2]) - 1 + max_x = max(rect[::2]) - 1 + max_y = max(rect[1::2]) - 1 + box = [min_x, min_y, max_x, max_y] + segmentations.append([loc[0, :]]) + tindex = len(boxes) + boxes.append(box) + words.append(word) + labels.append(1) + c_class = self.char2num(strs[1:]) + charbb = np.zeros((10,), dtype=np.float32) + if loc.shape[0] > 1: + for i in range(1, loc.shape[0]): + charbb[:8] = loc[i, :] + charbb[8] = c_class[i - 1] + charbb[9] = tindex + charbbs.append(charbb.copy()) + charsboxes.append(charbbs) + num_boxes = len(boxes) + if len(boxes) > 0: + keep_boxes = np.zeros((num_boxes, 5)) + keep_boxes[:, :4] = np.array(boxes) + keep_boxes[:, 4] = range( + num_boxes + ) + # the 5th column is the box label, + # same as the 10th column of all charsboxes which belong to the box + if self.use_charann: + return words, np.array(keep_boxes), charsboxes, segmentations, labels + else: + charbbs = np.zeros((10,), dtype=np.float32) + if len(charsboxes) == 0: + for _ in range(len(words)): + charsboxes.append([charbbs]) + return words, np.array(keep_boxes), charsboxes, segmentations, labels + else: + words.append("") + charbbs = np.zeros((10,), dtype=np.float32) + return ( + words, + np.zeros((1, 5), dtype=np.float32), + [[charbbs]], + [[np.zeros((8,), dtype=np.float32)]], + [1] + ) + + + def line2boxes(self, line): + parts = line.strip().split(",") + return [parts[-1]], np.array([[float(x) for x in parts[:-1]]]) + + def check_charbbs(self, charbbs): + xmins = np.minimum.reduce( + [charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]] + ) + xmaxs = np.maximum.reduce( + [charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]] + ) + ymins = np.minimum.reduce( + [charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]] + ) + ymaxs = np.maximum.reduce( + [charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]] + ) + return np.logical_and( + xmaxs - xmins > self.min_proposal_size, + ymaxs - ymins > self.min_proposal_size, + ) + + def check_charbb(self, charbb): + xmins = min(charbb[0], charbb[2], charbb[4], charbb[6]) + xmaxs = max(charbb[0], charbb[2], charbb[4], charbb[6]) + ymins = min(charbb[1], charbb[3], charbb[5], charbb[7]) + ymaxs = max(charbb[1], charbb[3], charbb[5], charbb[7]) + return ( + xmaxs - xmins > self.min_proposal_size + and ymaxs - ymins > self.min_proposal_size + ) + + def char2num(self, chars): + ## chars ['h', 'e', 'l', 'l', 'o'] + nums = [self.char_classes.index(c.lower()) for c in chars] + return nums + + def get_img_info(self, item): + """ + Return the image dimensions for the image, without + loading and pre-processing it + """ + + im_name = os.path.basename(self.image_lists[item]) + img = Image.open(self.image_lists[item]) + width, height = img.size + img_info = {"im_name": im_name, "height": height, "width": width} + return img_info diff --git a/maskrcnn_benchmark/data/datasets/utils.py b/maskrcnn_benchmark/data/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fd89e059945399c81deff84b5e57522cca78d81c --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/utils.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 + +import os +import shlex +import shutil +import subprocess + + +def extract_archive(dataset_archive, tmp_data_path): + if not os.path.isfile(dataset_archive): + return False + + dataset_ext = os.path.splitext(dataset_archive)[1] + if dataset_ext != ".gz" and dataset_ext != ".tar": + return False + + if os.path.isdir(tmp_data_path): + shutil.rmtree(tmp_data_path, ignore_errors=True) + os.makedirs(tmp_data_path) + + if dataset_ext == ".gz": + tar_opt = "-xzf" + else: + tar_opt = "-xf" + + extract_cmd = ("tar {} {} -C {}").format(tar_opt, dataset_archive, tmp_data_path) + + subprocess.call(shlex.split(extract_cmd)) + + return True + + +def tar_file(tar_path, tmp_path): + tar_name = tar_path.split('/')[-1] + if extract_archive(tar_path, tmp_path): + print('extract ' + tar_name + 'successfully!') + else: + print("fail to extract " + tar_name) diff --git a/maskrcnn_benchmark/data/samplers/__init__.py b/maskrcnn_benchmark/data/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..27982cbe68c6173a911e700273f25973acbf04bd --- /dev/null +++ b/maskrcnn_benchmark/data/samplers/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .distributed import DistributedSampler +from .grouped_batch_sampler import GroupedBatchSampler +from .iteration_based_batch_sampler import IterationBasedBatchSampler + +__all__ = ["DistributedSampler", "GroupedBatchSampler", "IterationBasedBatchSampler"] diff --git a/maskrcnn_benchmark/data/samplers/distributed.py b/maskrcnn_benchmark/data/samplers/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..b71cdf9a17648feeac43f3b19714046b77e66980 --- /dev/null +++ b/maskrcnn_benchmark/data/samplers/distributed.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Code is copy-pasted exactly as in torch.utils.data.distributed, +# with a modification in the import to use the deprecated backend +# FIXME remove this once c10d fixes the bug it has +import math +import torch +import torch.distributed as dist +from torch.utils.data.sampler import Sampler +from maskrcnn_benchmark.utils.comm import get_rank, get_world_size + + +class DistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + # num_replicas = dist.get_world_size() + num_replicas = get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + # rank = dist.get_rank() + rank = get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = True + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset : offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/maskrcnn_benchmark/data/samplers/grouped_batch_sampler.py b/maskrcnn_benchmark/data/samplers/grouped_batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..4915316027da880f56aa414754099b889aa26e2d --- /dev/null +++ b/maskrcnn_benchmark/data/samplers/grouped_batch_sampler.py @@ -0,0 +1,114 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import itertools + +import torch +from torch.utils.data.sampler import BatchSampler +from torch.utils.data.sampler import Sampler + + +class GroupedBatchSampler(BatchSampler): + """ + Wraps another sampler to yield a mini-batch of indices. + It enforces that elements from the same group should appear in groups of batch_size. + It also tries to provide mini-batches which follows an ordering which is + as close as possible to the ordering from the original sampler. + + Arguments: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_uneven (bool): If ``True``, the sampler will drop the batches whose + size is less than ``batch_size`` + + """ + + def __init__(self, sampler, group_ids, batch_size, drop_uneven=False): + if not isinstance(sampler, Sampler): + raise ValueError( + "sampler should be an instance of " + "torch.utils.data.Sampler, but got sampler={}".format(sampler) + ) + self.sampler = sampler + self.group_ids = torch.as_tensor(group_ids) + assert self.group_ids.dim() == 1 + self.batch_size = batch_size + self.drop_uneven = drop_uneven + + self.groups = torch.unique(self.group_ids).sort(0)[0] + + self._can_reuse_batches = False + + def _prepare_batches(self): + dataset_size = len(self.group_ids) + # get the sampled indices from the sampler + sampled_ids = torch.as_tensor(list(self.sampler)) + # potentially not all elements of the dataset were sampled + # by the sampler (e.g., DistributedSampler). + # construct a tensor which contains -1 if the element was + # not sampled, and a non-negative number indicating the + # order where the element was sampled. + # for example. if sampled_ids = [3, 1] and dataset_size = 5, + # the order is [-1, 1, -1, 0, -1] + order = torch.full((dataset_size,), -1, dtype=torch.int64) + order[sampled_ids] = torch.arange(len(sampled_ids)) + + # get a mask with the elements that were sampled + mask = order >= 0 + + # find the elements that belong to each individual cluster + clusters = [(self.group_ids == i) & mask for i in self.groups] + # get relative order of the elements inside each cluster + # that follows the order from the sampler + relative_order = [order[cluster] for cluster in clusters] + # with the relative order, find the absolute order in the + # sampled space + permutation_ids = [s[s.sort()[1]] for s in relative_order] + # permute each cluster so that they follow the order from + # the sampler + permuted_clusters = [sampled_ids[idx] for idx in permutation_ids] + + # splits each cluster in batch_size, and merge as a list of tensors + splits = [c.split(self.batch_size) for c in permuted_clusters] + merged = tuple(itertools.chain.from_iterable(splits)) + # now each batch internally has the right order, but + # they are grouped by clusters. Find the permutation between + # different batches that brings them as close as possible to + # the order that we have in the sampler. For that, we will consider the + # ordering as coming from the first element of each batch, and sort + # correspondingly + first_element_of_batch = [t[0].item() for t in merged] + # get and inverse mapping from sampled indices and the position where + # they occur (as returned by the sampler) + inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())} + # from the first element in each batch, get a relative ordering + first_index_of_batch = torch.as_tensor( + [inv_sampled_ids_map[s] for s in first_element_of_batch] + ) + + # permute the batches so that they approximately follow the order + # from the sampler + permutation_order = first_index_of_batch.sort(0)[1].tolist() + # finally, permute the batches + batches = [merged[i].tolist() for i in permutation_order] + + if self.drop_uneven: + kept = [] + for batch in batches: + if len(batch) == self.batch_size: + kept.append(batch) + batches = kept + return batches + + def __iter__(self): + if self._can_reuse_batches: + batches = self._batches + self._can_reuse_batches = False + else: + batches = self._prepare_batches() + self._batches = batches + return iter(batches) + + def __len__(self): + if not hasattr(self, "_batches"): + self._batches = self._prepare_batches() + self._can_reuse_batches = True + return len(self._batches) diff --git a/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py b/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..93452b64696dc9b2cd2a347b8051729864bf9510 --- /dev/null +++ b/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from torch.utils.data.sampler import BatchSampler + + +class IterationBasedBatchSampler(BatchSampler): + """ + Wraps a BatchSampler, resampling from it until + a specified number of iterations have been sampled + """ + + def __init__(self, batch_sampler, num_iterations, start_iter=0): + self.batch_sampler = batch_sampler + self.num_iterations = num_iterations + self.start_iter = start_iter + + def __iter__(self): + iteration = self.start_iter + while iteration <= self.num_iterations: + # if the underlying sampler has a set_epoch method, like + # DistributedSampler, used for making each process see + # a different split of the dataset, then set it + if hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(iteration) + for batch in self.batch_sampler: + iteration += 1 + if iteration > self.num_iterations: + break + yield batch + + def __len__(self): + return self.num_iterations diff --git a/maskrcnn_benchmark/data/transforms/__init__.py b/maskrcnn_benchmark/data/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..076f8e98f7d852ef28094ec9790eb69e56d8c68c --- /dev/null +++ b/maskrcnn_benchmark/data/transforms/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .transforms import Compose +from .transforms import Resize +from .transforms import RandomHorizontalFlip +from .transforms import ToTensor +from .transforms import Normalize + +from .build import build_transforms + diff --git a/maskrcnn_benchmark/data/transforms/build.py b/maskrcnn_benchmark/data/transforms/build.py new file mode 100644 index 0000000000000000000000000000000000000000..118fe8b93a361164db23fb2738d214e21fbf4574 --- /dev/null +++ b/maskrcnn_benchmark/data/transforms/build.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from . import transforms as T + + +def build_transforms(cfg, is_train=True): + to_bgr255 = cfg.INPUT.TO_BGR255 + normalize_transform = T.Normalize( + mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=to_bgr255 + ) + if is_train: + min_size = cfg.INPUT.MIN_SIZE_TRAIN + max_size = cfg.INPUT.MAX_SIZE_TRAIN + # flip_prob = 0.5 # cfg.INPUT.FLIP_PROB_TRAIN + # flip_prob = 0 + # rotate_prob = 0.5 + rotate_prob = 0.5 + pixel_aug_prob = 0.2 + random_crop_prob = cfg.DATASETS.RANDOM_CROP_PROB + else: + min_size = cfg.INPUT.MIN_SIZE_TEST + max_size = cfg.INPUT.MAX_SIZE_TEST + # flip_prob = 0 + rotate_prob = 0 + pixel_aug_prob = 0 + random_crop_prob = 0 + + to_bgr255 = cfg.INPUT.TO_BGR255 + normalize_transform = T.Normalize( + mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=to_bgr255 + ) + if cfg.DATASETS.AUG and is_train: + if cfg.DATASETS.FIX_CROP: + transform = T.Compose( + [ + T.RandomCrop(1.0, crop_min_size=512, crop_max_size=640, max_trys=50), + T.RandomBrightness(pixel_aug_prob), + T.RandomContrast(pixel_aug_prob), + T.RandomHue(pixel_aug_prob), + T.RandomSaturation(pixel_aug_prob), + T.RandomGamma(pixel_aug_prob), + T.RandomRotate(rotate_prob), + T.Resize(min_size, max_size, cfg.INPUT.STRICT_RESIZE), + T.ToTensor(), + normalize_transform, + ] + ) + else: + transform = T.Compose( + [ + T.RandomCrop(random_crop_prob), + T.RandomBrightness(pixel_aug_prob), + T.RandomContrast(pixel_aug_prob), + T.RandomHue(pixel_aug_prob), + T.RandomSaturation(pixel_aug_prob), + T.RandomGamma(pixel_aug_prob), + T.RandomRotate(rotate_prob, max_theta=cfg.DATASETS.MAX_ROTATE_THETA, fix_rotate=cfg.DATASETS.FIX_ROTATE), + T.Resize(min_size, max_size, cfg.INPUT.STRICT_RESIZE), + T.ToTensor(), + normalize_transform, + ] + ) + else: + transform = T.Compose( + [ + T.Resize(min_size, max_size, cfg.INPUT.STRICT_RESIZE), + T.ToTensor(), + normalize_transform, + ] + ) + return transform diff --git a/maskrcnn_benchmark/data/transforms/transforms.py b/maskrcnn_benchmark/data/transforms/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..e880a87c9cc95d81cd6910c5191afd0482407760 --- /dev/null +++ b/maskrcnn_benchmark/data/transforms/transforms.py @@ -0,0 +1,376 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import random + +import cv2 +import numpy as np +from PIL import Image +from shapely import affinity +from shapely.geometry import Polygon +from torchvision.transforms import functional as F + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class Resize(object): + def __init__(self, min_size, max_size, strict_resize): + self.min_size = min_size + self.max_size = max_size + self.strict_resize = strict_resize + + # modified from torchvision to add support for max size + def get_size(self, image_size): + w, h = image_size + if isinstance(self.min_size, tuple): + if len(self.min_size) == 1: + size = self.min_size[0] + else: + random_size_index = random.randint(0, len(self.min_size) - 1) + size = self.min_size[random_size_index] + else: + size = self.min_size + max_size = self.max_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + if self.strict_resize: + h = h if h % 32 == 0 else (h // 32) * 32 + w = w if w % 32 == 0 else (w // 32) * 32 + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + if self.strict_resize: + oh = oh if oh % 32 == 0 else (oh // 32) * 32 + ow = ow if ow % 32 == 0 else (ow // 32) * 32 + + return (oh, ow) + + def __call__(self, image, target): + size = self.get_size(image.size) + image = F.resize(image, size) + if target is not None: + target = target.resize(image.size) + return image, target + + +class RandomCrop(object): + def __init__(self, prob, crop_min_size=500, crop_max_size=1000, max_trys=50): + self.min_size = crop_min_size + self.max_size = crop_max_size + self.max_trys = max_trys + self.prob = prob + + def __call__(self, image, target): + if random.random() < self.prob: + im = np.array(image) + w, h = image.size + h_array = np.zeros((h), dtype=np.int32) + w_array = np.zeros((w), dtype=np.int32) + boxes = target.bbox.numpy() + if len(boxes) == 0: + return image, target + for box in boxes: + box = np.round(box, decimals=0).astype(np.int32) + minx = box[0] + maxx = box[2] + w_array[minx:maxx] = 1 + miny = box[1] + maxy = box[3] + h_array[miny:maxy] = 1 + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + if len(h_axis) == 0 or len(w_axis) == 0: + return image, target + for _ in range(self.max_trys): + xx = np.random.choice(w_axis, size=2) + xmin = min(xx) + xmax = max(xx) + x_size = xmax - xmin + if x_size > self.max_size or x_size < self.min_size: + continue + yy = np.random.choice(h_axis, size=2) + ymin = min(yy) + ymax = max(yy) + y_size = ymax - ymin + if y_size > self.max_size or y_size < self.min_size: + continue + box_in_area = ( + (boxes[:, 0] >= xmin) + & (boxes[:, 1] >= ymin) + & (boxes[:, 2] <= xmax) + & (boxes[:, 3] <= ymax) + ) + if len(np.where(box_in_area)[0]) == 0: + continue + im = im[ymin:ymax, xmin:xmax] + target = target.crop([xmin, ymin, xmax, ymax]) + return Image.fromarray(im), target + return image, target + else: + return image, target + + +# class RandomCropFixSize(object): +# def __init__(self, prob, crop_size=512, max_trys=50): +# self.crop_size = crop_size +# self.max_trys = max_trys +# self.prob = prob + +# def __call__(self, image, target): +# if random.random() < self.prob: +# im = np.array(image) +# w, h = image.size +# h_array = np.zeros((h), dtype=np.int32) +# w_array = np.zeros((w), dtype=np.int32) +# boxes = target.bbox.numpy() +# if len(boxes) == 0: +# return image, target +# for box in boxes: +# box = np.round(box, decimals=0).astype(np.int32) +# minx = box[0] +# maxx = box[2] +# w_array[minx:maxx] = 1 +# miny = box[1] +# maxy = box[3] +# h_array[miny:maxy] = 1 +# h_axis = np.where(h_array == 0)[0] +# w_axis = np.where(w_array == 0)[0] +# if len(h_axis) == 0 or len(w_axis) == 0: +# return image, target +# for _ in range(self.max_trys): +# xx = np.random.choice(w_axis, size=2) +# xmin = min(xx) +# xmax = max(xx) +# x_size = xmax - xmin +# if x_size > self.max_size or x_size < self.min_size: +# continue +# yy = np.random.choice(h_axis, size=2) +# ymin = min(yy) +# ymax = max(yy) +# y_size = ymax - ymin +# if y_size > self.max_size or y_size < self.min_size: +# continue +# box_in_area = ( +# (boxes[:, 0] >= xmin) +# & (boxes[:, 1] >= ymin) +# & (boxes[:, 2] <= xmax) +# & (boxes[:, 3] <= ymax) +# ) +# if len(np.where(box_in_area)[0]) == 0: +# continue +# im = im[ymin:ymax, xmin:xmax] +# target = target.crop([xmin, ymin, xmax, ymax]) +# return Image.fromarray(im), target +# return image, target +# else: +# return image, target + + +class RandomHorizontalFlip(object): + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, image, target): + if random.random() < self.prob: + image = F.hflip(image) + target = target.transpose(0) + return image, target + + +class ToTensor(object): + def __call__(self, image, target): + return F.to_tensor(image), target + + +class Normalize(object): + def __init__(self, mean, std, to_bgr255=True): + self.mean = mean + self.std = std + self.to_bgr255 = to_bgr255 + + def __call__(self, image, target): + if self.to_bgr255: + image = image[[2, 1, 0]] * 255 + image = F.normalize(image, mean=self.mean, std=self.std) + return image, target + + +class RandomBrightness(object): + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, image, target): + if random.random() < self.prob: + brightness_factor = random.uniform(0.5, 2) + image = F.adjust_brightness(image, brightness_factor) + return image, target + + +class RandomContrast(object): + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, image, target): + if random.random() < self.prob: + contrast_factor = random.uniform(0.5, 2) + image = F.adjust_contrast(image, contrast_factor) + return image, target + + +class RandomHue(object): + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, image, target): + if random.random() < self.prob: + hue_factor = random.uniform(-0.25, 0.25) + image = F.adjust_hue(image, hue_factor) + return image, target + + +class RandomSaturation(object): + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, image, target): + if random.random() < self.prob: + saturation_factor = random.uniform(0.5, 2) + image = F.adjust_saturation(image, saturation_factor) + return image, target + + +class RandomGamma(object): + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, image, target): + if random.random() < self.prob: + gamma_factor = random.uniform(0.5, 2) + image = F.adjust_gamma(image, gamma_factor) + return image, target + + +class RandomRotate(object): + def __init__(self, prob, max_theta=30, fix_rotate=False): + self.prob = prob + self.max_theta = max_theta + self.fix_rotate = fix_rotate + + def __call__(self, image, target): + if random.random() < self.prob and target is not None: + # try: + if self.fix_rotate: + delta = 30 + else: + delta = random.uniform(-1 * self.max_theta, self.max_theta) + width, height = image.size + ## get the minimal rect to cover the rotated image + img_box = [[[0, 0], [width, 0], [width, height], [0, height]]] + rotated_img_box = _quad2minrect( + _rotate_polygons(img_box, delta, (width / 2, height / 2)) + ) + r_height = int( + max(rotated_img_box[0][3], rotated_img_box[0][1]) + - min(rotated_img_box[0][3], rotated_img_box[0][1]) + ) + r_width = int( + max(rotated_img_box[0][2], rotated_img_box[0][0]) + - min(rotated_img_box[0][2], rotated_img_box[0][0]) + ) + r_height = max(r_height, height + 1) + r_width = max(r_width, width + 1) + + ## padding im + im_padding = np.zeros((r_height, r_width, 3)) + start_h, start_w = ( + int((r_height - height) / 2.0), + int((r_width - width) / 2.0), + ) + end_h, end_w = start_h + height, start_w + width + im_padding[start_h:end_h, start_w:end_w, :] = image + + M = cv2.getRotationMatrix2D((r_width / 2, r_height / 2), delta, 1) + im = cv2.warpAffine(im_padding, M, (r_width, r_height)) + im = Image.fromarray(im.astype(np.uint8)) + target = target.rotate( + -delta, (r_width / 2, r_height / 2), start_h, start_w + ) + return im, target + # except: + # return image, target + else: + return image, target + + +def _quad2minrect(boxes): + ## trans a quad(N*4) to a rectangle(N*4) which has miniual area to cover it + return np.hstack( + ( + boxes[:, ::2].min(axis=1).reshape((-1, 1)), + boxes[:, 1::2].min(axis=1).reshape((-1, 1)), + boxes[:, ::2].max(axis=1).reshape((-1, 1)), + boxes[:, 1::2].max(axis=1).reshape((-1, 1)), + ) + ) + + +def _boxlist2quads(boxlist): + res = np.zeros((len(boxlist), 8)) + for i, box in enumerate(boxlist): + # print(box) + res[i] = np.array( + [ + box[0][0], + box[0][1], + box[1][0], + box[1][1], + box[2][0], + box[2][1], + box[3][0], + box[3][1], + ] + ) + return res + + +def _rotate_polygons(polygons, angle, r_c): + ## polygons: N*8 + ## r_x: rotate center x + ## r_y: rotate center y + ## angle: -15~15 + + rotate_boxes_list = [] + for poly in polygons: + box = Polygon(poly) + rbox = affinity.rotate(box, angle, r_c) + if len(list(rbox.exterior.coords)) < 5: + print("img_box_ori:", poly) + print("img_box_rotated:", rbox) + # assert(len(list(rbox.exterior.coords))>=5) + rotate_boxes_list.append(rbox.boundary.coords[:-1]) + res = _boxlist2quads(rotate_boxes_list) + return res diff --git a/maskrcnn_benchmark/engine/text_inference.py b/maskrcnn_benchmark/engine/text_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a15f5b67019107c5a94e20ce753af2b678ae3e31 --- /dev/null +++ b/maskrcnn_benchmark/engine/text_inference.py @@ -0,0 +1,560 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import datetime +import logging +import os +import pickle +import subprocess +import time + +import cv2 +import numpy as np +import torch +from maskrcnn_benchmark.utils.chars import char2num, get_tight_rect, getstr_grid +from PIL import Image, ImageDraw +from tqdm import tqdm + +from ..utils.comm import is_main_process, scatter_gather, synchronize +import pdb + +# TO DO: format output with dictionnary +def compute_on_dataset(model, data_loader, device, cfg): + model.eval() + results_dict = {} + seg_results = [] + cpu_device = torch.device("cpu") + total_time = 0 + for _, batch in tqdm(enumerate(data_loader)): + images, targets, image_paths = batch + images = images.to(device) + with torch.no_grad(): + if cfg.MODEL.SEG_ON: + predictions, proposals, seg_results_dict = model( + images + ) + seg_results.append( + [image_paths, proposals, seg_results_dict['rotated_boxes'], seg_results_dict['polygons'], seg_results_dict['preds'], seg_results_dict['scores']] + ) + # if cfg.MODEL.MASK_ON and predictions is not None: + if predictions is not None: + if cfg.MODEL.CHAR_MASK_ON or cfg.SEQUENCE.SEQ_ON: + global_predictions = predictions[0] + char_predictions = predictions[1] + char_mask = char_predictions["char_mask"] + boxes = char_predictions["boxes"] + seq_words = char_predictions["seq_outputs"] + seq_scores = char_predictions["seq_scores"] + detailed_seq_scores = char_predictions["detailed_seq_scores"] + global_predictions = [o.to(cpu_device) for o in global_predictions] + results_dict.update( + { + image_paths[0]: [ + global_predictions[0], + char_mask, + boxes, + seq_words, + seq_scores, + detailed_seq_scores, + ] + } + ) + else: + global_predictions = [o.to(cpu_device) for o in predictions] + results_dict.update( + { + image_paths[0]: [ + global_predictions[0], + ] + } + ) + else: + predictions = model(images) + if predictions is not None: + if not (cfg.MODEL.CHAR_MASK_ON and cfg.SEQUENCE.SEQ_ON): + global_predictions = predictions + global_predictions = [o.to(cpu_device) for o in global_predictions] + results_dict.update( + { + image_paths[0]: [ + global_predictions[0], + ] + } + ) + else: + global_predictions = predictions[0] + char_predictions = predictions[1] + if cfg.MODEL.CHAR_MASK_ON: + char_mask = char_predictions["char_mask"] + else: + char_mask = None + boxes = char_predictions["boxes"] + seq_words = char_predictions["seq_outputs"] + seq_scores = char_predictions["seq_scores"] + detailed_seq_scores = char_predictions["detailed_seq_scores"] + global_predictions = [o.to(cpu_device) for o in global_predictions] + results_dict.update( + { + image_paths[0]: [ + global_predictions[0], + char_mask, + boxes, + seq_words, + seq_scores, + detailed_seq_scores, + ] + } + ) + return results_dict, seg_results + + +def polygon2rbox(polygon, image_height, image_width): + poly = np.array(polygon).reshape((-1, 2)) + rect = cv2.minAreaRect(poly) + corners = cv2.boxPoints(rect) + corners = np.array(corners, dtype="int") + pts = get_tight_rect(corners, 0, 0, image_height, image_width, 1) + pts = list(map(int, pts)) + return pts + + +def mask2polygon(mask, box, im_size, threshold=0.5, output_folder=None): + # mask 32*128 + image_width, image_height = im_size[0], im_size[1] + box_h = box[3] - box[1] + box_w = box[2] - box[0] + cls_polys = (mask * 255).astype(np.uint8) + poly_map = np.array(Image.fromarray(cls_polys).resize((box_w, box_h))) + poly_map = poly_map.astype(np.float32) / 255 + poly_map = cv2.GaussianBlur(poly_map, (3, 3), sigmaX=3) + ret, poly_map = cv2.threshold(poly_map, threshold, 1, cv2.THRESH_BINARY) + if "total_text" in output_folder or "cute80" in output_folder: + SE1 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) + poly_map = cv2.erode(poly_map, SE1) + poly_map = cv2.dilate(poly_map, SE1) + poly_map = cv2.morphologyEx(poly_map, cv2.MORPH_CLOSE, SE1) + try: + _, contours, _ = cv2.findContours( + (poly_map * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE + ) + except: + contours, _ = cv2.findContours( + (poly_map * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE + ) + if len(contours) == 0: + # print(contours) + # print(len(contours)) + return None + max_area = 0 + max_cnt = contours[0] + for cnt in contours: + area = cv2.contourArea(cnt) + if area > max_area: + max_area = area + max_cnt = cnt + # perimeter = cv2.arcLength(max_cnt, True) + epsilon = 0.01 * cv2.arcLength(max_cnt, True) + approx = cv2.approxPolyDP(max_cnt, epsilon, True) + pts = approx.reshape((-1, 2)) + pts[:, 0] = pts[:, 0] + box[0] + pts[:, 1] = pts[:, 1] + box[1] + polygon = list(pts.reshape((-1,))) + polygon = list(map(int, polygon)) + if len(polygon) < 6: + return None + else: + SE1 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) + poly_map = cv2.erode(poly_map, SE1) + poly_map = cv2.dilate(poly_map, SE1) + poly_map = cv2.morphologyEx(poly_map, cv2.MORPH_CLOSE, SE1) + idy, idx = np.where(poly_map == 1) + xy = np.vstack((idx, idy)) + xy = np.transpose(xy) + hull = cv2.convexHull(xy, clockwise=True) + # reverse order of points. + if hull is None: + return None + hull = hull[::-1] + # find minimum area bounding box. + rect = cv2.minAreaRect(hull) + corners = cv2.boxPoints(rect) + corners = np.array(corners, dtype="int") + pts = get_tight_rect(corners, box[0], box[1], image_height, image_width, 1) + polygon = [x * 1.0 for x in pts] + polygon = list(map(int, polygon)) + return polygon + + +def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu): + all_predictions = scatter_gather(predictions_per_gpu) + if not is_main_process(): + return + # merge the list of dicts + predictions = {} + for p in all_predictions: + predictions.update(p) + return predictions + + +def format_output(out_dir, boxes, img_name): + with open( + os.path.join(out_dir, "res_" + img_name.split(".")[0] + ".txt"), "wt" + ) as res: + ## char score save dir + ssur_name = os.path.join(out_dir, "res_" + img_name.split(".")[0]) + for i, box in enumerate(boxes): + save_name = ssur_name + "_" + str(i) + ".pkl" + save_dict = {} + if "total_text" in out_dir or "cute80" in out_dir: + # np.save(save_name, box[-2]) + save_dict["seg_char_scores"] = box[-3] + save_dict["seq_char_scores"] = box[-2] + box = ( + ",".join([str(x) for x in box[:4]]) + + ";" + + ",".join([str(x) for x in box[4 : 4 + int(box[-1])]]) + + ";" + + ",".join([str(x) for x in box[4 + int(box[-1]) : -3]]) + + "," + + save_name + ) + else: + save_dict["seg_char_scores"] = box[-2] + save_dict["seq_char_scores"] = box[-1] + np.save(save_name, box[-1]) + box = ",".join([str(x) for x in box[:-2]]) + "," + save_name + with open(save_name, "wb") as f: + pickle.dump(save_dict, f, protocol=2) + res.write(box + "\n") + +def format_seg_output(results_dir, rotated_boxes_this_image, polygons_this_image, scores, img_name, ratio): + height_ratio, width_ratio = ratio + with open( + os.path.join(results_dir, "res_" + img_name.split(".")[0] + ".txt"), "wt" + ) as res: + if "total_text" in results_dir or "cute80" in results_dir: + for i, box in enumerate(polygons_this_image): + box = box[0] + box[0::2] = box[0::2] * width_ratio + box[1::2] = box[1::2] * height_ratio + save_dict = {} + # result = ",".join([str(int(x[0])) + ',' +str(int(x[1])) for x in box]) + result = ",".join([str(int(x)) for x in box]) + score = scores[i].item() + res.write(result + ',' + str(score) + "\n") + else: + for i, box in enumerate(rotated_boxes_this_image): + box[0::2] = box[0::2] * width_ratio + box[1::2] = box[1::2] * height_ratio + save_dict = {} + result = ",".join([str(int(x[0])) + ',' +str(int(x[1])) for x in box]) + score = scores[i].item() + res.write(result + ',' + str(score) + "\n") + + + +def process_char_mask(char_masks, boxes, threshold=192): + texts, rec_scores, rec_char_scores, char_polygons = [], [], [], [] + for index in range(char_masks.shape[0]): + box = list(boxes[index]) + box = list(map(int, box)) + text, rec_score, rec_char_score, char_polygon = getstr_grid( + char_masks[index, :, :, :].copy(), box, threshold=threshold + ) + texts.append(text) + rec_scores.append(rec_score) + rec_char_scores.append(rec_char_score) + char_polygons.append(char_polygon) + # segmss.append(segms) + return texts, rec_scores, rec_char_scores, char_polygons + + +def creat_color_map(n_class, width): + splits = int(np.ceil(np.power((n_class * 1.0), 1.0 / 3))) + maps = [] + for i in range(splits): + r = int(i * width * 1.0 / (splits - 1)) + for j in range(splits): + g = int(j * width * 1.0 / (splits - 1)) + for k in range(splits - 1): + b = int(k * width * 1.0 / (splits - 1)) + maps.append((r, g, b, 200)) + return maps + + +def visualization(image, polygons, resize_ratio, colors, char_polygons=None, words=None): + draw = ImageDraw.Draw(image, "RGBA") + for polygon in polygons: + # draw.polygon(polygon, fill=None, outline=(0, 255, 0, 255)) + # print(polygon) + polygon.append(polygon[0]) + polygon.append(polygon[1]) + # print(polygon) + color = '#33FF33' + draw.line(polygon, fill=color, width=5) + # if char_polygons is not None: + # for i, char_polygon in enumerate(char_polygons): + # for j, polygon in enumerate(char_polygon): + # polygon = [int(x * resize_ratio) for x in polygon] + # char = words[i][j] + # color = colors[char2num(char)] + # draw.polygon(polygon, fill=color, outline=color) + + +def vis_seg_map(image_path, seg_map, rotated_boxes, polygons_this_image, proposals, vis_dir): + img_name = image_path.split("/")[-1] + image = cv2.imread(image_path) + height, width, _ = image.shape + seg_map = seg_map.data.cpu().numpy() + img = Image.fromarray(image).convert("RGB") + # height_ratio = height / seg_map.shape[1] + # width_ratio = width / seg_map.shape[2] + # print('seg_map.shape:', seg_map.shape) + # print('image.shape:', image.shape) + seg_image = ( + Image.fromarray((seg_map[0, :proposals.size[1], :proposals.size[0]] * 255).astype(np.uint8)) + .convert("RGB") + .resize((width, height)) + ) + visu_image = Image.blend(seg_image, img, 0.5) + img_draw = ImageDraw.Draw(visu_image) + if "total_text" in vis_dir or "cute80" in vis_dir: + for box in polygons_this_image: + # box[:, 0] = box[:, 0] + # box[:, 1] = box[:, 1] + tuple_box = [tuple(x) for x in box[0].reshape(-1, 2).tolist()] + tuple_box.append(tuple_box[0]) + img_draw.line(tuple_box, fill=(0, 255, 0), width=5) + else: + for box in rotated_boxes: + # box[:, 0] = box[:, 0] + # box[:, 1] = box[:, 1] + tuple_box = [tuple(x) for x in box.tolist()] + tuple_box.append(tuple_box[0]) + img_draw.line(tuple_box, fill=(0, 255, 0), width=5) + visu_image.save(vis_dir + "/seg_" + img_name) + + +def prepare_results_for_evaluation( + predictions, output_folder, model_name, seg_predictions=None, vis=False, cfg=None +): + results_dir = os.path.join(output_folder, model_name + "_results") + if not os.path.isdir(results_dir): + os.mkdir(results_dir) + seg_results_dir = os.path.join(output_folder, model_name + "_seg_results") + if not os.path.isdir(seg_results_dir): + os.mkdir(seg_results_dir) + if vis: + visu_dir = os.path.join(output_folder, model_name + "_visu") + if not os.path.isdir(visu_dir): + os.mkdir(visu_dir) + seg_visu_dir = os.path.join(output_folder, model_name + "_seg_visu") + if not os.path.isdir(seg_visu_dir): + os.mkdir(seg_visu_dir) + if len(seg_predictions) > 0: + for seg_prediction in seg_predictions: + image_paths, proposals, rotated_boxes, polygons, seg_maps, seg_scores = ( + seg_prediction[0], + seg_prediction[1], + seg_prediction[2], + seg_prediction[3], + seg_prediction[4], + seg_prediction[5], + ) + for batch_id in range(len(image_paths)): + image_path = image_paths[batch_id] + im_name = image_path.split("/")[-1] + image = cv2.imread(image_path) + height, width, _ = image.shape + rotated_boxes_this_image = rotated_boxes[batch_id] + polygons_this_image = polygons[batch_id] + proposals_this_image = proposals[batch_id] + seg_map = seg_maps[batch_id] + seg_score = seg_scores[batch_id] + height, width, _ = image.shape + height_ratio = height / proposals_this_image.size[1] + width_ratio = width / proposals_this_image.size[0] + format_seg_output(seg_results_dir, rotated_boxes_this_image, polygons_this_image, seg_score, im_name, (height_ratio, width_ratio)) + if vis: + vis_seg_map(image_path, seg_map, rotated_boxes_this_image, polygons_this_image, proposals_this_image, seg_visu_dir) + if (not cfg.MODEL.TRAIN_DETECTION_ONLY): + for image_path, prediction in predictions.items(): + im_name = image_path.split("/")[-1] + if cfg.MODEL.CHAR_MASK_ON or cfg.SEQUENCE.SEQ_ON: + global_prediction, char_mask, boxes_char, seq_words, seq_scores, detailed_seq_scores = ( + prediction[0], + prediction[1], + prediction[2], + prediction[3], + prediction[4], + prediction[5], + ) + if char_mask is not None: + words, rec_scores, rec_char_scoress, char_polygons = process_char_mask( + char_mask, boxes_char + ) + else: + global_prediction = prediction[0] + test_image_width, test_image_height = global_prediction.size + img = Image.open(image_path) + width, height = img.size + resize_ratio = float(height) / test_image_height + global_prediction = global_prediction.resize((width, height)) + boxes = global_prediction.bbox.tolist() + if cfg.MODEL.ROI_BOX_HEAD.INFERENCE_USE_BOX: + scores = global_prediction.get_field("scores").tolist() + if not cfg.MODEL.SEG.USE_SEG_POLY: + masks = global_prediction.get_field("mask").cpu().numpy() + else: + masks = global_prediction.get_field("masks").get_polygons() + result_logs = [] + polygons = [] + for k, box in enumerate(boxes): + if box[2] - box[0] < 1 or box[3] - box[1] < 1: + continue + box = list(map(int, box)) + if not cfg.MODEL.SEG.USE_SEG_POLY: + mask = masks[k, 0, :, :] + polygon = mask2polygon( + mask, box, img.size, threshold=0.5, output_folder=output_folder + ) + else: + polygon = list(masks[k].get_polygons()[0].cpu().numpy()) + if not ("total_text" in output_folder or "cute80" in output_folder): + polygon = polygon2rbox(polygon, height, width) + if polygon is None: + polygon = [ + box[0], + box[1], + box[2], + box[1], + box[2], + box[3], + box[0], + box[3], + ] + continue + polygons.append(polygon) + if cfg.MODEL.ROI_BOX_HEAD.INFERENCE_USE_BOX: + score = scores[k] + else: + score = 1.0 + if cfg.MODEL.CHAR_MASK_ON or cfg.SEQUENCE.SEQ_ON: + if char_mask is None: + word = 'aaa' + rec_score = 1.0 + char_score = None + else: + word = words[k] + rec_score = rec_scores[k] + char_score = rec_char_scoress[k] + seq_word = seq_words[k] + seq_char_scores = seq_scores[k] + seq_score = sum(seq_char_scores) / float(len(seq_char_scores)) + detailed_seq_score = detailed_seq_scores[k] + detailed_seq_score = np.squeeze(np.array(detailed_seq_score), axis=1) + else: + word = 'aaa' + rec_score = 1.0 + char_score = [1.0, 1.0, 1.0] + seq_word = 'aaa' + seq_char_scores = [1.0, 1.0, 1.0] + seq_score = 1.0 + detailed_seq_score = None + if "total_text" in output_folder or "cute80" in output_folder: + result_log = ( + [int(x * 1.0) for x in box[:4]] + + polygon + + [word] + + [seq_word] + + [score] + + [rec_score] + + [seq_score] + + [char_score] + + [detailed_seq_score] + + [len(polygon)] + ) + else: + result_log = ( + [int(x * 1.0) for x in box[:4]] + + polygon + + [word] + + [seq_word] + + [score] + + [rec_score] + + [seq_score] + + [char_score] + + [detailed_seq_score] + ) + result_logs.append(result_log) + if vis: + colors = creat_color_map(37, 255) + if cfg.MODEL.CHAR_MASK_ON: + visualization(img, polygons, resize_ratio, colors, char_polygons, words) + else: + visualization(img, polygons, resize_ratio, colors) + img.save(os.path.join(visu_dir, im_name)) + format_output(results_dir, result_logs, im_name) + + +def inference( + model, + data_loader, + iou_types=("bbox",), + box_only=False, + device="cuda", + expected_results=(), + expected_results_sigma_tol=4, + output_folder=None, + model_name=None, + cfg=None, +): + + # convert to a torch.device for efficiency + model_name = model_name.split(".")[0] + "_" + str(cfg.INPUT.MIN_SIZE_TEST) + predictions_path = os.path.join(output_folder, model_name + "_predictions.pth") + seg_predictions_path = os.path.join( + output_folder, model_name + "_seg_predictions.pth" + ) + # if os.path.isfile(predictions_path) and os.path.isfile(seg_predictions_path): + if False: + predictions = torch.load(predictions_path) + seg_predictions = torch.load(seg_predictions_path) + else: + device = torch.device(device) + num_devices = ( + torch.distributed.get_world_size() + if torch.distributed.is_initialized() + else 1 + ) + logger = logging.getLogger("maskrcnn_benchmark.inference") + dataset = data_loader.dataset + logger.info("Start evaluation on {} images".format(len(dataset))) + start_time = time.time() + predictions, seg_predictions = compute_on_dataset( + model, data_loader, device, cfg + ) + # wait for all processes to complete before measuring the time + synchronize() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=total_time)) + logger.info( + "Total inference time: {} ({} s / img per device, on {} devices)".format( + total_time_str, total_time * num_devices / len(dataset), num_devices + ) + ) + + # predictions = _accumulate_predictions_from_multiple_gpus(predictions) + # if not is_main_process(): + # return + + if output_folder: + torch.save(predictions, predictions_path) + torch.save(seg_predictions, seg_predictions_path) + + prepare_results_for_evaluation( + predictions, + output_folder, + model_name, + seg_predictions=seg_predictions, + vis=cfg.TEST.VIS, + cfg=cfg + ) diff --git a/maskrcnn_benchmark/engine/trainer.py b/maskrcnn_benchmark/engine/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..552e6a98a7b45b79ede60c7e76507c03f07990ef --- /dev/null +++ b/maskrcnn_benchmark/engine/trainer.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import datetime +import logging +import time + +import torch +from maskrcnn_benchmark.utils.comm import get_world_size, is_main_process +from maskrcnn_benchmark.utils.metric_logger import MetricLogger +import torch.distributed as dist +from apex import amp + + +def reduce_loss_dict(loss_dict): + """ + Reduce the loss dictionary from all processes so that process with rank + 0 has the averaged results. Returns a dict with the same fields as + loss_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return loss_dict + with torch.no_grad(): + loss_names = [] + all_losses = [] + for k, v in loss_dict.items(): + loss_names.append(k) + all_losses.append(v) + all_losses = torch.stack(all_losses, dim=0) + dist.reduce(all_losses, dst=0) + if dist.get_rank() == 0: + # only main process gets accumulated, so only divide by + # world_size in this case + all_losses /= world_size + reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} + return reduced_losses + + +def do_train( + model, + data_loader, + optimizer, + scheduler, + checkpointer, + device, + checkpoint_period, + arguments, + tb_logger, + cfg, + local_rank, +): + logger = logging.getLogger("maskrcnn_benchmark.trainer") + logger.info("Start training") + meters = MetricLogger(delimiter=" ") + max_iter = len(data_loader) + start_iter = arguments["iteration"] + model.train() + start_training_time = time.time() + end = time.time() + for iteration, (images, targets, _) in enumerate(data_loader, start_iter): + data_time = time.time() - end + arguments["iteration"] = iteration + + scheduler.step() + + images = images.to(device) + targets = [target.to(device) for target in targets] + + loss_dict = model(images, targets) + losses = sum(loss for loss in loss_dict.values()) + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = reduce_loss_dict(loss_dict) + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + meters.update(loss=losses_reduced, **loss_dict_reduced) + + optimizer.zero_grad() + # losses.backward() + # Note: If mixed precision is not used, this ends up doing nothing + # Otherwise apply loss scaling for mixed-precision recipe + with amp.scale_loss(losses, optimizer) as scaled_losses: + scaled_losses.backward() + if cfg.SOLVER.USE_ADAM: + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + + batch_time = time.time() - end + end = time.time() + meters.update(time=batch_time, data=data_time) + + eta_seconds = meters.time.global_avg * (max_iter - iteration) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + + if local_rank == 0 and (iteration % cfg.SOLVER.DISPLAY_FREQ == 0 or iteration == (max_iter - 1)): + logger.info( + meters.delimiter.join( + [ + "eta: {eta}", + "iter: {iter}", + "{meters}", + "lr: {lr:.6f}", + "max mem: {memory:.0f}", + ] + ).format( + eta=eta_string, + iter=iteration, + meters=str(meters), + lr=optimizer.param_groups[0]["lr"], + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, + ) + ) + for tag, value in loss_dict_reduced.items(): + tb_logger.scalar_summary(tag, value.item(), iteration) + if local_rank == 0 and iteration % checkpoint_period == 0 and iteration > 0: + checkpointer.save("model_{:07d}".format(iteration), **arguments) + + if local_rank == 0: + checkpointer.save("model_{:07d}".format(iteration), **arguments) + total_training_time = time.time() - start_training_time + total_time_str = str(datetime.timedelta(seconds=total_training_time)) + logger.info( + "Total training time: {} ({:.4f} s / it)".format( + total_time_str, total_training_time / (max_iter) + ) + ) diff --git a/maskrcnn_benchmark/layers/__init__.py b/maskrcnn_benchmark/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..420bc1d2c4f9d93bf72388f7be2bd4d557e6c26b --- /dev/null +++ b/maskrcnn_benchmark/layers/__init__.py @@ -0,0 +1,42 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +from .batch_norm import FrozenBatchNorm2d +from .misc import Conv2d +from .misc import DFConv2d +from .misc import ConvTranspose2d +from .misc import interpolate +from .nms import nms +from .roi_align import ROIAlign +from .roi_align import roi_align +from .roi_pool import ROIPool +from .roi_pool import roi_pool +from .smooth_l1_loss import smooth_l1_loss +from .dcn.deform_conv_func import deform_conv, modulated_deform_conv +from .dcn.deform_conv_module import DeformConv, ModulatedDeformConv, ModulatedDeformConvPack +from .dcn.deform_pool_func import deform_roi_pooling +from .dcn.deform_pool_module import DeformRoIPooling, DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack +__all__ = [ + "nms", + "roi_align", + "ROIAlign", + "roi_pool", + "ROIPool", + "smooth_l1_loss", + "Conv2d", + "DFConv2d", + "ConvTranspose2d", + "interpolate", + "BatchNorm2d", + "FrozenBatchNorm2d", + 'deform_conv', + 'modulated_deform_conv', + 'DeformConv', + 'ModulatedDeformConv', + 'ModulatedDeformConvPack', + 'deform_roi_pooling', + 'DeformRoIPooling', + 'DeformRoIPoolingPack', + 'ModulatedDeformRoIPoolingPack', +] +# __all__ = ["nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool", "smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate", "FrozenBatchNorm2d"] diff --git a/maskrcnn_benchmark/layers/_utils.py b/maskrcnn_benchmark/layers/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3dabc127b221d67eae7587ab4905416fa5fcf121 --- /dev/null +++ b/maskrcnn_benchmark/layers/_utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import glob +import os.path + +import torch + +try: + from torch.utils.cpp_extension import load as load_ext + from torch.utils.cpp_extension import CUDA_HOME +except ImportError: + raise ImportError("The cpp layer extensions requires PyTorch 0.4 or higher") + + +def _load_C_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + this_dir = os.path.dirname(this_dir) + this_dir = os.path.join(this_dir, "csrc") + + main_file = glob.glob(os.path.join(this_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu")) + + source = main_file + source_cpu + + extra_cflags = [] + if torch.cuda.is_available() and CUDA_HOME is not None: + source.extend(source_cuda) + extra_cflags = ["-DWITH_CUDA"] + source = [os.path.join(this_dir, s) for s in source] + extra_include_paths = [this_dir] + return load_ext( + "torchvision", + source, + extra_cflags=extra_cflags, + extra_include_paths=extra_include_paths, + ) + + +_C = _load_C_extensions() diff --git a/maskrcnn_benchmark/layers/batch_norm.py b/maskrcnn_benchmark/layers/batch_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc9ec6a5b2bff625930b66aa8acb173efe49731 --- /dev/null +++ b/maskrcnn_benchmark/layers/batch_norm.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn + + +class FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters + are fixed + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def forward(self, x): + # Cast all fixed parameters to half() if necessary + if x.dtype == torch.float16: + self.weight = self.weight.half() + self.bias = self.bias.half() + self.running_mean = self.running_mean.half() + self.running_var = self.running_var.half() + scale = self.weight * self.running_var.rsqrt() + bias = self.bias - self.running_mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + return x * scale + bias diff --git a/maskrcnn_benchmark/layers/dcn/__init__.py b/maskrcnn_benchmark/layers/dcn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5af25d45fd8b80a347566ecef1f9cf77d3da48 --- /dev/null +++ b/maskrcnn_benchmark/layers/dcn/__init__.py @@ -0,0 +1,3 @@ +# +# Copied From [mmdetection](https://github.com/open-mmlab/mmdetection/tree/master/mmdet/ops/dcn) +# diff --git a/maskrcnn_benchmark/layers/dcn/deform_conv_func.py b/maskrcnn_benchmark/layers/dcn/deform_conv_func.py new file mode 100644 index 0000000000000000000000000000000000000000..388bacf12d860c4d056dde0076400209802bb4e1 --- /dev/null +++ b/maskrcnn_benchmark/layers/dcn/deform_conv_func.py @@ -0,0 +1,262 @@ +import torch +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from maskrcnn_benchmark import _C + + +class DeformConvFunction(Function): + + @staticmethod + def forward( + ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=64 + ): + if input is not None and input.dim() != 4: + raise ValueError( + "Expected 4D tensor as input, got {}D tensor instead.".format( + input.dim())) + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty( + DeformConvFunction._output_size(input, weight, ctx.padding, + ctx.dilation, ctx.stride)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % + cur_im2col_step) == 0, 'im2col step must divide batchsize' + _C.deform_conv_forward( + input, + weight, + offset, + output, + ctx.bufs_[0], + ctx.bufs_[1], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + cur_im2col_step + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % + cur_im2col_step) == 0, 'im2col step must divide batchsize' + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + _C.deform_conv_backward_input( + input, + offset, + grad_output, + grad_input, + grad_offset, + weight, + ctx.bufs_[0], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + cur_im2col_step + ) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + _C.deform_conv_backward_parameters( + input, + offset, + grad_output, + grad_weight, + ctx.bufs_[0], + ctx.bufs_[1], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + 1, + cur_im2col_step + ) + + return (grad_input, grad_offset, grad_weight, None, None, None, None, None) + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError( + "convolution input is too small (output would be {})".format( + 'x'.join(map(str, output_size)))) + return output_size + + +class ModulatedDeformConvFunction(Function): + + @staticmethod + def forward( + ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1 + ): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad \ + or input.requires_grad: + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty( + ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + _C.modulated_deform_conv_forward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + output, + ctx._bufs[1], + weight.shape[2], + weight.shape[3], + ctx.stride, + ctx.stride, + ctx.padding, + ctx.padding, + ctx.dilation, + ctx.dilation, + ctx.groups, + ctx.deformable_groups, + ctx.with_bias + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + _C.modulated_deform_conv_backward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + ctx._bufs[1], + grad_input, + grad_weight, + grad_bias, + grad_offset, + grad_mask, + grad_output, + weight.shape[2], + weight.shape[3], + ctx.stride, + ctx.stride, + ctx.padding, + ctx.padding, + ctx.dilation, + ctx.dilation, + ctx.groups, + ctx.deformable_groups, + ctx.with_bias + ) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, + None, None, None, None, None) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * ctx.padding - + (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 + width_out = (width + 2 * ctx.padding - + (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply diff --git a/maskrcnn_benchmark/layers/dcn/deform_conv_module.py b/maskrcnn_benchmark/layers/dcn/deform_conv_module.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b58c8407978e020821436398c6faaf1f10940f --- /dev/null +++ b/maskrcnn_benchmark/layers/dcn/deform_conv_module.py @@ -0,0 +1,177 @@ +import math + +import torch +import torch.nn as nn +from torch.nn.modules.utils import _pair + +from .deform_conv_func import deform_conv, modulated_deform_conv + + +class DeformConv(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False + ): + assert not bias + super(DeformConv, self).__init__() + self.with_bias = bias + + assert in_channels % groups == 0, \ + 'in_channels {} cannot be divisible by groups {}'.format( + in_channels, groups) + assert out_channels % groups == 0, \ + 'out_channels {} cannot be divisible by groups {}'.format( + out_channels, groups) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // self.groups, + *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, input, offset): + return deform_conv(input, offset, self.weight, self.stride, + self.padding, self.dilation, self.groups, + self.deformable_groups) + + def __repr__(self): + return "".join([ + "{}(".format(self.__class__.__name__), + "in_channels={}, ".format(self.in_channels), + "out_channels={}, ".format(self.out_channels), + "kernel_size={}, ".format(self.kernel_size), + "stride={}, ".format(self.stride), + "dilation={}, ".format(self.dilation), + "padding={}, ".format(self.padding), + "groups={}, ".format(self.groups), + "deformable_groups={}, ".format(self.deformable_groups), + "bias={})".format(self.with_bias), + ]) + + +class ModulatedDeformConv(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True + ): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + + self.weight = nn.Parameter(torch.Tensor( + out_channels, + in_channels // groups, + *self.kernel_size + )) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, input, offset, mask): + return modulated_deform_conv( + input, offset, mask, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups, self.deformable_groups) + + def __repr__(self): + return "".join([ + "{}(".format(self.__class__.__name__), + "in_channels={}, ".format(self.in_channels), + "out_channels={}, ".format(self.out_channels), + "kernel_size={}, ".format(self.kernel_size), + "stride={}, ".format(self.stride), + "dilation={}, ".format(self.dilation), + "padding={}, ".format(self.padding), + "groups={}, ".format(self.groups), + "deformable_groups={}, ".format(self.deformable_groups), + "bias={})".format(self.with_bias), + ]) + +class ModulatedDeformConvPack(ModulatedDeformConv): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True): + super(ModulatedDeformConvPack, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, deformable_groups, bias) + + self.conv_offset_mask = nn.Conv2d( + self.in_channels // self.groups, + self.deformable_groups * 3 * self.kernel_size[0] * + self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset_mask.weight.data.zero_() + self.conv_offset_mask.bias.data.zero_() + + def forward(self, input): + out = self.conv_offset_mask(input) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv( + input, offset, mask, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups, self.deformable_groups) diff --git a/maskrcnn_benchmark/layers/dcn/deform_pool_func.py b/maskrcnn_benchmark/layers/dcn/deform_pool_func.py new file mode 100644 index 0000000000000000000000000000000000000000..e083b002ec13fda353b98f513fce96911eab0c75 --- /dev/null +++ b/maskrcnn_benchmark/layers/dcn/deform_pool_func.py @@ -0,0 +1,95 @@ +import torch +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from maskrcnn_benchmark import _C + + +class DeformRoIPoolingFunction(Function): + + @staticmethod + def forward( + ctx, + data, + rois, + offset, + spatial_scale, + out_size, + out_channels, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=.0 + ): + ctx.spatial_scale = spatial_scale + ctx.out_size = out_size + ctx.out_channels = out_channels + ctx.no_trans = no_trans + ctx.group_size = group_size + ctx.part_size = out_size if part_size is None else part_size + ctx.sample_per_part = sample_per_part + ctx.trans_std = trans_std + + assert 0.0 <= ctx.trans_std <= 1.0 + if not data.is_cuda: + raise NotImplementedError + + n = rois.shape[0] + output = data.new_empty(n, out_channels, out_size, out_size) + output_count = data.new_empty(n, out_channels, out_size, out_size) + _C.deform_psroi_pooling_forward( + data, + rois, + offset, + output, + output_count, + ctx.no_trans, + ctx.spatial_scale, + ctx.out_channels, + ctx.group_size, + ctx.out_size, + ctx.part_size, + ctx.sample_per_part, + ctx.trans_std + ) + + if data.requires_grad or rois.requires_grad or offset.requires_grad: + ctx.save_for_backward(data, rois, offset) + ctx.output_count = output_count + + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + + data, rois, offset = ctx.saved_tensors + output_count = ctx.output_count + grad_input = torch.zeros_like(data) + grad_rois = None + grad_offset = torch.zeros_like(offset) + + _C.deform_psroi_pooling_backward( + grad_output, + data, + rois, + offset, + output_count, + grad_input, + grad_offset, + ctx.no_trans, + ctx.spatial_scale, + ctx.out_channels, + ctx.group_size, + ctx.out_size, + ctx.part_size, + ctx.sample_per_part, + ctx.trans_std + ) + return (grad_input, grad_rois, grad_offset, None, None, None, None, None, None, None, None) + + +deform_roi_pooling = DeformRoIPoolingFunction.apply diff --git a/maskrcnn_benchmark/layers/dcn/deform_pool_module.py b/maskrcnn_benchmark/layers/dcn/deform_pool_module.py new file mode 100644 index 0000000000000000000000000000000000000000..bab6c2604da89430f759c67c3659b6d9f474ed04 --- /dev/null +++ b/maskrcnn_benchmark/layers/dcn/deform_pool_module.py @@ -0,0 +1,150 @@ +from torch import nn + +from .deform_pool_func import deform_roi_pooling + + +class DeformRoIPooling(nn.Module): + + def __init__(self, + spatial_scale, + out_size, + out_channels, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=.0): + super(DeformRoIPooling, self).__init__() + self.spatial_scale = spatial_scale + self.out_size = out_size + self.out_channels = out_channels + self.no_trans = no_trans + self.group_size = group_size + self.part_size = out_size if part_size is None else part_size + self.sample_per_part = sample_per_part + self.trans_std = trans_std + + def forward(self, data, rois, offset): + if self.no_trans: + offset = data.new_empty(0) + return deform_roi_pooling( + data, rois, offset, self.spatial_scale, self.out_size, + self.out_channels, self.no_trans, self.group_size, self.part_size, + self.sample_per_part, self.trans_std) + + +class DeformRoIPoolingPack(DeformRoIPooling): + + def __init__(self, + spatial_scale, + out_size, + out_channels, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=.0, + deform_fc_channels=1024): + super(DeformRoIPoolingPack, + self).__init__(spatial_scale, out_size, out_channels, no_trans, + group_size, part_size, sample_per_part, trans_std) + + self.deform_fc_channels = deform_fc_channels + + if not no_trans: + self.offset_fc = nn.Sequential( + nn.Linear(self.out_size * self.out_size * self.out_channels, + self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, + self.out_size * self.out_size * 2)) + self.offset_fc[-1].weight.data.zero_() + self.offset_fc[-1].bias.data.zero_() + + def forward(self, data, rois): + assert data.size(1) == self.out_channels + if self.no_trans: + offset = data.new_empty(0) + return deform_roi_pooling( + data, rois, offset, self.spatial_scale, self.out_size, + self.out_channels, self.no_trans, self.group_size, + self.part_size, self.sample_per_part, self.trans_std) + else: + n = rois.shape[0] + offset = data.new_empty(0) + x = deform_roi_pooling(data, rois, offset, self.spatial_scale, + self.out_size, self.out_channels, True, + self.group_size, self.part_size, + self.sample_per_part, self.trans_std) + offset = self.offset_fc(x.view(n, -1)) + offset = offset.view(n, 2, self.out_size, self.out_size) + return deform_roi_pooling( + data, rois, offset, self.spatial_scale, self.out_size, + self.out_channels, self.no_trans, self.group_size, + self.part_size, self.sample_per_part, self.trans_std) + + +class ModulatedDeformRoIPoolingPack(DeformRoIPooling): + + def __init__(self, + spatial_scale, + out_size, + out_channels, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=.0, + deform_fc_channels=1024): + super(ModulatedDeformRoIPoolingPack, self).__init__( + spatial_scale, out_size, out_channels, no_trans, group_size, + part_size, sample_per_part, trans_std) + + self.deform_fc_channels = deform_fc_channels + + if not no_trans: + self.offset_fc = nn.Sequential( + nn.Linear(self.out_size * self.out_size * self.out_channels, + self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, + self.out_size * self.out_size * 2)) + self.offset_fc[-1].weight.data.zero_() + self.offset_fc[-1].bias.data.zero_() + self.mask_fc = nn.Sequential( + nn.Linear(self.out_size * self.out_size * self.out_channels, + self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, + self.out_size * self.out_size * 1), + nn.Sigmoid()) + self.mask_fc[2].weight.data.zero_() + self.mask_fc[2].bias.data.zero_() + + def forward(self, data, rois): + assert data.size(1) == self.out_channels + if self.no_trans: + offset = data.new_empty(0) + return deform_roi_pooling( + data, rois, offset, self.spatial_scale, self.out_size, + self.out_channels, self.no_trans, self.group_size, + self.part_size, self.sample_per_part, self.trans_std) + else: + n = rois.shape[0] + offset = data.new_empty(0) + x = deform_roi_pooling(data, rois, offset, self.spatial_scale, + self.out_size, self.out_channels, True, + self.group_size, self.part_size, + self.sample_per_part, self.trans_std) + offset = self.offset_fc(x.view(n, -1)) + offset = offset.view(n, 2, self.out_size, self.out_size) + mask = self.mask_fc(x.view(n, -1)) + mask = mask.view(n, 1, self.out_size, self.out_size) + return deform_roi_pooling( + data, rois, offset, self.spatial_scale, self.out_size, + self.out_channels, self.no_trans, self.group_size, + self.part_size, self.sample_per_part, self.trans_std) * mask diff --git a/maskrcnn_benchmark/layers/misc.py b/maskrcnn_benchmark/layers/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..0a044b46b1dcad25a5810b764e78e318f1065f49 --- /dev/null +++ b/maskrcnn_benchmark/layers/misc.py @@ -0,0 +1,185 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +helper class that supports empty tensors on some nn functions. + +Ideally, add support directly in PyTorch to empty tensors in +those functions. + +This can be removed once https://github.com/pytorch/pytorch/issues/12013 +is implemented +""" + +import math +import torch +from torch import nn +from torch.nn.modules.utils import _ntuple + + +class _NewEmptyTensorOp(torch.autograd.Function): + @staticmethod + def forward(ctx, x, new_shape): + ctx.shape = x.shape + return x.new_empty(new_shape) + + @staticmethod + def backward(ctx, grad): + shape = ctx.shape + return _NewEmptyTensorOp.apply(grad, shape), None + + + +class Conv2d(torch.nn.Conv2d): + def forward(self, x): + if x.numel() > 0: + return super(Conv2d, self).forward(x) + # get output shape + + output_shape = [ + (i + 2 * p - (di * (k - 1) + 1)) // d + 1 + for i, p, di, k, d in zip( + x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride + ) + ] + output_shape = [x.shape[0], self.weight.shape[0]] + output_shape + return _NewEmptyTensorOp.apply(x, output_shape) + + +class ConvTranspose2d(torch.nn.ConvTranspose2d): + def forward(self, x): + if x.numel() > 0: + return super(ConvTranspose2d, self).forward(x) + # get output shape + + output_shape = [ + (i - 1) * d - 2 * p + (di * (k - 1) + 1) + op + for i, p, di, k, d, op in zip( + x.shape[-2:], + self.padding, + self.dilation, + self.kernel_size, + self.stride, + self.output_padding, + ) + ] + output_shape = [x.shape[0], self.bias.shape[0]] + output_shape + return _NewEmptyTensorOp.apply(x, output_shape) + + +def interpolate( + input, size=None, scale_factor=None, mode="nearest", align_corners=None +): + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + def _check_size_scale_factor(dim): + if size is None and scale_factor is None: + raise ValueError("either size or scale_factor should be defined") + if size is not None and scale_factor is not None: + raise ValueError("only one of size or scale_factor should be defined") + if ( + scale_factor is not None + and isinstance(scale_factor, tuple) + and len(scale_factor) != dim + ): + raise ValueError( + "scale_factor shape must match input shape. " + "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor)) + ) + + def _output_size(dim): + _check_size_scale_factor(dim) + if size is not None: + return size + scale_factors = _ntuple(dim)(scale_factor) + # math.floor might return float in py2.7 + return [ + int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim) + ] + + output_shape = tuple(_output_size(2)) + output_shape = input.shape[:-2] + output_shape + return _NewEmptyTensorOp.apply(input, output_shape) + +class DFConv2d(nn.Module): + """Deformable convolutional layer""" + def __init__( + self, + in_channels, + out_channels, + with_modulated_dcn=True, + kernel_size=3, + stride=1, + groups=1, + dilation=1, + deformable_groups=1, + bias=False + ): + super(DFConv2d, self).__init__() + if isinstance(kernel_size, (list, tuple)): + assert len(kernel_size) == 2 + offset_base_channels = kernel_size[0] * kernel_size[1] + else: + offset_base_channels = kernel_size * kernel_size + if with_modulated_dcn: + from maskrcnn_benchmark.layers import ModulatedDeformConv + offset_channels = offset_base_channels * 3 #default: 27 + conv_block = ModulatedDeformConv + else: + from maskrcnn_benchmark.layers import DeformConv + offset_channels = offset_base_channels * 2 #default: 18 + conv_block = DeformConv + self.offset = Conv2d( + in_channels, + deformable_groups * offset_channels, + kernel_size=kernel_size, + stride= stride, + padding= dilation, + groups=1, + dilation=dilation + ) + for l in [self.offset,]: + nn.init.kaiming_uniform_(l.weight, a=1) + torch.nn.init.constant_(l.bias, 0.) + self.conv = conv_block( + in_channels, + out_channels, + kernel_size=kernel_size, + stride= stride, + padding=dilation, + dilation=dilation, + groups=groups, + deformable_groups=deformable_groups, + bias=bias + ) + self.with_modulated_dcn = with_modulated_dcn + self.kernel_size = kernel_size + self.stride = stride + self.padding = dilation + self.dilation = dilation + + def forward(self, x): + if x.numel() > 0: + if not self.with_modulated_dcn: + offset = self.offset(x) + x = self.conv(x, offset) + else: + offset_mask = self.offset(x) + offset = offset_mask[:, :18, :, :] + mask = offset_mask[:, -9:, :, :].sigmoid() + x = self.conv(x, offset, mask) + return x + # get output shape + output_shape = [ + (i + 2 * p - (di * (k - 1) + 1)) // d + 1 + for i, p, di, k, d in zip( + x.shape[-2:], + self.padding, + self.dilation, + self.kernel_size, + self.stride + ) + ] + output_shape = [x.shape[0], self.conv.weight.shape[0]] + output_shape + return _NewEmptyTensorOp.apply(x, output_shape) diff --git a/maskrcnn_benchmark/layers/nms.py b/maskrcnn_benchmark/layers/nms.py new file mode 100644 index 0000000000000000000000000000000000000000..855d032664b67b46b3646042a5c707f7bf2234eb --- /dev/null +++ b/maskrcnn_benchmark/layers/nms.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# from ._utils import _C +from maskrcnn_benchmark import _C + +# nms = _C.nms +from apex import amp + +# Only valid with fp32 inputs - give AMP the hint +nms = amp.float_function(_C.nms) +# nms.__doc__ = """ +# This function performs Non-maximum suppresion""" diff --git a/maskrcnn_benchmark/layers/roi_align.py b/maskrcnn_benchmark/layers/roi_align.py new file mode 100644 index 0000000000000000000000000000000000000000..1036f962db0bfa9d053ab111be4536add3dc3860 --- /dev/null +++ b/maskrcnn_benchmark/layers/roi_align.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from maskrcnn_benchmark import _C +from apex import amp + +class _ROIAlign(Function): + @staticmethod + def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio): + ctx.save_for_backward(roi) + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.sampling_ratio = sampling_ratio + ctx.input_shape = input.size() + output = _C.roi_align_forward( + input, roi, spatial_scale, output_size[0], output_size[1], sampling_ratio + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + rois, = ctx.saved_tensors + output_size = ctx.output_size + spatial_scale = ctx.spatial_scale + sampling_ratio = ctx.sampling_ratio + bs, ch, h, w = ctx.input_shape + grad_input = _C.roi_align_backward( + grad_output, + rois, + spatial_scale, + output_size[0], + output_size[1], + bs, + ch, + h, + w, + sampling_ratio, + ) + return grad_input, None, None, None, None + + +roi_align = _ROIAlign.apply + + +class ROIAlign(nn.Module): + def __init__(self, output_size, spatial_scale, sampling_ratio): + super(ROIAlign, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + + @amp.float_function + def forward(self, input, rois): + return roi_align( + input, rois, self.output_size, self.spatial_scale, self.sampling_ratio + ) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ")" + return tmpstr diff --git a/maskrcnn_benchmark/layers/roi_pool.py b/maskrcnn_benchmark/layers/roi_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..b09e69978f28aed8bd77cecfd35255a61f463a26 --- /dev/null +++ b/maskrcnn_benchmark/layers/roi_pool.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from maskrcnn_benchmark import _C +from apex import amp + +class _ROIPool(Function): + @staticmethod + def forward(ctx, input, roi, output_size, spatial_scale): + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.input_shape = input.size() + output, argmax = _C.roi_pool_forward( + input, roi, spatial_scale, output_size[0], output_size[1] + ) + ctx.save_for_backward(input, roi, argmax) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, rois, argmax = ctx.saved_tensors + output_size = ctx.output_size + spatial_scale = ctx.spatial_scale + bs, ch, h, w = ctx.input_shape + grad_input = _C.roi_pool_backward( + grad_output, + input, + rois, + argmax, + spatial_scale, + output_size[0], + output_size[1], + bs, + ch, + h, + w, + ) + return grad_input, None, None, None + + +roi_pool = _ROIPool.apply + + +class ROIPool(nn.Module): + def __init__(self, output_size, spatial_scale): + super(ROIPool, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + + @amp.float_function + def forward(self, input, rois): + return roi_pool(input, rois, self.output_size, self.spatial_scale) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ")" + return tmpstr diff --git a/maskrcnn_benchmark/layers/smooth_l1_loss.py b/maskrcnn_benchmark/layers/smooth_l1_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4664bb47b731eb087aa777d6f9a4b28fddd03a --- /dev/null +++ b/maskrcnn_benchmark/layers/smooth_l1_loss.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + + +# TODO maybe push this to nn? +def smooth_l1_loss(input, target, beta=1. / 9, size_average=True): + """ + very similar to the smooth_l1_loss from pytorch, but with + the extra beta parameter + """ + n = torch.abs(input - target) + cond = n < beta + loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta) + if size_average: + return loss.mean() + return loss.sum() diff --git a/maskrcnn_benchmark/modeling/__init__.py b/maskrcnn_benchmark/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/maskrcnn_benchmark/modeling/backbone/__init__.py b/maskrcnn_benchmark/modeling/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b3da17b811a3a3efda7f40a9af9ada24ffe2be7 --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .backbone import build_backbone diff --git a/maskrcnn_benchmark/modeling/backbone/backbone.py b/maskrcnn_benchmark/modeling/backbone/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..6f6561ba0dae70eb771e9068dc6b234aaab43733 --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/backbone.py @@ -0,0 +1,131 @@ +# # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# from collections import OrderedDict + +# from torch import nn + +# from . import fpn as fpn_module +# from . import resnet + + +# def build_resnet_backbone(cfg): +# body = resnet.ResNet(cfg) +# model = nn.Sequential(OrderedDict([("body", body)])) +# return model + + +# def build_resnet_fpn_backbone(cfg): +# body = resnet.ResNet(cfg) +# in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS +# out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS +# fpn = fpn_module.FPN( +# in_channels_list=[ +# in_channels_stage2, +# in_channels_stage2 * 2, +# in_channels_stage2 * 4, +# in_channels_stage2 * 8, +# ], +# out_channels=out_channels, +# top_blocks=fpn_module.LastLevelMaxPool(), +# ) +# model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)])) +# return model + + +# _BACKBONES = {"resnet": build_resnet_backbone, "resnet-fpn": build_resnet_fpn_backbone} + + +# def build_backbone(cfg): +# assert cfg.MODEL.BACKBONE.CONV_BODY.startswith( +# "R-" +# ), "Only ResNet and ResNeXt models are currently implemented" +# # Models using FPN end with "-FPN" +# if cfg.MODEL.BACKBONE.CONV_BODY.endswith("-FPN"): +# return build_resnet_fpn_backbone(cfg) +# return build_resnet_backbone(cfg) + + +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from collections import OrderedDict + +from torch import nn + +from maskrcnn_benchmark.modeling import registry +from maskrcnn_benchmark.modeling.make_layers import conv_with_kaiming_uniform +from . import fpn as fpn_module +# from . import resnet + + +@registry.BACKBONES.register("R-50-C4") +@registry.BACKBONES.register("R-50-C5") +@registry.BACKBONES.register("R-101-C4") +@registry.BACKBONES.register("R-101-C5") +def build_resnet_backbone(cfg): + body = resnet.ResNet(cfg) + model = nn.Sequential(OrderedDict([("body", body)])) + model.out_channels = cfg.MODEL.RESNETS.BACKBONE_OUT_CHANNELS + return model + +@registry.BACKBONES.register("R-18-FPN") +@registry.BACKBONES.register("R-34-FPN") +@registry.BACKBONES.register("R-50-FPN") +@registry.BACKBONES.register("R-101-FPN") +@registry.BACKBONES.register("R-152-FPN") +def build_resnet_fpn_backbone(cfg): + if cfg.MODEL.RESNET34: + from . import resnet34 as resnet + body = resnet.ResNet(layers=cfg.MODEL.RESNETS.LAYERS) + else: + from . import resnet + body = resnet.ResNet(cfg) + in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + out_channels = cfg.MODEL.RESNETS.BACKBONE_OUT_CHANNELS + fpn = fpn_module.FPN( + in_channels_list=[ + in_channels_stage2, + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ], + out_channels=out_channels, + conv_block=conv_with_kaiming_uniform( + cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU + ), + top_blocks=fpn_module.LastLevelMaxPool(), + ) + model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)])) + model.out_channels = out_channels + return model + + +@registry.BACKBONES.register("R-50-FPN-RETINANET") +@registry.BACKBONES.register("R-101-FPN-RETINANET") +def build_resnet_fpn_p3p7_backbone(cfg): + body = resnet.ResNet(cfg) + in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + out_channels = cfg.MODEL.RESNETS.BACKBONE_OUT_CHANNELS + in_channels_p6p7 = in_channels_stage2 * 8 if cfg.MODEL.RETINANET.USE_C5 \ + else out_channels + fpn = fpn_module.FPN( + in_channels_list=[ + 0, + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ], + out_channels=out_channels, + conv_block=conv_with_kaiming_uniform( + cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU + ), + top_blocks=fpn_module.LastLevelP6P7(in_channels_p6p7, out_channels), + ) + model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)])) + model.out_channels = out_channels + return model + + +def build_backbone(cfg): + assert cfg.MODEL.BACKBONE.CONV_BODY in registry.BACKBONES, \ + "cfg.MODEL.BACKBONE.CONV_BODY: {} are not registered in registry".format( + cfg.MODEL.BACKBONE.CONV_BODY + ) + return registry.BACKBONES[cfg.MODEL.BACKBONE.CONV_BODY](cfg) diff --git a/maskrcnn_benchmark/modeling/backbone/fpn.py b/maskrcnn_benchmark/modeling/backbone/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..b1190389164fa5057dc4cd17e959969c588ad1ee --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/fpn.py @@ -0,0 +1,175 @@ +# #!/usr/bin/env python3 +# # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# import torch +# import torch.nn.functional as F +# from torch import nn + + +# class FPN(nn.Module): +# """ +# Module that adds FPN on top of a list of feature maps. +# The feature maps are currently supposed to be in increasing depth +# order, and must be consecutive +# """ + +# def __init__(self, in_channels_list, out_channels, top_blocks=None): +# """ +# Arguments: +# in_channels_list (list[int]): number of channels for each feature map that +# will be fed +# out_channels (int): number of channels of the FPN representation +# top_blocks (nn.Module or None): if provided, an extra operation will +# be performed on the output of the last (smallest resolution) +# FPN output, and the result will extend the result list +# """ +# super(FPN, self).__init__() +# self.inner_blocks = [] +# self.layer_blocks = [] +# for idx, in_channels in enumerate(in_channels_list, 1): +# inner_block = "fpn_inner{}".format(idx) +# layer_block = "fpn_layer{}".format(idx) +# inner_block_module = nn.Conv2d(in_channels, out_channels, 1) +# layer_block_module = nn.Conv2d(out_channels, out_channels, 3, 1, 1) +# for module in [inner_block_module, layer_block_module]: +# # Caffe2 implementation uses XavierFill, which in fact +# # corresponds to kaiming_uniform_ in PyTorch +# nn.init.kaiming_uniform_(module.weight, a=1) +# nn.init.constant_(module.bias, 0) +# self.add_module(inner_block, inner_block_module) +# self.add_module(layer_block, layer_block_module) +# self.inner_blocks.append(inner_block) +# self.layer_blocks.append(layer_block) +# self.top_blocks = top_blocks + +# def forward(self, x): +# """ +# Arguments: +# x (list[Tensor]): feature maps for each feature level. +# Returns: +# results (tuple[Tensor]): feature maps after FPN layers. +# They are ordered from highest resolution first. +# """ +# last_inner = getattr(self, self.inner_blocks[-1])(x[-1]) +# results = [] +# results.append(getattr(self, self.layer_blocks[-1])(last_inner)) +# for feature, inner_block, layer_block in zip( +# x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1] +# ): +# inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest") +# inner_lateral = getattr(self, inner_block)(feature) +# # TODO use size instead of scale to make it robust to different sizes +# # inner_top_down = F.upsample(last_inner, size=inner_lateral.shape[-2:], +# # mode='bilinear', align_corners=False) +# last_inner = inner_lateral + inner_top_down +# results.insert(0, getattr(self, layer_block)(last_inner)) + +# if self.top_blocks is not None: +# last_results = self.top_blocks(results[-1]) +# results.extend(last_results) + +# return tuple(results) + + +# class LastLevelMaxPool(nn.Module): +# def forward(self, x): +# return [F.max_pool2d(x, 1, 2, 0)] + + +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +import torch.nn.functional as F +from torch import nn + +class FPN(nn.Module): + """ + Module that adds FPN on top of a list of feature maps. + The feature maps are currently supposed to be in increasing depth + order, and must be consecutive + """ + + def __init__( + self, in_channels_list, out_channels, conv_block, top_blocks=None + ): + """ + Arguments: + in_channels_list (list[int]): number of channels for each feature map that + will be fed + out_channels (int): number of channels of the FPN representation + top_blocks (nn.Module or None): if provided, an extra operation will + be performed on the output of the last (smallest resolution) + FPN output, and the result will extend the result list + """ + super(FPN, self).__init__() + self.inner_blocks = [] + self.layer_blocks = [] + for idx, in_channels in enumerate(in_channels_list, 1): + inner_block = "fpn_inner{}".format(idx) + layer_block = "fpn_layer{}".format(idx) + + if in_channels == 0: + continue + inner_block_module = conv_block(in_channels, out_channels, 1) + layer_block_module = conv_block(out_channels, out_channels, 3, 1) + self.add_module(inner_block, inner_block_module) + self.add_module(layer_block, layer_block_module) + self.inner_blocks.append(inner_block) + self.layer_blocks.append(layer_block) + self.top_blocks = top_blocks + + def forward(self, x): + """ + Arguments: + x (list[Tensor]): feature maps for each feature level. + Returns: + results (tuple[Tensor]): feature maps after FPN layers. + They are ordered from highest resolution first. + """ + last_inner = getattr(self, self.inner_blocks[-1])(x[-1]) + results = [] + results.append(getattr(self, self.layer_blocks[-1])(last_inner)) + for feature, inner_block, layer_block in zip( + x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1] + ): + if not inner_block: + continue + inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest") + inner_lateral = getattr(self, inner_block)(feature) + # TODO use size instead of scale to make it robust to different sizes + # inner_top_down = F.upsample(last_inner, size=inner_lateral.shape[-2:], + # mode='bilinear', align_corners=False) + last_inner = inner_lateral + inner_top_down + results.insert(0, getattr(self, layer_block)(last_inner)) + + if isinstance(self.top_blocks, LastLevelP6P7): + last_results = self.top_blocks(x[-1], results[-1]) + results.extend(last_results) + elif isinstance(self.top_blocks, LastLevelMaxPool): + last_results = self.top_blocks(results[-1]) + results.extend(last_results) + + return tuple(results) + + +class LastLevelMaxPool(nn.Module): + def forward(self, x): + return [F.max_pool2d(x, 1, 2, 0)] + + +class LastLevelP6P7(nn.Module): + """ + This module is used in RetinaNet to generate extra layers, P6 and P7. + """ + def __init__(self, in_channels, out_channels): + super(LastLevelP6P7, self).__init__() + self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) + self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) + for module in [self.p6, self.p7]: + nn.init.kaiming_uniform_(module.weight, a=1) + nn.init.constant_(module.bias, 0) + self.use_P5 = in_channels == out_channels + + def forward(self, c5, p5): + x = p5 if self.use_P5 else c5 + p6 = self.p6(x) + p7 = self.p7(F.relu(p6)) + return [p6, p7] diff --git a/maskrcnn_benchmark/modeling/backbone/resnet.py b/maskrcnn_benchmark/modeling/backbone/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e728570fe3a1d62fbef47805f7d0bee7720f5b77 --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/resnet.py @@ -0,0 +1,773 @@ +# # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# """ +# Variant of the resnet module that takes cfg as an argument. +# Example usage. Strings may be specified in the config file. +# model = ResNet( +# "StemWithFixedBatchNorm", +# "BottleneckWithFixedBatchNorm", +# "ResNet50StagesTo4", +# ) +# Custom implementations may be written in user code and hooked in via the +# `register_*` functions. +# """ +# from collections import namedtuple + +# import torch +# import torch.nn.functional as F +# from torch import nn + +# from maskrcnn_benchmark.layers import FrozenBatchNorm2d +# from maskrcnn_benchmark.layers import Conv2d + + +# # ResNet stage specification +# StageSpec = namedtuple( +# "StageSpec", +# [ +# "index", # Index of the stage, eg 1, 2, ..,. 5 +# "block_count", # Numer of residual blocks in the stage +# "return_features", # True => return the last feature map from this stage +# ], +# ) + +# # ----------------------------------------------------------------------------- +# # Standard ResNet models +# # ----------------------------------------------------------------------------- +# # ResNet-50 (including all stages) +# ResNet50StagesTo5 = ( +# StageSpec(index=i, block_count=c, return_features=r) +# for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, False), (4, 3, True)) +# ) +# # ResNet-50 up to stage 4 (excludes stage 5) +# ResNet50StagesTo4 = ( +# StageSpec(index=i, block_count=c, return_features=r) +# for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, True)) +# ) +# # ResNet-50-FPN (including all stages) +# ResNet50FPNStagesTo5 = ( +# StageSpec(index=i, block_count=c, return_features=r) +# for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 6, True), (4, 3, True)) +# ) +# # ResNet-101-FPN (including all stages) +# ResNet101FPNStagesTo5 = ( +# StageSpec(index=i, block_count=c, return_features=r) +# for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 23, True), (4, 3, True)) +# ) + + +# class ResNet(nn.Module): +# def __init__(self, cfg): +# super(ResNet, self).__init__() + +# # If we want to use the cfg in forward(), then we should make a copy +# # of it and store it for later use: +# # self.cfg = cfg.clone() + +# # Translate string names to implementations +# stem_module = _STEM_MODULES[cfg.MODEL.RESNETS.STEM_FUNC] +# stage_specs = _STAGE_SPECS[cfg.MODEL.BACKBONE.CONV_BODY] +# transformation_module = _TRANSFORMATION_MODULES[cfg.MODEL.RESNETS.TRANS_FUNC] + +# # Construct the stem module +# self.stem = stem_module(cfg) + +# # Constuct the specified ResNet stages +# num_groups = cfg.MODEL.RESNETS.NUM_GROUPS +# width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP +# in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS +# stage2_bottleneck_channels = num_groups * width_per_group +# stage2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS +# self.stages = [] +# self.return_features = {} +# for stage_spec in stage_specs: +# name = "layer" + str(stage_spec.index) +# stage2_relative_factor = 2 ** (stage_spec.index - 1) +# bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor +# out_channels = stage2_out_channels * stage2_relative_factor +# module = _make_stage( +# transformation_module, +# in_channels, +# bottleneck_channels, +# out_channels, +# stage_spec.block_count, +# num_groups, +# cfg.MODEL.RESNETS.STRIDE_IN_1X1, +# first_stride=int(stage_spec.index > 1) + 1, +# ) +# in_channels = out_channels +# self.add_module(name, module) +# self.stages.append(name) +# self.return_features[name] = stage_spec.return_features + +# # Optionally freeze (requires_grad=False) parts of the backbone +# self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT) + +# def _freeze_backbone(self, freeze_at): +# for stage_index in range(freeze_at): +# if stage_index == 0: +# m = self.stem # stage 0 is the stem +# else: +# m = getattr(self, "layer" + str(stage_index)) +# for p in m.parameters(): +# p.requires_grad = False + +# def forward(self, x): +# outputs = [] +# x = self.stem(x) +# for stage_name in self.stages: +# x = getattr(self, stage_name)(x) +# if self.return_features[stage_name]: +# outputs.append(x) +# return outputs + + +# class ResNetHead(nn.Module): +# def __init__( +# self, +# block_module, +# stages, +# num_groups=1, +# width_per_group=64, +# stride_in_1x1=True, +# stride_init=None, +# res2_out_channels=256, +# ): +# super(ResNetHead, self).__init__() + +# stage2_relative_factor = 2 ** (stages[0].index - 1) +# stage2_bottleneck_channels = num_groups * width_per_group +# out_channels = res2_out_channels * stage2_relative_factor +# in_channels = out_channels // 2 +# bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor + +# block_module = _TRANSFORMATION_MODULES[block_module] + +# self.stages = [] +# stride = stride_init +# for stage in stages: +# name = "layer" + str(stage.index) +# if not stride: +# stride = int(stage.index > 1) + 1 +# module = _make_stage( +# block_module, +# in_channels, +# bottleneck_channels, +# out_channels, +# stage.block_count, +# num_groups, +# stride_in_1x1, +# first_stride=stride, +# ) +# stride = None +# self.add_module(name, module) +# self.stages.append(name) + +# def forward(self, x): +# for stage in self.stages: +# x = getattr(self, stage)(x) +# return x + + +# def _make_stage( +# transformation_module, +# in_channels, +# bottleneck_channels, +# out_channels, +# block_count, +# num_groups, +# stride_in_1x1, +# first_stride, +# ): +# blocks = [] +# stride = first_stride +# for _ in range(block_count): +# blocks.append( +# transformation_module( +# in_channels, +# bottleneck_channels, +# out_channels, +# num_groups, +# stride_in_1x1, +# stride, +# ) +# ) +# stride = 1 +# in_channels = out_channels +# return nn.Sequential(*blocks) + + +# class BottleneckWithFixedBatchNorm(nn.Module): +# def __init__( +# self, +# in_channels, +# bottleneck_channels, +# out_channels, +# num_groups=1, +# stride_in_1x1=True, +# stride=1, +# ): +# super(BottleneckWithFixedBatchNorm, self).__init__() + +# self.downsample = None +# if in_channels != out_channels: +# self.downsample = nn.Sequential( +# Conv2d( +# in_channels, out_channels, kernel_size=1, stride=stride, bias=False +# ), +# FrozenBatchNorm2d(out_channels), +# ) + +# # The original MSRA ResNet models have stride in the first 1x1 conv +# # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have +# # stride in the 3x3 conv +# stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + +# self.conv1 = Conv2d( +# in_channels, +# bottleneck_channels, +# kernel_size=1, +# stride=stride_1x1, +# bias=False, +# ) +# self.bn1 = FrozenBatchNorm2d(bottleneck_channels) +# # TODO: specify init for the above + +# self.conv2 = Conv2d( +# bottleneck_channels, +# bottleneck_channels, +# kernel_size=3, +# stride=stride_3x3, +# padding=1, +# bias=False, +# groups=num_groups, +# ) +# self.bn2 = FrozenBatchNorm2d(bottleneck_channels) + +# self.conv3 = Conv2d( +# bottleneck_channels, out_channels, kernel_size=1, bias=False +# ) +# self.bn3 = FrozenBatchNorm2d(out_channels) + +# def forward(self, x): +# residual = x + +# out = self.conv1(x) +# out = self.bn1(out) +# out = F.relu_(out) + +# out = self.conv2(out) +# out = self.bn2(out) +# out = F.relu_(out) + +# out0 = self.conv3(out) +# out = self.bn3(out0) + +# if self.downsample is not None: +# residual = self.downsample(x) + +# out += residual +# out = F.relu_(out) + +# return out + + +# class StemWithFixedBatchNorm(nn.Module): +# def __init__(self, cfg): +# super(StemWithFixedBatchNorm, self).__init__() + +# out_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + +# self.conv1 = Conv2d( +# 3, out_channels, kernel_size=7, stride=2, padding=3, bias=False +# ) +# self.bn1 = FrozenBatchNorm2d(out_channels) + +# def forward(self, x): +# x = self.conv1(x) +# x = self.bn1(x) +# x = F.relu_(x) +# x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) +# return x + + +# _TRANSFORMATION_MODULES = {"BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm} + +# _STEM_MODULES = {"StemWithFixedBatchNorm": StemWithFixedBatchNorm} + +# _STAGE_SPECS = { +# "R-50-C4": ResNet50StagesTo4, +# "R-50-C5": ResNet50StagesTo5, +# "R-50-FPN": ResNet50FPNStagesTo5, +# "R-101-FPN": ResNet101FPNStagesTo5, +# } + + +# def register_transformation_module(module_name, module): +# _register_generic(_TRANSFORMATION_MODULES, module_name, module) + + +# def register_stem_module(module_name, module): +# _register_generic(_STEM_MODULES, module_name, module) + + +# def register_stage_spec(stage_spec_name, stage_spec): +# _register_generic(_STAGE_SPECS, stage_spec_name, stage_spec) + + +# def _register_generic(module_dict, module_name, module): +# assert module_name not in module_dict +# module_dict[module_name] = module + + + +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Variant of the resnet module that takes cfg as an argument. +Example usage. Strings may be specified in the config file. + model = ResNet( + "StemWithFixedBatchNorm", + "BottleneckWithFixedBatchNorm", + "ResNet50StagesTo4", + ) +OR: + model = ResNet( + "StemWithGN", + "BottleneckWithGN", + "ResNet50StagesTo4", + ) +Custom implementations may be written in user code and hooked in via the +`register_*` functions. +""" +from collections import namedtuple + +import torch +import torch.nn.functional as F +from torch import nn + +from maskrcnn_benchmark.layers import FrozenBatchNorm2d +from maskrcnn_benchmark.layers import Conv2d +from maskrcnn_benchmark.layers import DFConv2d +from maskrcnn_benchmark.modeling.make_layers import group_norm +from maskrcnn_benchmark.utils.registry import Registry + + +# ResNet stage specification +StageSpec = namedtuple( + "StageSpec", + [ + "index", # Index of the stage, eg 1, 2, ..,. 5 + "block_count", # Number of residual blocks in the stage + "return_features", # True => return the last feature map from this stage + ], +) + +# ----------------------------------------------------------------------------- +# Standard ResNet models +# ----------------------------------------------------------------------------- +# ResNet-50 (including all stages) +ResNet50StagesTo5 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, False), (4, 3, True)) +) +# ResNet-50 up to stage 4 (excludes stage 5) +ResNet50StagesTo4 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, True)) +) +# ResNet-101 (including all stages) +ResNet101StagesTo5 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, False), (4, 3, True)) +) +# ResNet-101 up to stage 4 (excludes stage 5) +ResNet101StagesTo4 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, True)) +) +# ResNet-50-FPN (including all stages) +ResNet50FPNStagesTo5 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 6, True), (4, 3, True)) +) +# ResNet-101-FPN (including all stages) +ResNet101FPNStagesTo5 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 23, True), (4, 3, True)) +) +# ResNet-152-FPN (including all stages) +ResNet152FPNStagesTo5 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, True), (2, 8, True), (3, 36, True), (4, 3, True)) +) + +class ResNet(nn.Module): + def __init__(self, cfg): + super(ResNet, self).__init__() + + # If we want to use the cfg in forward(), then we should make a copy + # of it and store it for later use: + # self.cfg = cfg.clone() + + # Translate string names to implementations + stem_module = _STEM_MODULES[cfg.MODEL.RESNETS.STEM_FUNC] + stage_specs = _STAGE_SPECS[cfg.MODEL.BACKBONE.CONV_BODY] + transformation_module = _TRANSFORMATION_MODULES[cfg.MODEL.RESNETS.TRANS_FUNC] + + # Construct the stem module + self.stem = stem_module(cfg) + + # Constuct the specified ResNet stages + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + stage2_bottleneck_channels = num_groups * width_per_group + stage2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + self.stages = [] + self.return_features = {} + for stage_spec in stage_specs: + name = "layer" + str(stage_spec.index) + stage2_relative_factor = 2 ** (stage_spec.index - 1) + bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor + out_channels = stage2_out_channels * stage2_relative_factor + stage_with_dcn = cfg.MODEL.RESNETS.STAGE_WITH_DCN[stage_spec.index -1] + module = _make_stage( + transformation_module, + in_channels, + bottleneck_channels, + out_channels, + stage_spec.block_count, + num_groups, + cfg.MODEL.RESNETS.STRIDE_IN_1X1, + first_stride=int(stage_spec.index > 1) + 1, + dcn_config={ + "stage_with_dcn": stage_with_dcn, + "with_modulated_dcn": cfg.MODEL.RESNETS.WITH_MODULATED_DCN, + "deformable_groups": cfg.MODEL.RESNETS.DEFORMABLE_GROUPS, + } + ) + in_channels = out_channels + self.add_module(name, module) + self.stages.append(name) + self.return_features[name] = stage_spec.return_features + + # Optionally freeze (requires_grad=False) parts of the backbone + self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT) + + def _freeze_backbone(self, freeze_at): + if freeze_at < 0: + return + for stage_index in range(freeze_at): + if stage_index == 0: + m = self.stem # stage 0 is the stem + else: + m = getattr(self, "layer" + str(stage_index)) + for p in m.parameters(): + p.requires_grad = False + + def forward(self, x): + outputs = [] + x = self.stem(x) + for stage_name in self.stages: + x = getattr(self, stage_name)(x) + if self.return_features[stage_name]: + outputs.append(x) + return outputs + + +class ResNetHead(nn.Module): + def __init__( + self, + block_module, + stages, + num_groups=1, + width_per_group=64, + stride_in_1x1=True, + stride_init=None, + res2_out_channels=256, + dilation=1, + dcn_config={} + ): + super(ResNetHead, self).__init__() + + stage2_relative_factor = 2 ** (stages[0].index - 1) + stage2_bottleneck_channels = num_groups * width_per_group + out_channels = res2_out_channels * stage2_relative_factor + in_channels = out_channels // 2 + bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor + + block_module = _TRANSFORMATION_MODULES[block_module] + + self.stages = [] + stride = stride_init + for stage in stages: + name = "layer" + str(stage.index) + if not stride: + stride = int(stage.index > 1) + 1 + module = _make_stage( + block_module, + in_channels, + bottleneck_channels, + out_channels, + stage.block_count, + num_groups, + stride_in_1x1, + first_stride=stride, + dilation=dilation, + dcn_config=dcn_config + ) + stride = None + self.add_module(name, module) + self.stages.append(name) + self.out_channels = out_channels + + def forward(self, x): + for stage in self.stages: + x = getattr(self, stage)(x) + return x + + +def _make_stage( + transformation_module, + in_channels, + bottleneck_channels, + out_channels, + block_count, + num_groups, + stride_in_1x1, + first_stride, + dilation=1, + dcn_config={} +): + blocks = [] + stride = first_stride + for _ in range(block_count): + blocks.append( + transformation_module( + in_channels, + bottleneck_channels, + out_channels, + num_groups, + stride_in_1x1, + stride, + dilation=dilation, + dcn_config=dcn_config + ) + ) + stride = 1 + in_channels = out_channels + return nn.Sequential(*blocks) + + +class Bottleneck(nn.Module): + def __init__( + self, + in_channels, + bottleneck_channels, + out_channels, + num_groups, + stride_in_1x1, + stride, + dilation, + norm_func, + dcn_config + ): + super(Bottleneck, self).__init__() + + self.downsample = None + if in_channels != out_channels: + down_stride = stride if dilation == 1 else 1 + self.downsample = nn.Sequential( + Conv2d( + in_channels, out_channels, + kernel_size=1, stride=down_stride, bias=False + ), + norm_func(out_channels), + ) + for modules in [self.downsample,]: + for l in modules.modules(): + if isinstance(l, Conv2d): + nn.init.kaiming_uniform_(l.weight, a=1) + + if dilation > 1: + stride = 1 # reset to be 1 + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have + # stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + ) + self.bn1 = norm_func(bottleneck_channels) + # TODO: specify init for the above + with_dcn = dcn_config.get("stage_with_dcn", False) + if with_dcn: + deformable_groups = dcn_config.get("deformable_groups", 1) + with_modulated_dcn = dcn_config.get("with_modulated_dcn", False) + self.conv2 = DFConv2d( + bottleneck_channels, + bottleneck_channels, + with_modulated_dcn=with_modulated_dcn, + kernel_size=3, + stride=stride_3x3, + groups=num_groups, + dilation=dilation, + deformable_groups=deformable_groups, + bias=False + ) + else: + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride_3x3, + padding=dilation, + bias=False, + groups=num_groups, + dilation=dilation + ) + nn.init.kaiming_uniform_(self.conv2.weight, a=1) + + self.bn2 = norm_func(bottleneck_channels) + + self.conv3 = Conv2d( + bottleneck_channels, out_channels, kernel_size=1, bias=False + ) + self.bn3 = norm_func(out_channels) + + for l in [self.conv1, self.conv3,]: + nn.init.kaiming_uniform_(l.weight, a=1) + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = F.relu_(out) + + out = self.conv2(out) + out = self.bn2(out) + out = F.relu_(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = F.relu_(out) + + return out + + +class BaseStem(nn.Module): + def __init__(self, cfg, norm_func): + super(BaseStem, self).__init__() + + out_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + + self.conv1 = Conv2d( + 3, out_channels, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = norm_func(out_channels) + + for l in [self.conv1,]: + nn.init.kaiming_uniform_(l.weight, a=1) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu_(x) + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + return x + + +class BottleneckWithFixedBatchNorm(Bottleneck): + def __init__( + self, + in_channels, + bottleneck_channels, + out_channels, + num_groups=1, + stride_in_1x1=True, + stride=1, + dilation=1, + dcn_config={} + ): + super(BottleneckWithFixedBatchNorm, self).__init__( + in_channels=in_channels, + bottleneck_channels=bottleneck_channels, + out_channels=out_channels, + num_groups=num_groups, + stride_in_1x1=stride_in_1x1, + stride=stride, + dilation=dilation, + norm_func=FrozenBatchNorm2d, + dcn_config=dcn_config + ) + + +class StemWithFixedBatchNorm(BaseStem): + def __init__(self, cfg): + super(StemWithFixedBatchNorm, self).__init__( + cfg, norm_func=FrozenBatchNorm2d + ) + + +class BottleneckWithGN(Bottleneck): + def __init__( + self, + in_channels, + bottleneck_channels, + out_channels, + num_groups=1, + stride_in_1x1=True, + stride=1, + dilation=1, + dcn_config={} + ): + super(BottleneckWithGN, self).__init__( + in_channels=in_channels, + bottleneck_channels=bottleneck_channels, + out_channels=out_channels, + num_groups=num_groups, + stride_in_1x1=stride_in_1x1, + stride=stride, + dilation=dilation, + norm_func=group_norm, + dcn_config=dcn_config + ) + + +class StemWithGN(BaseStem): + def __init__(self, cfg): + super(StemWithGN, self).__init__(cfg, norm_func=group_norm) + + +_TRANSFORMATION_MODULES = Registry({ + "BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm, + "BottleneckWithGN": BottleneckWithGN, +}) + +_STEM_MODULES = Registry({ + "StemWithFixedBatchNorm": StemWithFixedBatchNorm, + "StemWithGN": StemWithGN, +}) + +_STAGE_SPECS = Registry({ + "R-50-C4": ResNet50StagesTo4, + "R-50-C5": ResNet50StagesTo5, + "R-101-C4": ResNet101StagesTo4, + "R-101-C5": ResNet101StagesTo5, + "R-50-FPN": ResNet50FPNStagesTo5, + "R-50-FPN-RETINANET": ResNet50FPNStagesTo5, + "R-101-FPN": ResNet101FPNStagesTo5, + "R-101-FPN-RETINANET": ResNet101FPNStagesTo5, + "R-152-FPN": ResNet152FPNStagesTo5, +}) diff --git a/maskrcnn_benchmark/modeling/backbone/resnet34.py b/maskrcnn_benchmark/modeling/backbone/resnet34.py new file mode 100644 index 0000000000000000000000000000000000000000..e2c43147a91e723f8acbbc75e745aef9805a96f7 --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/resnet34.py @@ -0,0 +1,97 @@ +import torch +import torch.nn.functional as F +from torch import nn +import math +from maskrcnn_benchmark.layers import FrozenBatchNorm2d +from maskrcnn_benchmark.layers import Conv2d + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = FrozenBatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = FrozenBatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block=BasicBlock, layers=[3, 4, 6, 3]): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = FrozenBatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, FrozenBatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + FrozenBatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x2 = self.layer1(x) + x3 = self.layer2(x2) + x4 = self.layer3(x3) + x5 = self.layer4(x4) + return [x2, x3, x4, x5] diff --git a/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py b/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..6201c5368aa29ff1eae60b7bed33bb6bd015cb1d --- /dev/null +++ b/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +# TODO +class BalancedPositiveNegativeSampler(object): + """ + This class samples batches, + ensuring that they contain a fixed proportion of positives + """ + + def __init__(self, batch_size_per_image, positive_fraction): + """ + Arguments: + batch_size_per_image (int): number of elements to be selected per image + positive_fraction (float): percentace of positive elements per batch + """ + self.batch_size_per_image = batch_size_per_image + self.positive_fraction = positive_fraction + + def __call__(self, matched_idxs): + """ + Arguments: + matched idxs: list of tensors containing -1, 0 or positive values. + Each tensor corresponds to a specific image. + -1 values are ignored, 0 are considered as negatives and > 0 as + positives. + + Returns: + pos_idx (list[tensor]) + neg_idx (list[tensor]) + + Returns two lists of binary masks for each image. + The first list contains the positive elements that were selected, + and the second list the negative example. + """ + pos_idx = [] + neg_idx = [] + for matched_idxs_per_image in matched_idxs: + positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1) + negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1) + + num_pos = int(self.batch_size_per_image * self.positive_fraction) + # protect against not enough positive examples + num_pos = min(positive.numel(), num_pos) + num_neg = self.batch_size_per_image - num_pos + # protect against not enough negative examples + num_neg = min(negative.numel(), num_neg) + + # randomly select positive and negative examples + perm1 = torch.randperm(positive.numel())[:num_pos] + perm2 = torch.randperm(negative.numel())[:num_neg] + + pos_idx_per_image = positive[perm1] + neg_idx_per_image = negative[perm2] + + # create binary mask from indices + pos_idx_per_image_mask = torch.zeros_like( + matched_idxs_per_image, dtype=torch.bool + ) + neg_idx_per_image_mask = torch.zeros_like( + matched_idxs_per_image, dtype=torch.bool + ) + pos_idx_per_image_mask[pos_idx_per_image] = 1 + neg_idx_per_image_mask[neg_idx_per_image] = 1 + + pos_idx.append(pos_idx_per_image_mask) + neg_idx.append(neg_idx_per_image_mask) + + return pos_idx, neg_idx diff --git a/maskrcnn_benchmark/modeling/box_coder.py b/maskrcnn_benchmark/modeling/box_coder.py new file mode 100644 index 0000000000000000000000000000000000000000..5299d0a3256fac4decc8f52d39357d269779af68 --- /dev/null +++ b/maskrcnn_benchmark/modeling/box_coder.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import math + +import torch + + +class BoxCoder(object): + """ + This class encodes and decodes a set of bounding boxes into + the representation used for training the regressors. + """ + + def __init__(self, weights, bbox_xform_clip=None): + """ + Arguments: + weights (4-element tuple) + bbox_xform_clip (float) + """ + self.weights = weights + if bbox_xform_clip is None: + bbox_xform_clip = math.log(1000.0 / 16) + self.bbox_xform_clip = bbox_xform_clip + + def encode(self, reference_boxes, proposals): + """ + Encode a set of proposals with respect to some + reference boxes + + Arguments: + reference_boxes (Tensor): reference boxes + proposals (Tensor): boxes to be encoded + """ + + TO_REMOVE = 1 # TODO remove + ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE + ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE + ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths + ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights + + gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE + gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE + gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths + gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights + + wx, wy, ww, wh = self.weights + targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths + targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights + targets_dw = ww * torch.log(gt_widths / ex_widths) + targets_dh = wh * torch.log(gt_heights / ex_heights) + + targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) + return targets + + def decode(self, rel_codes, boxes): + """ + From a set of original boxes and encoded relative box offsets, + get the decoded boxes. + + Arguments: + rel_codes (Tensor): encoded boxes + boxes (Tensor): reference boxes. + """ + + boxes = boxes.to(rel_codes.dtype) + + TO_REMOVE = 1 # TODO remove + widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE + heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE + ctr_x = boxes[:, 0] + 0.5 * widths + ctr_y = boxes[:, 1] + 0.5 * heights + + wx, wy, ww, wh = self.weights + dx = rel_codes[:, 0::4] / wx + dy = rel_codes[:, 1::4] / wy + dw = rel_codes[:, 2::4] / ww + dh = rel_codes[:, 3::4] / wh + + # Prevent sending too large values into torch.exp() + dw = torch.clamp(dw, max=self.bbox_xform_clip) + dh = torch.clamp(dh, max=self.bbox_xform_clip) + + pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] + pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] + pred_w = torch.exp(dw) * widths[:, None] + pred_h = torch.exp(dh) * heights[:, None] + + pred_boxes = torch.zeros_like(rel_codes) + # x1 + pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w + # y1 + pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h + # x2 (note: "- 1" is correct; don't be fooled by the asymmetry) + pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1 + # y2 (note: "- 1" is correct; don't be fooled by the asymmetry) + pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1 + + return pred_boxes diff --git a/maskrcnn_benchmark/modeling/detector/__init__.py b/maskrcnn_benchmark/modeling/detector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff421e281e16e6623bab2551b242ea003d1f2166 --- /dev/null +++ b/maskrcnn_benchmark/modeling/detector/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .detectors import build_detection_model diff --git a/maskrcnn_benchmark/modeling/detector/detectors.py b/maskrcnn_benchmark/modeling/detector/detectors.py new file mode 100644 index 0000000000000000000000000000000000000000..c96423ea8ab3c3a080cde00730dd27d06a3c52c0 --- /dev/null +++ b/maskrcnn_benchmark/modeling/detector/detectors.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .generalized_rcnn import GeneralizedRCNN + + +_DETECTION_META_ARCHITECTURES = {"GeneralizedRCNN": GeneralizedRCNN} + + +def build_detection_model(cfg): + meta_arch = _DETECTION_META_ARCHITECTURES[cfg.MODEL.META_ARCHITECTURE] + return meta_arch(cfg) diff --git a/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py b/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..e7693244d7f66f4cd739c402cdd053866412aa13 --- /dev/null +++ b/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Implements the Generalized R-CNN framework +""" + +import torch +from torch import nn + +from maskrcnn_benchmark.structures.image_list import to_image_list + +from ..backbone import build_backbone +from ..rpn.rpn import build_rpn +from ..segmentation.segmentation import build_segmentation +from ..roi_heads.roi_heads import build_roi_heads +import time + +class GeneralizedRCNN(nn.Module): + """ + Main class for Generalized R-CNN. Currently supports boxes and masks. + It consists of three main parts: + - backbone + = rpn + - heads: takes the features + the proposals from the RPN and computes + detections / masks from it. + """ + + def __init__(self, cfg): + super(GeneralizedRCNN, self).__init__() + self.cfg = cfg + self.backbone = build_backbone(cfg) + if cfg.MODEL.SEG_ON: + self.proposal = build_segmentation(cfg) + else: + self.proposal = build_rpn(cfg) + if cfg.MODEL.TRAIN_DETECTION_ONLY: + self.roi_heads = None + else: + self.roi_heads = build_roi_heads(cfg) + + def forward(self, images, targets=None): + """ + Arguments: + images (list[Tensor] or ImageList): images to be processed + targets (list[BoxList]): ground-truth boxes present in the image (optional) + + Returns: + result (list[BoxList] or dict[Tensor]): the output from the model. + During training, it returns a dict[Tensor] which contains the losses. + During testing, it returns list[BoxList] contains additional fields + like `scores`, `labels` and `mask` (for Mask R-CNN models). + + """ + if self.training and targets is None: + raise ValueError("In training mode, targets should be passed") + # torch.cuda.synchronize() + # start_time = time.time() + images = to_image_list(images) + # torch.cuda.synchronize() + # end_time = time.time() + # print('image load time:', end_time - start_time) + # torch.cuda.synchronize() + # start_time = time.time() + features = self.backbone(images.tensors) + # torch.cuda.synchronize() + # end_time = time.time() + # print('backbone time:', end_time - start_time) + if self.cfg.MODEL.SEG_ON and not self.training: + # torch.cuda.synchronize() + # start_time = time.time() + (proposals, seg_results), fuse_feature = self.proposal(images, features, targets) + # torch.cuda.synchronize() + # end_time = time.time() + # print('seg time:', end_time - start_time) + else: + if self.cfg.MODEL.SEG_ON: + (proposals, proposal_losses), fuse_feature = self.proposal(images, features, targets) + else: + proposals, proposal_losses = self.proposal(images, features, targets) + if self.roi_heads is not None: + if self.cfg.MODEL.SEG_ON and self.cfg.MODEL.SEG.USE_FUSE_FEATURE: + x, result, detector_losses = self.roi_heads(fuse_feature, proposals, targets) + else: + x, result, detector_losses = self.roi_heads(features, proposals, targets) + else: + # RPN-only models don't have roi_heads + # x = features + result = proposals + detector_losses = {} + + if self.training: + losses = {} + if self.roi_heads is not None: + losses.update(detector_losses) + losses.update(proposal_losses) + return losses + else: + if self.cfg.MODEL.SEG_ON: + return result, proposals, seg_results + else: + return result + + # return result diff --git a/maskrcnn_benchmark/modeling/make_layers.py b/maskrcnn_benchmark/modeling/make_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..049aee6d13ae9d9c17440655cbde0a09a6e21918 --- /dev/null +++ b/maskrcnn_benchmark/modeling/make_layers.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Miscellaneous utility functions +""" + +import torch +from torch import nn +from torch.nn import functional as F +from maskrcnn_benchmark.config import cfg +from maskrcnn_benchmark.layers import Conv2d +from maskrcnn_benchmark.modeling.poolers import Pooler + + +def get_group_gn(dim, dim_per_gp, num_groups): + """get number of groups used by GroupNorm, based on number of channels.""" + assert dim_per_gp == -1 or num_groups == -1, \ + "GroupNorm: can only specify G or C/G." + + if dim_per_gp > 0: + assert dim % dim_per_gp == 0, \ + "dim: {}, dim_per_gp: {}".format(dim, dim_per_gp) + group_gn = dim // dim_per_gp + else: + assert dim % num_groups == 0, \ + "dim: {}, num_groups: {}".format(dim, num_groups) + group_gn = num_groups + + return group_gn + + +def group_norm(out_channels, affine=True, divisor=1): + out_channels = out_channels // divisor + dim_per_gp = cfg.MODEL.GROUP_NORM.DIM_PER_GP // divisor + num_groups = cfg.MODEL.GROUP_NORM.NUM_GROUPS // divisor + eps = cfg.MODEL.GROUP_NORM.EPSILON # default: 1e-5 + return torch.nn.GroupNorm( + get_group_gn(out_channels, dim_per_gp, num_groups), + out_channels, + eps, + affine + ) + + +def make_conv3x3( + in_channels, + out_channels, + dilation=1, + stride=1, + use_gn=False, + use_relu=False, + kaiming_init=True +): + conv = Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False if use_gn else True + ) + if kaiming_init: + nn.init.kaiming_normal_( + conv.weight, mode="fan_out", nonlinearity="relu" + ) + else: + torch.nn.init.normal_(conv.weight, std=0.01) + if not use_gn: + nn.init.constant_(conv.bias, 0) + module = [conv,] + if use_gn: + module.append(group_norm(out_channels)) + if use_relu: + module.append(nn.ReLU(inplace=True)) + if len(module) > 1: + return nn.Sequential(*module) + return conv + + +def make_fc(dim_in, hidden_dim, use_gn=False): + ''' + Caffe2 implementation uses XavierFill, which in fact + corresponds to kaiming_uniform_ in PyTorch + ''' + if use_gn: + fc = nn.Linear(dim_in, hidden_dim, bias=False) + nn.init.kaiming_uniform_(fc.weight, a=1) + return nn.Sequential(fc, group_norm(hidden_dim)) + fc = nn.Linear(dim_in, hidden_dim) + nn.init.kaiming_uniform_(fc.weight, a=1) + nn.init.constant_(fc.bias, 0) + return fc + + +def conv_with_kaiming_uniform(use_gn=False, use_relu=False): + def make_conv( + in_channels, out_channels, kernel_size, stride=1, dilation=1 + ): + conv = Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=dilation * (kernel_size - 1) // 2, + dilation=dilation, + bias=False if use_gn else True + ) + # Caffe2 implementation uses XavierFill, which in fact + # corresponds to kaiming_uniform_ in PyTorch + nn.init.kaiming_uniform_(conv.weight, a=1) + if not use_gn: + nn.init.constant_(conv.bias, 0) + module = [conv,] + if use_gn: + module.append(group_norm(out_channels)) + if use_relu: + module.append(nn.ReLU(inplace=True)) + if len(module) > 1: + return nn.Sequential(*module) + return conv + + return make_conv diff --git a/maskrcnn_benchmark/modeling/matcher.py b/maskrcnn_benchmark/modeling/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..e051d3f593f663834e15517ba0a9be912c890358 --- /dev/null +++ b/maskrcnn_benchmark/modeling/matcher.py @@ -0,0 +1,106 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + + +class Matcher(object): + """ + This class assigns to each predicted "element" (e.g., a box) a ground-truth + element. Each predicted element will have exactly zero or one matches; each + ground-truth element may be assigned to zero or more predicted elements. + + Matching is based on the MxN match_quality_matrix, that characterizes how well + each (ground-truth, predicted)-pair match. For example, if the elements are + boxes, the matrix may contain box IoU overlap values. + + The matcher returns a tensor of size N containing the index of the ground-truth + element m that matches to prediction n. If there is no match, a negative value + is returned. + """ + + BELOW_LOW_THRESHOLD = -1 + BETWEEN_THRESHOLDS = -2 + + def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): + """ + Args: + high_threshold (float): quality values greater than or equal to + this value are candidate matches. + low_threshold (float): a lower quality threshold used to stratify + matches into three levels: + 1) matches >= high_threshold + 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold) + 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold) + allow_low_quality_matches (bool): if True, produce additional matches + for predictions that have only low-quality match candidates. See + set_low_quality_matches_ for more details. + """ + assert low_threshold <= high_threshold + self.high_threshold = high_threshold + self.low_threshold = low_threshold + self.allow_low_quality_matches = allow_low_quality_matches + + def __call__(self, match_quality_matrix): + """ + Args: + match_quality_matrix (Tensor[float]): an MxN tensor, containing the + pairwise quality between M ground-truth elements and N predicted elements. + + Returns: + matches (Tensor[int64]): an N tensor where N[i] is a matched gt in + [0, M - 1] or a negative value indicating that prediction i could not + be matched. + """ + if match_quality_matrix.numel() == 0: + # handle empty case + device = match_quality_matrix.device + return torch.empty((0,), dtype=torch.int64, device=device) + + # match_quality_matrix is M (gt) x N (predicted) + # Max over gt elements (dim 0) to find best gt candidate for each prediction + matched_vals, matches = match_quality_matrix.max(dim=0) + if self.allow_low_quality_matches: + all_matches = matches.clone() + + # Assign candidate matches with low quality to negative (unassigned) values + below_low_threshold = matched_vals < self.low_threshold + between_thresholds = (matched_vals >= self.low_threshold) & ( + matched_vals < self.high_threshold + ) + matches[below_low_threshold] = Matcher.BELOW_LOW_THRESHOLD + matches[between_thresholds] = Matcher.BETWEEN_THRESHOLDS + + if self.allow_low_quality_matches: + self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) + + return matches + + def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): + """ + Produce additional matches for predictions that have only low-quality matches. + Specifically, for each ground-truth find the set of predictions that have + maximum overlap with it (including ties); for each prediction in that set, if + it is unmatched, then match it to the ground-truth with which it has the highest + quality value. + """ + # For each gt, find the prediction with which it has highest quality + highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) + # Find highest quality match available, even if it is low, including ties + gt_pred_pairs_of_highest_quality = torch.nonzero( + match_quality_matrix == highest_quality_foreach_gt[:, None] + ) + # Example gt_pred_pairs_of_highest_quality: + # tensor([[ 0, 39796], + # [ 1, 32055], + # [ 1, 32070], + # [ 2, 39190], + # [ 2, 40255], + # [ 3, 40390], + # [ 3, 41455], + # [ 4, 45470], + # [ 5, 45325], + # [ 5, 46390]]) + # Each row is a (gt index, prediction index) + # Note how gt items 1, 2, 3, and 5 each have two ties + + pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1] + matches[pred_inds_to_update] = all_matches[pred_inds_to_update] diff --git a/maskrcnn_benchmark/modeling/poolers.py b/maskrcnn_benchmark/modeling/poolers.py new file mode 100644 index 0000000000000000000000000000000000000000..15df1b4b88da00a120ad657fd66f1c0f222df2aa --- /dev/null +++ b/maskrcnn_benchmark/modeling/poolers.py @@ -0,0 +1,123 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import math +import torch +import torch.nn.functional as F +from torch import nn + +from maskrcnn_benchmark.layers import ROIAlign + +from .utils import cat + + +class LevelMapper(object): + """Determine which FPN level each RoI in a set of RoIs should map to based + on the heuristic in the FPN paper. + """ + + def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6): + """ + Arguments: + k_min (int) + k_max (int) + canonical_scale (int) + canonical_level (int) + eps (float) + """ + self.k_min = k_min + self.k_max = k_max + self.s0 = canonical_scale + self.lvl0 = canonical_level + self.eps = eps + + def __call__(self, boxlists): + """ + Arguments: + boxlists (list[BoxList]) + """ + # Compute level ids + s = torch.sqrt(cat([boxlist.area() for boxlist in boxlists])) + + # Eqn.(1) in FPN paper + target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps)) + target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max) + return target_lvls.to(torch.int64) - self.k_min + + +class Pooler(nn.Module): + """ + Pooler for Detection with or without FPN. + It currently hard-code ROIAlign in the implementation, + but that can be made more generic later on. + Also, the requirement of passing the scales is not strictly necessary, as they + can be inferred from the size of the feature map / size of original image, + which is available thanks to the BoxList. + """ + + def __init__(self, output_size, scales, sampling_ratio): + """ + Arguments: + output_size (list[tuple[int]] or list[int]): output size for the pooled region + scales (list[flaot]): scales for each Pooler + sampling_ratio (int): sampling ratio for ROIAlign + """ + super(Pooler, self).__init__() + poolers = [] + for scale in scales: + poolers.append( + ROIAlign( + output_size, spatial_scale=scale, sampling_ratio=sampling_ratio + ) + ) + self.poolers = nn.ModuleList(poolers) + self.output_size = output_size + # get the levels in the feature map by leveraging the fact that the network always + # downsamples by a factor of 2 at each level. + lvl_min = -math.log2(scales[0]) + lvl_max = -math.log2(scales[-1]) + self.map_levels = LevelMapper(lvl_min, lvl_max) + + def convert_to_roi_format(self, boxes): + concat_boxes = cat([b.bbox for b in boxes], dim=0) + device, dtype = concat_boxes.device, concat_boxes.dtype + ids = cat( + [ + torch.full((len(b), 1), i, dtype=dtype, device=device) + for i, b in enumerate(boxes) + ], + dim=0, + ) + rois = torch.cat([ids, concat_boxes], dim=1) + return rois + + def forward(self, x, boxes): + """ + Arguments: + x (list[Tensor]): feature maps for each level + boxes (list[BoxList]): boxes to be used to perform the pooling operation. + Returns: + result (Tensor) + """ + num_levels = len(self.poolers) + rois = self.convert_to_roi_format(boxes) + if num_levels == 1: + return self.poolers[0](x[0], rois) + + levels = self.map_levels(boxes) + + num_rois = len(rois) + num_channels = x[0].shape[1] + output_size_h = self.output_size[0] + output_size_w = self.output_size[1] + + dtype, device = x[0].dtype, x[0].device + result = torch.zeros( + (num_rois, num_channels, output_size_h, output_size_w), + dtype=dtype, + device=device, + ) + for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)): + idx_in_level = torch.nonzero(levels == level).squeeze(1) + rois_per_level = rois[idx_in_level] + result[idx_in_level] = pooler(per_level_feature, rois_per_level) + + return result diff --git a/maskrcnn_benchmark/modeling/registry.py b/maskrcnn_benchmark/modeling/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..e14fb118c458d0ba97d2a699be3004c6bdd3913c --- /dev/null +++ b/maskrcnn_benchmark/modeling/registry.py @@ -0,0 +1,12 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +from maskrcnn_benchmark.utils.registry import Registry + +BACKBONES = Registry() +RPN_HEADS = Registry() +ROI_BOX_FEATURE_EXTRACTORS = Registry() +ROI_BOX_PREDICTOR = Registry() +ROI_KEYPOINT_FEATURE_EXTRACTORS = Registry() +ROI_KEYPOINT_PREDICTOR = Registry() +ROI_MASK_FEATURE_EXTRACTORS = Registry() +ROI_MASK_PREDICTOR = Registry() diff --git a/maskrcnn_benchmark/modeling/roi_heads/__init__.py b/maskrcnn_benchmark/modeling/roi_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/__init__.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d268f96b3a4e582ab3fc71305d42eb0ab6a4e414 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +from .inference import make_roi_box_post_processor +from .loss import make_roi_box_loss_evaluator +from .roi_box_feature_extractors import make_roi_box_feature_extractor +from .roi_box_predictors import make_roi_box_predictor + + +class ROIBoxHead(torch.nn.Module): + """ + Generic Box Head class. + """ + + def __init__(self, cfg): + super(ROIBoxHead, self).__init__() + self.cfg = cfg + self.feature_extractor = make_roi_box_feature_extractor(cfg) + self.predictor = make_roi_box_predictor(cfg) + self.post_processor = make_roi_box_post_processor(cfg) + self.loss_evaluator = make_roi_box_loss_evaluator(cfg) + + def forward(self, features, proposals, targets=None): + """ + Arguments: + features (list[Tensor]): feature-maps from possibly several levels + proposals (list[BoxList]): proposal boxes + targets (list[BoxList], optional): the ground-truth targets. + + Returns: + x (Tensor): the result of the feature extractor + proposals (list[BoxList]): during training, the subsampled proposals + are returned. During testing, the predicted boxlists are returned + losses (dict[Tensor]): During training, returns the losses for the + head. During testing, returns an empty dict. + """ + + if self.training: + # Faster R-CNN subsamples during training the proposals with a fixed + # positive / negative ratio + with torch.no_grad(): + proposals = self.loss_evaluator.subsample(proposals, targets) + + # extract features that will be fed to the final classifier. The + # feature_extractor generally corresponds to the pooler + heads + x = self.feature_extractor(features, proposals) + + # final classifier that converts the features into predictions + class_logits, box_regression = self.predictor(x) + + if not self.training: + if self.cfg.MODEL.ROI_BOX_HEAD.INFERENCE_USE_BOX: + result = self.post_processor((class_logits, box_regression), proposals) + # print(result[0].get_field('masks')) + return x, result, {} + else: + return x, proposals, {} + + loss_classifier, loss_box_reg = self.loss_evaluator( + [class_logits], [box_regression] + ) + if self.cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION: + return ( + x, + proposals, + dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg), + ) + else: + return ( + x, + proposals, + dict(loss_classifier=loss_classifier), + ) + + +def build_roi_box_head(cfg): + """ + Constructs a new box head. + By default, uses ROIBoxHead, + but if it turns out not to be enough, just register a new class + and make it a parameter in the config + """ + return ROIBoxHead(cfg) diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..25ee971f2a30963524da6601dd386adf3f4ffb09 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +import torch.nn.functional as F +from torch import nn + +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist +from maskrcnn_benchmark.modeling.box_coder import BoxCoder + + +class PostProcessor(nn.Module): + """ + From a set of classification scores, box regression and proposals, + computes the post-processed boxes, and applies NMS to obtain the + final results + """ + + def __init__( + self, score_thresh=0.05, nms=0.5, detections_per_img=100, box_coder=None, cfg=None + ): + """ + Arguments: + score_thresh (float) + nms (float) + detections_per_img (int) + box_coder (BoxCoder) + """ + super(PostProcessor, self).__init__() + self.cfg = cfg + self.score_thresh = score_thresh + self.nms = nms + self.detections_per_img = detections_per_img + if cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION: + if box_coder is None: + box_coder = BoxCoder(weights=(10., 10., 5., 5.)) + self.box_coder = box_coder + + def forward(self, x, boxes): + """ + Arguments: + x (tuple[tensor, tensor]): x contains the class logits + and the box_regression from the model. + boxes (list[BoxList]): bounding boxes that are used as + reference, one for ech image + + Returns: + results (list[BoxList]): one BoxList for each image, containing + the extra fields labels and scores + """ + class_logits, box_regression = x + class_prob = F.softmax(class_logits, -1) + + # TODO think about a representation of batch of boxes + image_shapes = [box.size for box in boxes] + boxes_per_image = [len(box) for box in boxes] + if self.cfg.MODEL.SEG.USE_SEG_POLY or self.cfg.MODEL.ROI_BOX_HEAD.USE_MASKED_FEATURE or self.cfg.MODEL.ROI_MASK_HEAD.USE_MASKED_FEATURE: + masks = [box.get_field('masks') for box in boxes] + if self.cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION: + concat_boxes = torch.cat([a.bbox for a in boxes], dim=0) + proposals = self.box_coder.decode( + box_regression.view(sum(boxes_per_image), -1), concat_boxes + ) + proposals = proposals.split(boxes_per_image, dim=0) + else: + proposals = boxes + num_classes = class_prob.shape[1] + class_prob = class_prob.split(boxes_per_image, dim=0) + + results = [] + if self.cfg.MODEL.SEG.USE_SEG_POLY or self.cfg.MODEL.ROI_BOX_HEAD.USE_MASKED_FEATURE or self.cfg.MODEL.ROI_MASK_HEAD.USE_MASKED_FEATURE: + for prob, boxes_per_img, image_shape, mask in zip( + class_prob, proposals, image_shapes, masks + ): + boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape, mask) + if self.cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION: + boxlist = boxlist.clip_to_image(remove_empty=False) + boxlist = self.filter_results(boxlist, num_classes) + results.append(boxlist) + else: + for prob, boxes_per_img, image_shape in zip( + class_prob, proposals, image_shapes + ): + boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape) + if self.cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION: + boxlist = boxlist.clip_to_image(remove_empty=False) + boxlist = self.filter_results(boxlist, num_classes) + results.append(boxlist) + return results + + def prepare_boxlist(self, boxes, scores, image_shape, mask=None): + """ + Returns BoxList from `boxes` and adds probability scores information + as an extra field + `boxes` has shape (#detections, 4 * #classes), where each row represents + a list of predicted bounding boxes for each of the object classes in the + dataset (including the background class). The detections in each row + originate from the same object proposal. + `scores` has shape (#detection, #classes), where each row represents a list + of object detection confidence scores for each of the object classes in the + dataset (including the background class). `scores[i, j]`` corresponds to the + box at `boxes[i, j * 4:(j + 1) * 4]`. + """ + if not self.cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION: + scores = scores.reshape(-1) + boxes.add_field("scores", scores) + return boxes + boxes = boxes.reshape(-1, 4) + scores = scores.reshape(-1) + boxlist = BoxList(boxes, image_shape, mode="xyxy") + boxlist.add_field("scores", scores) + if mask is not None: + boxlist.add_field('masks', mask) + return boxlist + + def filter_results(self, boxlist, num_classes): + """Returns bounding-box detection results by thresholding on scores and + applying non-maximum suppression (NMS). + """ + # unwrap the boxlist to avoid additional overhead. + # if we had multi-class NMS, we could perform this directly on the boxlist + boxes = boxlist.bbox.reshape(-1, num_classes * 4) + scores = boxlist.get_field("scores").reshape(-1, num_classes) + + device = scores.device + result = [] + # Apply threshold on detection probabilities and apply NMS + # Skip j = 0, because it's the background class + inds_all = scores > self.score_thresh + for j in range(1, num_classes): + inds = inds_all[:, j].nonzero().squeeze(1) + scores_j = scores[inds, j] + boxes_j = boxes[inds, j * 4 : (j + 1) * 4] + boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") + boxlist_for_class.add_field("scores", scores_j) + boxlist_for_class = boxlist_nms( + boxlist_for_class, self.nms, score_field="scores" + ) + num_labels = len(boxlist_for_class) + boxlist_for_class.add_field( + "labels", torch.full((num_labels,), j, dtype=torch.int64, device=device) + ) + if self.cfg.MODEL.SEG.USE_SEG_POLY or self.cfg.MODEL.ROI_BOX_HEAD.USE_MASKED_FEATURE or self.cfg.MODEL.ROI_MASK_HEAD.USE_MASKED_FEATURE: + boxlist_for_class.add_field('masks', boxlist.get_field('masks')) + result.append(boxlist_for_class) + + result = cat_boxlist(result) + number_of_detections = len(result) + + # Limit to max_per_image detections **over all classes** + if number_of_detections > self.detections_per_img > 0: + cls_scores = result.get_field("scores") + image_thresh, _ = torch.kthvalue( + cls_scores.cpu(), number_of_detections - self.detections_per_img + 1 + ) + keep = cls_scores >= image_thresh.item() + keep = torch.nonzero(keep).squeeze(1) + result = result[keep] + return result + + +def make_roi_box_post_processor(cfg): + # use_fpn = cfg.MODEL.ROI_HEADS.USE_FPN + + bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS + box_coder = BoxCoder(weights=bbox_reg_weights) + + score_thresh = cfg.MODEL.ROI_HEADS.SCORE_THRESH + nms_thresh = cfg.MODEL.ROI_HEADS.NMS + detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG + + postprocessor = PostProcessor( + score_thresh, nms_thresh, detections_per_img, box_coder, cfg + ) + return postprocessor diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..31e9ff8e464b8ecf463c866e09ebdf737daec908 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from maskrcnn_benchmark.layers import smooth_l1_loss +from maskrcnn_benchmark.modeling.balanced_positive_negative_sampler import ( + BalancedPositiveNegativeSampler, +) +from maskrcnn_benchmark.modeling.box_coder import BoxCoder +from maskrcnn_benchmark.modeling.matcher import Matcher +from maskrcnn_benchmark.modeling.utils import cat +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou +from torch.nn import functional as F + + +class FastRCNNLossComputation(object): + """ + Computes the loss for Faster R-CNN. + Also supports FPN + """ + + def __init__(self, proposal_matcher, fg_bg_sampler, box_coder, cfg=None): + """ + Arguments: + proposal_matcher (Matcher) + fg_bg_sampler (BalancedPositiveNegativeSampler) + box_coder (BoxCoder) + """ + self.proposal_matcher = proposal_matcher + self.fg_bg_sampler = fg_bg_sampler + self.box_coder = box_coder + self.cfg = cfg + + def match_targets_to_proposals(self, proposal, target): + match_quality_matrix = boxlist_iou(target, proposal) + matched_idxs = self.proposal_matcher(match_quality_matrix) + # Fast RCNN only need "labels" field for selecting the targets + target = target.copy_with_fields("labels") + # get the targets corresponding GT for each proposal + # NB: need to clamp the indices because we can have a single + # GT in the image, and matched_idxs can be -2, which goes + # out of bounds + matched_targets = target[matched_idxs.clamp(min=0)] + matched_targets.add_field("matched_idxs", matched_idxs) + return matched_targets + + def prepare_targets(self, proposals, targets): + labels = [] + regression_targets = [] + for proposals_per_image, targets_per_image in zip(proposals, targets): + matched_targets = self.match_targets_to_proposals( + proposals_per_image, targets_per_image + ) + matched_idxs = matched_targets.get_field("matched_idxs") + + labels_per_image = matched_targets.get_field("labels") + labels_per_image = labels_per_image.to(dtype=torch.int64) + + # Label background (below the low threshold) + bg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD + labels_per_image[bg_inds] = 0 + + # Label ignore proposals (between low and high thresholds) + ignore_inds = matched_idxs == Matcher.BETWEEN_THRESHOLDS + labels_per_image[ignore_inds] = -1 # -1 is ignored by sampler + + # compute regression targets + regression_targets_per_image = self.box_coder.encode( + matched_targets.bbox, proposals_per_image.bbox + ) + + labels.append(labels_per_image) + regression_targets.append(regression_targets_per_image) + + return labels, regression_targets + + def subsample(self, proposals, targets): + """ + This method performs the positive/negative sampling, and return + the sampled proposals. + Note: this function keeps a state. + + Arguments: + proposals (list[BoxList]) + targets (list[BoxList]) + """ + + labels, regression_targets = self.prepare_targets(proposals, targets) + sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) + # print('sampled_pos_inds:', sampled_pos_inds[0].sum()) + # print('sampled_neg_inds:', sampled_neg_inds[0].sum()) + + proposals = list(proposals) + # add corresponding label and + # regression_targets information to the bounding boxes + for labels_per_image, regression_targets_per_image, proposals_per_image in zip( + labels, regression_targets, proposals + ): + proposals_per_image.add_field("labels", labels_per_image) + proposals_per_image.add_field( + "regression_targets", regression_targets_per_image + ) + + # distributed sampled proposals, that were obtained on all feature maps + # concatenated via the fg_bg_sampler, into individual feature map levels + for img_idx, (pos_inds_img, neg_inds_img) in enumerate( + zip(sampled_pos_inds, sampled_neg_inds) + ): + img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1) + proposals_per_image = proposals[img_idx][img_sampled_inds] + proposals[img_idx] = proposals_per_image + + self._proposals = proposals + return proposals + + def __call__(self, class_logits, box_regression): + """ + Computes the loss for Faster R-CNN. + This requires that the subsample method has been called beforehand. + + Arguments: + class_logits (list[Tensor]) + box_regression (list[Tensor]) + + Returns: + classification_loss (Tensor) + box_loss (Tensor) + """ + + class_logits = cat(class_logits, dim=0) + if self.cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION: + box_regression = cat(box_regression, dim=0) + device = class_logits.device + + if not hasattr(self, "_proposals"): + raise RuntimeError("subsample needs to be called before") + + proposals = self._proposals + + labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0) + if self.cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION: + regression_targets = cat( + [proposal.get_field("regression_targets") for proposal in proposals], dim=0 + ) + + classification_loss = F.cross_entropy(class_logits, labels) + + if self.cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION: + # get indices that correspond to the regression targets for + # the corresponding ground truth labels, to be used with + # advanced indexing + sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1) + labels_pos = labels[sampled_pos_inds_subset] + map_inds = 4 * labels_pos[:, None] + torch.tensor([0, 1, 2, 3], device=device) + + box_loss = smooth_l1_loss( + box_regression[sampled_pos_inds_subset[:, None], map_inds], + regression_targets[sampled_pos_inds_subset], + size_average=False, + beta=1, + ) + box_loss = box_loss / labels.numel() + else: + box_loss = 0 + + return classification_loss, box_loss + + +def make_roi_box_loss_evaluator(cfg): + matcher = Matcher( + cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD, + cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD, + allow_low_quality_matches=False, + ) + + bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS + box_coder = BoxCoder(weights=bbox_reg_weights) + + fg_bg_sampler = BalancedPositiveNegativeSampler( + cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION + ) + + loss_evaluator = FastRCNNLossComputation(matcher, fg_bg_sampler, box_coder, cfg) + + return loss_evaluator diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7bfd6384af2c9eb4894147aaa9a0e6caed3630 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn +from torch.nn import functional as F + +from maskrcnn_benchmark.modeling.backbone import resnet +from maskrcnn_benchmark.modeling.poolers import Pooler +from maskrcnn_benchmark.modeling.utils import cat +from maskrcnn_benchmark.layers import Conv2d + +def conv3x3(in_planes, out_planes, stride=1, has_bias=False): + "3x3 convolution with padding" + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=has_bias + ) + + +def conv3x3_bn_relu(in_planes, out_planes, stride=1, has_bias=False): + return nn.Sequential( + conv3x3(in_planes, out_planes, stride), + nn.BatchNorm2d(out_planes), + nn.ReLU(inplace=True), + ) + +class ResNet50Conv5ROIFeatureExtractor(nn.Module): + def __init__(self, config): + super(ResNet50Conv5ROIFeatureExtractor, self).__init__() + + resolution = config.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + scales = config.MODEL.ROI_BOX_HEAD.POOLER_SCALES + sampling_ratio = config.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + + stage = resnet.StageSpec(index=4, block_count=3, return_features=False) + head = resnet.ResNetHead( + block_module=config.MODEL.RESNETS.TRANS_FUNC, + stages=(stage,), + num_groups=config.MODEL.RESNETS.NUM_GROUPS, + width_per_group=config.MODEL.RESNETS.WIDTH_PER_GROUP, + stride_in_1x1=config.MODEL.RESNETS.STRIDE_IN_1X1, + stride_init=None, + res2_out_channels=config.MODEL.RESNETS.RES2_OUT_CHANNELS, + ) + + self.pooler = pooler + self.head = head + + def forward(self, x, proposals): + x = self.pooler(x, proposals) + x = self.head(x) + return x + + +class FPN2MLPFeatureExtractor(nn.Module): + """ + Heads for FPN for classification + """ + + def __init__(self, cfg): + super(FPN2MLPFeatureExtractor, self).__init__() + self.cfg = cfg + resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES + sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + if self.cfg.MODEL.ROI_BOX_HEAD.MIX_OPTION == 'CAT': + input_size = (cfg.MODEL.BACKBONE.OUT_CHANNELS + 1) * resolution ** 2 + else: + input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS * resolution ** 2 + representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM + self.pooler = pooler + self.fc6 = nn.Linear(input_size, representation_size) + self.fc7 = nn.Linear(representation_size, representation_size) + # if self.cfg.MODEL.ROI_BOX_HEAD.MIX_OPTION == 'ATTENTION': + # self.attention = nn.Sequential( + # conv3x3_bn_relu(cfg.MODEL.BACKBONE.OUT_CHANNELS + 1, 32), + # conv3x3(32, 1), + # nn.Sigmoid() + # ) + # self.attention.apply(self.weights_init) + # if self.cfg.MODEL.ROI_BOX_HEAD.MIX_OPTION == 'ATTENTION': + # self.attention = nn.Sequential( + # Conv2d(cfg.MODEL.BACKBONE.OUT_CHANNELS + 1, 1, 1, 1, 0), + # nn.Sigmoid() + # ) + # for name, param in self.named_parameters(): + # if "bias" in name: + # nn.init.constant_(param, 0) + # elif "weight" in name: + # # Caffe2 implementation uses MSRAFill, which in fact + # # corresponds to kaiming_normal_ in PyTorch + # nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") + + for l in [self.fc6, self.fc7]: + # Caffe2 implementation uses XavierFill, which in fact + # corresponds to kaiming_uniform_ in PyTorch + nn.init.kaiming_uniform_(l.weight, a=1) + nn.init.constant_(l.bias, 0) + + def weights_init(self, m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find("BatchNorm") != -1: + m.weight.data.fill_(1.0) + m.bias.data.fill_(1e-4) + + def feature_mask(self, x, proposals): + masks = [] + for proposal in proposals: + segmentation_masks = proposal.get_field("masks") + boxes = proposal.bbox.to(torch.device("cpu")) + for segmentation_mask, box in zip(segmentation_masks, boxes): + cropped_mask = segmentation_mask.crop(box) + scaled_mask = cropped_mask.resize((self.cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION, self.cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION)) + mask = scaled_mask.convert(mode="mask") + masks.append(mask) + if len(masks) == 0: + if self.cfg.MODEL.ROI_BOX_HEAD.MIX_OPTION == 'CAT': + x = cat([x, torch.ones((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)], dim=1) + return x + masks = torch.stack(masks, dim=0).to(x.device, dtype=torch.float32) + if self.cfg.MODEL.ROI_BOX_HEAD.MIX_OPTION == 'CAT': + x = cat([x, masks.unsqueeze(1)], dim=1) + return x + if self.cfg.MODEL.ROI_BOX_HEAD.MIX_OPTION == 'ATTENTION': + # x_cat = cat([x, masks.unsqueeze(1)], dim=1) + # attention = self.attention(x_cat) + # x = x * attention + return x + soft_ratio = self.cfg.MODEL.ROI_BOX_HEAD.SOFT_MASKED_FEATURE_RATIO + if soft_ratio > 0: + if soft_ratio < 1.0: + x = x * (soft_ratio + (1 - soft_ratio) * masks.unsqueeze(1)) + else: + x = x * (1.0 + soft_ratio * masks.unsqueeze(1)) + else: + x = x * masks.unsqueeze(1) + return x + + def forward(self, x, proposals): + x = self.pooler(x, proposals) + if self.cfg.MODEL.ROI_BOX_HEAD.USE_MASKED_FEATURE: + x = self.feature_mask(x, proposals) + x = x.view(x.size(0), -1) + + x = F.relu(self.fc6(x)) + x = F.relu(self.fc7(x)) + + return x + + +_ROI_BOX_FEATURE_EXTRACTORS = { + "ResNet50Conv5ROIFeatureExtractor": ResNet50Conv5ROIFeatureExtractor, + "FPN2MLPFeatureExtractor": FPN2MLPFeatureExtractor, +} + + +def make_roi_box_feature_extractor(cfg): + func = _ROI_BOX_FEATURE_EXTRACTORS[cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR] + return func(cfg) diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d32f56830483f85a8288cb5816b91f101e8f53 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from torch import nn + + +class FastRCNNPredictor(nn.Module): + def __init__(self, config, pretrained=None): + super(FastRCNNPredictor, self).__init__() + + stage_index = 4 + stage2_relative_factor = 2 ** (stage_index - 1) + res2_out_channels = config.MODEL.RESNETS.RES2_OUT_CHANNELS + num_inputs = res2_out_channels * stage2_relative_factor + + num_classes = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES + self.avgpool = nn.AvgPool2d(kernel_size=7, stride=7) + self.cls_score = nn.Linear(num_inputs, num_classes) + self.bbox_pred = nn.Linear(num_inputs, num_classes * 4) + + nn.init.normal_(self.cls_score.weight, mean=0, std=0.01) + nn.init.constant_(self.cls_score.bias, 0) + + nn.init.normal_(self.bbox_pred.weight, mean=0, std=0.001) + nn.init.constant_(self.bbox_pred.bias, 0) + + def forward(self, x): + x = self.avgpool(x) + x = x.view(x.size(0), -1) + cls_logit = self.cls_score(x) + bbox_pred = self.bbox_pred(x) + return cls_logit, bbox_pred + + +class FPNPredictor(nn.Module): + def __init__(self, cfg): + super(FPNPredictor, self).__init__() + self.cfg = cfg + num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES + representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM + + self.cls_score = nn.Linear(representation_size, num_classes) + if cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION: + self.bbox_pred = nn.Linear(representation_size, num_classes * 4) + nn.init.normal_(self.bbox_pred.weight, std=0.001) + nn.init.constant_(self.bbox_pred.bias, 0) + + nn.init.normal_(self.cls_score.weight, std=0.01) + nn.init.constant_(self.cls_score.bias, 0) + # nn.init.normal_(self.cls_score.weight, std=0.01) + # nn.init.normal_(self.bbox_pred.weight, std=0.001) + # for l in [self.cls_score, self.bbox_pred]: + # nn.init.constant_(l.bias, 0) + + def forward(self, x): + scores = self.cls_score(x) + if self.cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION: + bbox_deltas = self.bbox_pred(x) + else: + bbox_deltas = None + + return scores, bbox_deltas + + +_ROI_BOX_PREDICTOR = { + "FastRCNNPredictor": FastRCNNPredictor, + "FPNPredictor": FPNPredictor, +} + + +def make_roi_box_predictor(cfg): + func = _ROI_BOX_PREDICTOR[cfg.MODEL.ROI_BOX_HEAD.PREDICTOR] + return func(cfg) diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/__init__.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..9575608a9f28f57f8f4eee3c8dc721808cbd7b07 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py @@ -0,0 +1,254 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import numpy as np +import torch +from PIL import Image +from torch import nn +import cv2 +from torch.nn import functional as F + +from maskrcnn_benchmark.structures.bounding_box import BoxList + +# TODO check if want to return a single BoxList or a composite +# object +class MaskPostProcessor(nn.Module): + """ + From the results of the CNN, post process the masks + by taking the mask corresponding to the class with max + probability (which are of fixed size and directly output + by the CNN) and return the masks in the mask field of the BoxList. + + If a masker object is passed, it will additionally + project the masks in the image according to the locations in boxes, + """ + + def __init__(self, masker=None): + super(MaskPostProcessor, self).__init__() + self.masker = masker + + def forward(self, x, boxes): + """ + Arguments: + x (Tensor): the mask logits + boxes (list[BoxList]): bounding boxes that are used as + reference, one for ech image + + Returns: + results (list[BoxList]): one BoxList for each image, containing + the extra field mask + """ + mask_prob = x.sigmoid() + + # select masks coresponding to the predicted classes + num_masks = x.shape[0] + labels = [bbox.get_field("labels") for bbox in boxes] + labels = torch.cat(labels) + index = torch.arange(num_masks, device=labels.device) + mask_prob = mask_prob[index, labels][:, None] + + if self.masker: + mask_prob = self.masker(mask_prob, boxes) + + boxes_per_image = [len(box) for box in boxes] + mask_prob = mask_prob.split(boxes_per_image, dim=0) + + results = [] + for prob, box in zip(mask_prob, boxes): + bbox = BoxList(box.bbox, box.size, mode="xyxy") + for field in box.fields(): + bbox.add_field(field, box.get_field(field)) + bbox.add_field("mask", prob) + results.append(bbox) + + return results +# TODO +class CharMaskPostProcessor(nn.Module): + """ + From the results of the CNN, post process the masks + by taking the mask corresponding to the class with max + probability (which are of fixed size and directly output + by the CNN) and return the masks in the mask field of the BoxList. + + If a masker object is passed, it will additionally + project the masks in the image according to the locations in boxes, + """ + + def __init__(self, cfg, masker=None): + super(CharMaskPostProcessor, self).__init__() + self.masker = masker + self.cfg = cfg + + def forward(self, x, char_mask, boxes, seq_outputs=None, seq_scores=None, detailed_seq_scores=None): + """ + Arguments: + x (Tensor): the mask logits + char_mask (Tensor): the char mask logits + boxes (list[BoxList]): bounding boxes that are used as + reference, one for ech image + + Returns: + results (list[BoxList]): one BoxList for each image, containing + the extra field mask + """ + if x is not None: + mask_prob = x.sigmoid() + mask_prob = mask_prob.squeeze(dim=1)[:, None] + if self.masker: + mask_prob = self.masker(mask_prob, boxes) + boxes_per_image = [len(box) for box in boxes] + if x is not None: + mask_prob = mask_prob.split(boxes_per_image, dim=0) + if self.cfg.MODEL.CHAR_MASK_ON: + char_mask_softmax = F.softmax(char_mask, dim=1) + char_results = {'char_mask': char_mask_softmax.cpu().numpy(), 'boxes': boxes[0].bbox.cpu().numpy(), 'seq_outputs': seq_outputs, 'seq_scores': seq_scores, 'detailed_seq_scores': detailed_seq_scores} + else: + char_results = {'char_mask': None, 'boxes': boxes[0].bbox.cpu().numpy(), 'seq_outputs': seq_outputs, 'seq_scores': seq_scores, 'detailed_seq_scores': detailed_seq_scores} + results = [] + if x is not None: + for prob, box in zip(mask_prob, boxes): + bbox = BoxList(box.bbox, box.size, mode="xyxy") + for field in box.fields(): + bbox.add_field(field, box.get_field(field)) + bbox.add_field("mask", prob) + results.append(bbox) + else: + for box in boxes: + bbox = BoxList(box.bbox, box.size, mode="xyxy") + for field in box.fields(): + bbox.add_field(field, box.get_field(field)) + results.append(bbox) + + return [results, char_results] + +class MaskPostProcessorCOCOFormat(MaskPostProcessor): + """ + From the results of the CNN, post process the results + so that the masks are pasted in the image, and + additionally convert the results to COCO format. + """ + + def forward(self, x, boxes): + import pycocotools.mask as mask_util + import numpy as np + + results = super(MaskPostProcessorCOCOFormat, self).forward(x, boxes) + for result in results: + masks = result.get_field("mask").cpu() + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + result.add_field("mask", rles) + return results + + +# the next two functions should be merged inside Masker +# but are kept here for the moment while we need them +# temporarily gor paste_mask_in_image +def expand_boxes(boxes, scale): + w_half = (boxes[:, 2] - boxes[:, 0]) * .5 + h_half = (boxes[:, 3] - boxes[:, 1]) * .5 + x_c = (boxes[:, 2] + boxes[:, 0]) * .5 + y_c = (boxes[:, 3] + boxes[:, 1]) * .5 + + w_half *= scale[1] + h_half *= scale[0] + + boxes_exp = torch.zeros_like(boxes) + boxes_exp[:, 0] = x_c - w_half + boxes_exp[:, 2] = x_c + w_half + boxes_exp[:, 1] = y_c - h_half + boxes_exp[:, 3] = y_c + h_half + return boxes_exp + + +def expand_masks(mask, padding): + N = mask.shape[0] + M_H = mask.shape[-2] + M_W = mask.shape[-1] + pad2 = 2 * padding + scale = (float(M_H + pad2) / M_H, float(M_W + pad2) / M_W) + padded_mask = mask.new_zeros((N, 1, M_H + pad2, M_W + pad2)) + padded_mask[:, :, padding:-padding, padding:-padding] = mask + return padded_mask, scale + + +def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1): + # Need to work on the CPU, where fp16 isn't supported - cast to float to avoid this + mask = mask.float() + box = box.float() + + padded_mask, scale = expand_masks(mask[None], padding=padding) + mask = padded_mask[0, 0] + box = expand_boxes(box[None], scale)[0] + box = box.numpy().astype(np.int32) + + TO_REMOVE = 1 + w = box[2] - box[0] + TO_REMOVE + h = box[3] - box[1] + TO_REMOVE + w = max(w, 1) + h = max(h, 1) + + mask = Image.fromarray(mask.cpu().numpy()) + mask = mask.resize((w, h), resample=Image.BILINEAR) + mask = np.array(mask, copy=False) + + if thresh >= 0: + mask = np.array(mask > thresh, dtype=np.uint8) + mask = torch.from_numpy(mask) + else: + # for visualization and debugging, we also + # allow it to return an unmodified mask + mask = torch.from_numpy(mask * 255).to(torch.bool) + + im_mask = torch.zeros((im_h, im_w), dtype=torch.bool) + x_0 = max(box[0], 0) + x_1 = min(box[2] + 1, im_w) + y_0 = max(box[1], 0) + y_1 = min(box[3] + 1, im_h) + + im_mask[y_0:y_1, x_0:x_1] = mask[ + (y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0]) + ] + return im_mask + + +class Masker(object): + """ + Projects a set of masks in an image on the locations + specified by the bounding boxes + """ + + def __init__(self, threshold=0.5, padding=1): + self.threshold = threshold + self.padding = padding + + def forward_single_image(self, masks, boxes): + boxes = boxes.convert("xyxy") + im_w, im_h = boxes.size + res = [ + paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding) + for mask, box in zip(masks, boxes.bbox) + ] + if len(res) > 0: + res = torch.stack(res, dim=0)[:, None] + else: + res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1])) + return res + + def __call__(self, masks, boxes): + # TODO do this properly + if isinstance(boxes, BoxList): + boxes = [boxes] + assert len(boxes) == 1, "Only single image batch supported" + result = self.forward_single_image(masks, boxes[0]) + return result + +def make_roi_mask_post_processor(cfg): + masker = None + if cfg.MODEL.CHAR_MASK_ON or cfg.SEQUENCE.SEQ_ON: + mask_post_processor = CharMaskPostProcessor(cfg, masker) + else: + mask_post_processor = MaskPostProcessor(masker) + return mask_post_processor diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..dc77cd3106127af8b8fbfb458fbca1967dcaf0d1 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py @@ -0,0 +1,237 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +# from maskrcnn_benchmark.layers import smooth_l1_loss +from maskrcnn_benchmark.modeling.matcher import Matcher +from maskrcnn_benchmark.modeling.utils import cat +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou +from torch.nn import functional as F + + +def project_masks_on_boxes(segmentation_masks, proposals, discretization_size): + """ + Given segmentation masks and the bounding boxes corresponding + to the location of the masks in the image, this function + crops and resizes the masks in the position defined by the + boxes. This prepares the masks for them to be fed to the + loss computation as the targets. + + Arguments: + segmentation_masks: an instance of SegmentationMask + proposals: an instance of BoxList + """ + masks = [] + M = discretization_size + device = proposals.bbox.device + proposals = proposals.convert("xyxy") + assert segmentation_masks.size == proposals.size, "{}, {}".format( + segmentation_masks, proposals + ) + # TODO put the proposals on the CPU, as the representation for the + # masks is not efficient GPU-wise (possibly several small tensors for + # representing a single instance mask) + proposals = proposals.bbox.to(torch.device("cpu")) + for segmentation_mask, proposal in zip(segmentation_masks, proposals): + # crop the masks, resize them to the desired resolution and + # then convert them to the tensor representation, + # instead of the list representation that was used + cropped_mask = segmentation_mask.crop(proposal) + scaled_mask = cropped_mask.resize((M, M)) + mask = scaled_mask.convert(mode="mask") + masks.append(mask) + if len(masks) == 0: + return torch.empty(0, dtype=torch.float32, device=device) + return torch.stack(masks, dim=0).to(device, dtype=torch.float32) + + +class MaskRCNNLossComputation(object): + def __init__(self, proposal_matcher, discretization_size): + """ + Arguments: + proposal_matcher (Matcher) + discretization_size (int) + """ + self.proposal_matcher = proposal_matcher + self.discretization_size = discretization_size + + def match_targets_to_proposals(self, proposal, target): + match_quality_matrix = boxlist_iou(target, proposal) + matched_idxs = self.proposal_matcher(match_quality_matrix) + # Mask RCNN needs "labels" and "masks "fields for creating the targets + target = target.copy_with_fields(["labels", "masks"]) + # get the targets corresponding GT for each proposal + # NB: need to clamp the indices because we can have a single + # GT in the image, and matched_idxs can be -2, which goes + # out of bounds + matched_targets = target[matched_idxs.clamp(min=0)] + matched_targets.add_field("matched_idxs", matched_idxs) + return matched_targets + + def prepare_targets(self, proposals, targets): + labels = [] + masks = [] + for proposals_per_image, targets_per_image in zip(proposals, targets): + matched_targets = self.match_targets_to_proposals( + proposals_per_image, targets_per_image + ) + matched_idxs = matched_targets.get_field("matched_idxs") + + labels_per_image = matched_targets.get_field("labels") + labels_per_image = labels_per_image.to(dtype=torch.int64) + + # this can probably be removed, but is left here for clarity + # and completeness + neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD + labels_per_image[neg_inds] = 0 + + # mask scores are only computed on positive samples + positive_inds = torch.nonzero(labels_per_image > 0).squeeze(1) + + segmentation_masks = matched_targets.get_field("masks") + segmentation_masks = segmentation_masks[positive_inds] + + positive_proposals = proposals_per_image[positive_inds] + + masks_per_image = project_masks_on_boxes( + segmentation_masks, positive_proposals, self.discretization_size + ) + + labels.append(labels_per_image) + masks.append(masks_per_image) + + return labels, masks + + def __call__(self, proposals, mask_logits, targets): + """ + Arguments: + proposals (list[BoxList]) + mask_logits (Tensor) + targets (list[BoxList]) + + Return: + mask_loss (Tensor): scalar tensor containing the loss + """ + labels, mask_targets = self.prepare_targets(proposals, targets) + + labels = cat(labels, dim=0) + mask_targets = cat(mask_targets, dim=0) + + positive_inds = torch.nonzero(labels > 0).squeeze(1) + labels_pos = labels[positive_inds] + + # torch.mean (in binary_cross_entropy_with_logits) doesn't + # accept empty tensors, so handle it separately + if mask_targets.numel() == 0: + return mask_logits.sum() * 0 + + mask_loss = F.binary_cross_entropy_with_logits( + mask_logits[positive_inds, labels_pos], mask_targets + ) + return mask_loss + + +class CharMaskRCNNLossComputation(object): + def __init__(self, use_weighted_loss=False): + """ + Arguments: + proposal_matcher (Matcher) + discretization_size (int) + """ + self.use_weighted_loss = use_weighted_loss + + def __call__( + self, + proposals, + mask_logits, + char_mask_logits, + mask_targets, + char_mask_targets, + char_mask_weights, + ): + """ + Arguments: + proposals (list[BoxList]) + mask_logits (Tensor) + targets (list[BoxList]) + + Return: + mask_loss (Tensor): scalar tensor containing the loss + """ + mask_targets = cat(mask_targets, dim=0) + char_mask_targets = cat(char_mask_targets, dim=0) + char_mask_weights = cat(char_mask_weights, dim=0) + char_mask_weights = char_mask_weights.mean(dim=0) + + # torch.mean (in binary_cross_entropy_with_logits) doesn't + # accept empty tensors, so handle it separately + if mask_targets.numel() == 0 or char_mask_targets.numel() == 0: + return mask_logits.sum() * 0, char_mask_targets.sum() * 0 + + mask_loss = F.binary_cross_entropy_with_logits( + mask_logits.squeeze(dim=1), mask_targets + ) + if self.use_weighted_loss: + char_mask_loss = F.cross_entropy( + char_mask_logits, char_mask_targets, char_mask_weights, ignore_index=-1 + ) + else: + char_mask_loss = F.cross_entropy( + char_mask_logits, char_mask_targets, ignore_index=-1 + ) + return mask_loss, char_mask_loss + +class SeqMaskRCNNLossComputation(object): + def __init__(self): + """ + Arguments: + proposal_matcher (Matcher) + discretization_size (int) + """ + + def __call__( + self, + proposals, + mask_logits, + mask_targets, + ): + """ + Arguments: + proposals (list[BoxList]) + mask_logits (Tensor) + targets (list[BoxList]) + + Return: + mask_loss (Tensor): scalar tensor containing the loss + """ + mask_targets = cat(mask_targets, dim=0) + + # torch.mean (in binary_cross_entropy_with_logits) doesn't + # accept empty tensors, so handle it separately + if mask_targets.numel() == 0: + return mask_logits.sum() * 0 + + mask_loss = F.binary_cross_entropy_with_logits( + mask_logits.squeeze(dim=1), mask_targets + ) + return mask_loss + + +def make_roi_mask_loss_evaluator(cfg): + matcher = Matcher( + cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD, + cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD, + allow_low_quality_matches=False, + ) + if cfg.MODEL.CHAR_MASK_ON: + loss_evaluator = CharMaskRCNNLossComputation( + use_weighted_loss=cfg.MODEL.ROI_MASK_HEAD.USE_WEIGHTED_CHAR_MASK + ) + else: + if cfg.SEQUENCE.SEQ_ON: + loss_evaluator = SeqMaskRCNNLossComputation() + else: + loss_evaluator = MaskRCNNLossComputation( + matcher, cfg.MODEL.ROI_MASK_HEAD.RESOLUTION + ) + + return loss_evaluator diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py new file mode 100644 index 0000000000000000000000000000000000000000..30076c3853405f2ae2ef9906e90a4989c0baae2c --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py @@ -0,0 +1,591 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn +from maskrcnn_benchmark.modeling.matcher import Matcher +from maskrcnn_benchmark.modeling.utils import cat +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou + +from .inference import make_roi_mask_post_processor +from .loss import make_roi_mask_loss_evaluator +from .roi_mask_feature_extractors import make_roi_mask_feature_extractor +from .roi_mask_predictors import make_roi_mask_predictor + +from maskrcnn_benchmark.layers import Conv2d +import math + +def conv3x3(in_planes, out_planes, stride=1, has_bias=False): + "3x3 convolution with padding" + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=has_bias + ) + + +def conv3x3_bn_relu(in_planes, out_planes, stride=1, has_bias=False): + return nn.Sequential( + conv3x3(in_planes, out_planes, stride), + nn.BatchNorm2d(out_planes), + nn.ReLU(inplace=True), + ) + +def keep_only_positive_boxes(boxes, batch_size_per_im): + """ + Given a set of BoxList containing the `labels` field, + return a set of BoxList for which `labels > 0`. + + Arguments: + boxes (list of BoxList) + """ + assert isinstance(boxes, (list, tuple)) + assert isinstance(boxes[0], BoxList) + assert boxes[0].has_field("labels") + positive_boxes = [] + positive_inds = [] + for boxes_per_image in boxes: + labels = boxes_per_image.get_field("labels") + inds_mask = labels > 0 + inds = inds_mask.nonzero().squeeze(1) + if len(inds) > batch_size_per_im: + new_inds = inds[:batch_size_per_im] + inds_mask[inds[batch_size_per_im:]] = 0 + else: + new_inds = inds + positive_boxes.append(boxes_per_image[new_inds]) + positive_inds.append(inds_mask) + return positive_boxes, positive_inds + + +# TODO +def project_char_masks_on_boxes( + segmentation_masks, segmentation_char_masks, proposals, discretization_size +): + """ + Given segmentation masks and the bounding boxes corresponding + to the location of the masks in the image, this function + crops and resizes the masks in the position defined by the + boxes. This prepares the masks for them to be fed to the + loss computation as the targets. + + Arguments: + segmentation_masks: an instance of SegmentationMask + proposals: an instance of BoxList + """ + masks = [] + char_masks = [] + char_mask_weights = [] + decoder_targets = [] + word_targets = [] + M_H, M_W = discretization_size[0], discretization_size[1] + device = proposals.bbox.device + proposals = proposals.convert("xyxy") + assert segmentation_masks.size == proposals.size, "{}, {}".format( + segmentation_masks, proposals + ) + assert segmentation_char_masks.size == proposals.size, "{}, {}".format( + segmentation_char_masks, proposals + ) + # TODO put the proposals on the CPU, as the representation for the + # masks is not efficient GPU-wise (possibly several small tensors for + # representing a single instance mask) + proposals = proposals.bbox.to(torch.device("cpu")) + for segmentation_mask, segmentation_char_mask, proposal in zip( + segmentation_masks, segmentation_char_masks, proposals + ): + # crop the masks, resize them to the desired resolution and + # then convert them to the tensor representation, + # instead of the list representation that was used + cropped_mask = segmentation_mask.crop(proposal) + scaled_mask = cropped_mask.resize((M_W, M_H)) + mask = scaled_mask.convert(mode="mask") + masks.append(mask) + cropped_char_mask = segmentation_char_mask.crop(proposal) + scaled_char_mask = cropped_char_mask.resize((M_W, M_H)) + char_mask, char_mask_weight, decoder_target, word_target = scaled_char_mask.convert( + mode="seq_char_mask" + ) + char_masks.append(char_mask) + char_mask_weights.append(char_mask_weight) + decoder_targets.append(decoder_target) + word_targets.append(word_target) + if len(masks) == 0: + return ( + torch.empty(0, dtype=torch.float32, device=device), + torch.empty(0, dtype=torch.long, device=device), + torch.empty(0, dtype=torch.float32, device=device), + torch.empty(0, dtype=torch.long, device=device), + ) + return ( + torch.stack(masks, dim=0).to(device, dtype=torch.float32), + torch.stack(char_masks, dim=0).to(device, dtype=torch.long), + torch.stack(char_mask_weights, dim=0).to(device, dtype=torch.float32), + torch.stack(decoder_targets, dim=0).to(device, dtype=torch.long), + torch.stack(word_targets, dim=0).to(device, dtype=torch.long), + ) + + +class ROIMaskHead(torch.nn.Module): + def __init__(self, cfg, proposal_matcher, discretization_size): + super(ROIMaskHead, self).__init__() + self.proposal_matcher = proposal_matcher + self.discretization_size = discretization_size + self.cfg = cfg.clone() + self.feature_extractor = make_roi_mask_feature_extractor(cfg) + self.predictor = make_roi_mask_predictor(cfg) + self.post_processor = make_roi_mask_post_processor(cfg) + self.loss_evaluator = make_roi_mask_loss_evaluator(cfg) + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION': + self.mask_attention = nn.Sequential( + conv3x3_bn_relu(cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + 1, 32), + conv3x3(32, 1), + # Conv2d(cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + 1, 1, 1, 1, 0), + nn.Sigmoid() + ) + self.mask_attention.apply(self.weights_init) + # for name, param in self.named_parameters(): + # if "bias" in name: + # nn.init.constant_(param, 0) + # elif "weight" in name: + # # Caffe2 implementation uses MSRAFill, which in fact + # # corresponds to kaiming_normal_ in PyTorch + # nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_DOWN': + self.mask_attention = nn.Sequential( + conv3x3_bn_relu(cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + 1, 32, stride=2), + conv3x3(32, 1, stride=2), + nn.Upsample(scale_factor=4, mode='nearest'), + nn.Sigmoid() + ) + self.mask_attention.apply(self.weights_init) + + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL': + num_channel = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] * 2 + self.channel_attention = nn.Sequential( + nn.MaxPool2d(2), + conv3x3_bn_relu(num_channel, num_channel, stride=2), + conv3x3(num_channel, num_channel, stride=2), + nn.AdaptiveAvgPool2d((1,1)), + nn.Sigmoid() + ) + self.channel_attention.apply(self.weights_init) + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_SPLIT' or self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_SPLIT_BINARY': + num_channel = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] * 2 + self.channel_attention = nn.Sequential( + nn.MaxPool2d(2), + conv3x3_bn_relu(num_channel, int(num_channel / 4), stride=2), + conv3x3(int(num_channel / 4), 2, stride=2), + nn.AdaptiveAvgPool2d((1,1)), + # nn.Sigmoid() + nn.Softmax(dim=1) + ) + self.channel_attention.apply(self.weights_init) + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_2': + num_channel = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] * 2 + self.channel_attention_2 = nn.Sequential( + nn.AdaptiveAvgPool2d((1,1)), + nn.Conv2d( + num_channel, num_channel, kernel_size=1, stride=1, padding=0 + ), + nn.Conv2d( + num_channel, num_channel, kernel_size=1, stride=1, padding=0 + ), + nn.Softmax(dim=1) + ) + self.channel_attention_2.apply(self.weights_init) + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_TANH': + feature_dim = 128 + num_channel = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] * 2 + self.mask_pooler = nn.Sequential( + nn.MaxPool2d(2), + conv3x3_bn_relu(num_channel, num_channel, stride=2), + ) + self.attn = nn.Linear(feature_dim, feature_dim) + self.v = nn.Parameter(torch.rand(feature_dim)) + stdv = 1.0 / math.sqrt(self.v.size(0)) + self.v.data.normal_(mean=0, std=stdv) + self.mask_pooler.apply(self.weights_init) + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'NEW_CAT': + num_channel = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + self.enlarge_recepitve_field = nn.Sequential( + nn.Conv2d( + 2 * num_channel, num_channel, kernel_size=3, stride=1, padding=2, dilation=2 + ), + nn.Conv2d( + num_channel, num_channel, kernel_size=3, stride=1, padding=2, dilation=2 + ), + ) + self.enlarge_recepitve_field.apply(self.weights_init) + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'NEW_MASK': + num_channel = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + self.new_mask = nn.Sequential( + nn.Conv2d( + 2 * num_channel, num_channel, kernel_size=3, stride=1, padding=2, dilation=2 + ), + nn.Conv2d( + num_channel, 32, kernel_size=3, stride=1, padding=2, dilation=2 + ), + nn.Conv2d( + 32, 1, kernel_size=3, stride=1, padding=2, dilation=2 + ), + nn.Sigmoid() + ) + self.new_mask.apply(self.weights_init) + + def weights_init(self, m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find("BatchNorm") != -1: + m.weight.data.fill_(1.0) + m.bias.data.fill_(1e-4) + + def step_function(self, x): + return torch.reciprocal(1 + torch.exp(-50 * (x - 0.5))) + + def channel_attention_tanh(self, feature, mask): + """ + :param hidden: + previous hidden state of the decoder, in shape (B, hidden_size) + :param encoder_outputs: + encoder outputs from Encoder, in shape (H*W, B, hidden_size) + :return + attention energies in shape (B, H*W) + """ + feature = feature.reshape((feature.shape[0], feature.shape[1], -1)) # (B, C, H*W) + masks = mask.reshape((mask.shape[0], mask.shape[1], -1)).repeat(1, feature.shape[1], 1) # (B, C, H*W) + fuse_feature = torch.cat([feature, masks], 2) + energy = torch.tanh(self.attn(fuse_feature)) # (B, C, 2*H*W)->(B, C, 2*H*W) + energy = energy.transpose(2, 1) # (B, 2*H*W, C) + v = self.v.repeat(feature.shape[0], 1).unsqueeze( + 1 + ) # (B, 1, 2*H*W) + energy = torch.bmm(v, energy) # (B, 1, C) + energy = energy.squeeze(1) # (B, C) + return nn.functional.softmax(energy, dim=1).unsqueeze(2).unsqueeze(3) # normalize with softmax (B, C) + + def match_targets_to_proposals(self, proposal, target): + match_quality_matrix = boxlist_iou(target, proposal) + # match_quality_matrix = boxlist_polygon_iou(target, proposal) + matched_idxs = self.proposal_matcher(match_quality_matrix) + # Mask RCNN needs "labels" and "masks "fields for creating the targets + target = target.copy_with_fields(["labels", "masks", "char_masks"]) + # get the targets corresponding GT for each proposal + # NB: need to clamp the indices because we can have a single + # GT in the image, and matched_idxs can be -2, which goes + # out of bounds + matched_targets = target[matched_idxs.clamp(min=0)] + matched_targets.add_field("matched_idxs", matched_idxs) + return matched_targets + + def prepare_targets(self, proposals, targets): + masks = [] + char_masks = [] + char_mask_weights = [] + decoder_targets = [] + word_targets = [] + for proposals_per_image, targets_per_image in zip(proposals, targets): + matched_targets = self.match_targets_to_proposals( + proposals_per_image, targets_per_image + ) + matched_idxs = matched_targets.get_field("matched_idxs") + + labels_per_image = matched_targets.get_field("labels") + labels_per_image = labels_per_image.to(dtype=torch.int64) + + # this can probably be removed, but is left here for clarity + # and completeness + neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD + labels_per_image[neg_inds] = 0 + + # mask scores are only computed on positive samples + positive_inds = torch.nonzero(labels_per_image > 0).squeeze(1) + + segmentation_masks = matched_targets.get_field("masks") + segmentation_masks = segmentation_masks[positive_inds] + + char_segmentation_masks = matched_targets.get_field("char_masks") + char_segmentation_masks = char_segmentation_masks[positive_inds] + + positive_proposals = proposals_per_image[positive_inds] + + masks_per_image, char_masks_per_image, char_masks_weight_per_image, decoder_targets_per_image, word_targets_per_image = project_char_masks_on_boxes( + segmentation_masks, + char_segmentation_masks, + positive_proposals, + self.discretization_size, + ) + + masks.append(masks_per_image) + char_masks.append(char_masks_per_image) + char_mask_weights.append(char_masks_weight_per_image) + decoder_targets.append(decoder_targets_per_image) + word_targets.append(word_targets_per_image) + + return masks, char_masks, char_mask_weights, decoder_targets, word_targets + + def feature_mask(self, x, proposals): + masks = [] + for proposal in proposals: + segmentation_masks = proposal.get_field("masks") + boxes = proposal.bbox.to(torch.device("cpu")) + for segmentation_mask, box in zip(segmentation_masks, boxes): + cropped_mask = segmentation_mask.crop(box) + scaled_mask = cropped_mask.resize((self.cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_W, self.cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_H)) + mask = scaled_mask.convert(mode="mask") + masks.append(mask) + if len(masks) == 0: + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'CAT': + x = cat([x, torch.ones((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)], dim=1) + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'MIX' or 'ATTENTION_CHANNEL' in self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION: + x = cat([x, x], dim=1) + return x + masks = torch.stack(masks, dim=0).to(x.device, dtype=torch.float32) + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'CAT': + x = cat([x, masks.unsqueeze(1)], dim=1) + return x + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'NEW_CAT': + cat_x = cat([x, x * masks.unsqueeze(1)], dim=1) + out_x = self.enlarge_recepitve_field(cat_x) + return out_x + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'NEW_MASK': + cat_x = cat([x, x * masks.unsqueeze(1)], dim=1) + new_mask = self.new_mask(cat_x) + out_x = x * new_mask + return out_x + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION' or self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_DOWN': + x_cat = cat([x, masks.unsqueeze(1)], dim=1) + attention = self.mask_attention(x_cat) + x = x * attention + return x + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'MIX': + mask_x = x * masks.unsqueeze(1) + cat_x = cat([x, mask_x], dim=1) + return cat_x + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL': + mask_x = x * masks.unsqueeze(1) + cat_x = cat([x, mask_x], dim=1) + channel_attention = self.channel_attention(cat_x) + attentioned_x = cat_x * channel_attention + return attentioned_x + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_2': + mask_x = x * masks.unsqueeze(1) + cat_x = cat([x, mask_x], dim=1) + channel_attention = self.channel_attention_2(cat_x) + # print(channel_attention[0, :, 0, 0]) + attentioned_x = cat_x * channel_attention + return attentioned_x + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_SPLIT': + mask_x = x * masks.unsqueeze(1) + cat_x = cat([x, mask_x], dim=1) + channel_attention = self.channel_attention(cat_x) + print(channel_attention[0, :, 0, 0]) + attentioned_x = cat([x * channel_attention[:, 0:1, :, :], mask_x * channel_attention[:, 1:, :, :]], dim=1) + return attentioned_x + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_SPLIT_BINARY': + mask_x = x * masks.unsqueeze(1) + cat_x = cat([x, mask_x], dim=1) + channel_attention = self.step_function(self.channel_attention(cat_x)) + # print(channel_attention[:, :, 0, 0]) + attentioned_x = cat([x * channel_attention[:, 0:1, :, :], mask_x * channel_attention[:, 1:, :, :]], dim=1) + # attentioned_x = cat([x * channel_attention[:, 1:, :, :], mask_x * channel_attention[:, 0:1, :, :]], dim=1) + return attentioned_x + if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_TANH': + mask_x = x * masks.unsqueeze(1) + cat_x = cat([x, mask_x], dim=1) + pooler_x = self.mask_pooler(cat_x) + pooler_mask = nn.functional.interpolate(masks.unsqueeze(1), scale_factor=0.25, mode='bilinear') + channel_attention = self.channel_attention_tanh(pooler_x, pooler_mask) + attentioned_x = cat_x * channel_attention + return attentioned_x + soft_ratio = self.cfg.MODEL.ROI_MASK_HEAD.SOFT_MASKED_FEATURE_RATIO + if soft_ratio > 0: + if soft_ratio < 1.0: + x = x * (soft_ratio + (1 - soft_ratio) * masks.unsqueeze(1)) + else: + x = x * (1.0 + soft_ratio * masks.unsqueeze(1)) + else: + x = x * masks.unsqueeze(1) + return x + + def forward(self, features, proposals, targets=None): + """ + Arguments: + features (list[Tensor]): feature-maps from possibly several levels + proposals (list[BoxList]): proposal boxes + targets (list[BoxList], optional): the ground-truth targets. + + Returns: + x (Tensor): the result of the feature extractor + proposals (list[BoxList]): during training, the original proposals + are returned. During testing, the predicted boxlists are returned + with the `mask` field set + losses (dict[Tensor]): During training, returns the losses for the + head. During testing, returns an empty dict. + """ + if self.training: + # during training, only focus on positive boxes + all_proposals = proposals + proposals, positive_inds = keep_only_positive_boxes( + proposals, self.cfg.MODEL.ROI_MASK_HEAD.MASK_BATCH_SIZE_PER_IM + ) + if all(len(proposal) == 0 for proposal in proposals): + return None, None, None + if self.training and self.cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR: + x = features + x = x[torch.cat(positive_inds, dim=0)] + else: + x = self.feature_extractor(features, proposals) + if self.cfg.MODEL.ROI_MASK_HEAD.USE_MASKED_FEATURE: + x = self.feature_mask(x, proposals) + if self.training: + mask_targets, char_mask_targets, char_mask_weights, \ + decoder_targets, word_targets = self.prepare_targets( + proposals, targets + ) + decoder_targets = cat(decoder_targets, dim=0) + word_targets = cat(word_targets, dim=0) + + # proposals_not_empty, targets_not = [], [] + # for proposal, target, mask_target, char_mask_target, char_mask_weight in zip(proposals, targets, mask_targets, char_mask_targets, char_mask_weights): + # if len(proposal_target[0]) > 0: + # proposals_not_empty.append(proposal) + # targets_not.append(proposal_target[1]) + # proposals = proposals_not_empty + # targets = targets_not + if self.cfg.MODEL.CHAR_MASK_ON: + if self.cfg.SEQUENCE.SEQ_ON: + if not self.training: + if x.numel() > 0: + mask_logits, char_mask_logits, seq_outputs, seq_scores, \ + detailed_seq_scores = self.predictor(x) + result = self.post_processor( + mask_logits, + char_mask_logits, + proposals, + seq_outputs=seq_outputs, + seq_scores=seq_scores, + detailed_seq_scores=detailed_seq_scores, + ) + return x, result, {} + else: + return None, None, {} + mask_logits, char_mask_logits, seq_outputs = self.predictor( + x, decoder_targets=decoder_targets, word_targets=word_targets + ) + loss_mask, loss_char_mask = self.loss_evaluator( + proposals, + mask_logits, + char_mask_logits, + mask_targets, + char_mask_targets, + char_mask_weights, + ) + return ( + x, + all_proposals, + dict( + loss_mask=loss_mask, + loss_char_mask=loss_char_mask, + loss_seq=seq_outputs, + ), + ) + else: + mask_logits, char_mask_logits = self.predictor(x) + if not self.training: + result = self.post_processor( + mask_logits, char_mask_logits, proposals + ) + return x, result, {} + loss_mask, loss_char_mask = self.loss_evaluator( + proposals, + mask_logits, + char_mask_logits, + mask_targets, + char_mask_targets, + char_mask_weights, + ) + return ( + x, + all_proposals, + dict(loss_mask=loss_mask, loss_char_mask=loss_char_mask), + ) + else: + if self.cfg.SEQUENCE.SEQ_ON: + if self.cfg.MODEL.MASK_ON: + if not self.training: + if x.numel() > 0: + mask_logits, seq_outputs, seq_scores, \ + detailed_seq_scores = self.predictor(x) + result = self.post_processor( + mask_logits, + None, + proposals, + seq_outputs=seq_outputs, + seq_scores=seq_scores, + detailed_seq_scores=detailed_seq_scores, + ) + return x, result, {} + else: + return None, None, {} + mask_logits, seq_outputs = self.predictor( + x, decoder_targets=decoder_targets, word_targets=word_targets + ) + loss_mask = self.loss_evaluator( + proposals, + mask_logits, + mask_targets, + ) + return ( + x, + all_proposals, + dict( + loss_mask=loss_mask, + loss_seq=seq_outputs, + ), + ) + else: + if not self.training: + if x.numel() > 0: + _, seq_outputs, seq_scores, \ + detailed_seq_scores = self.predictor(x) + result = self.post_processor( + None, + None, + proposals, + seq_outputs=seq_outputs, + seq_scores=seq_scores, + detailed_seq_scores=detailed_seq_scores, + ) + return x, result, {} + else: + return None, None, {} + _, seq_outputs = self.predictor( + x, decoder_targets=decoder_targets, word_targets=word_targets + ) + return ( + x, + all_proposals, + dict( + loss_seq=seq_outputs, + ), + ) + else: + mask_logits = self.predictor(x) + if not self.training: + result = self.post_processor(mask_logits, proposals) + return x, result, {} + loss_mask = self.loss_evaluator(proposals, mask_logits, targets) + return x, all_proposals, dict(loss_mask=loss_mask) + + +def build_roi_mask_head(cfg): + matcher = Matcher( + cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD, + cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD, + allow_low_quality_matches=False, + ) + return ROIMaskHead( + cfg, + matcher, + (cfg.MODEL.ROI_MASK_HEAD.RESOLUTION_H, cfg.MODEL.ROI_MASK_HEAD.RESOLUTION_W), + ) diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py new file mode 100644 index 0000000000000000000000000000000000000000..6bebccceecb9d166f2a1605f531199bdf36a56d5 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py @@ -0,0 +1,72 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from torch import nn +from torch.nn import functional as F + +from ..box_head.roi_box_feature_extractors import ResNet50Conv5ROIFeatureExtractor +from maskrcnn_benchmark.modeling.poolers import Pooler +from maskrcnn_benchmark.layers import Conv2d + + +class MaskRCNNFPNFeatureExtractor(nn.Module): + """ + Heads for FPN for classification + """ + + def __init__(self, cfg): + """ + Arguments: + num_classes (int): number of output classes + input_size (int): number of channels of the input once it's flattened + representation_size (int): size of the intermediate representation + """ + super(MaskRCNNFPNFeatureExtractor, self).__init__() + + # resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION + if cfg.MODEL.CHAR_MASK_ON or cfg.SEQUENCE.SEQ_ON: + resolution_h = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_H + resolution_w = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_W + else: + resolution_h = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION + resolution_w = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION + scales = cfg.MODEL.ROI_MASK_HEAD.POOLER_SCALES + sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO + pooler = Pooler( + output_size=(resolution_h, resolution_w), + scales=scales, + sampling_ratio=sampling_ratio, + ) + input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS + self.pooler = pooler + + layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS + + next_feature = input_size + self.blocks = [] + for layer_idx, layer_features in enumerate(layers, 1): + layer_name = "mask_fcn{}".format(layer_idx) + module = Conv2d(next_feature, layer_features, 3, stride=1, padding=1) + # Caffe2 implementation uses MSRAFill, which in fact + # corresponds to kaiming_normal_ in PyTorch + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + nn.init.constant_(module.bias, 0) + self.add_module(layer_name, module) + next_feature = layer_features + self.blocks.append(layer_name) + + def forward(self, x, proposals): + x = self.pooler(x, proposals) + for layer_name in self.blocks: + x = F.relu(getattr(self, layer_name)(x)) + + return x + + +_ROI_MASK_FEATURE_EXTRACTORS = { + "ResNet50Conv5ROIFeatureExtractor": ResNet50Conv5ROIFeatureExtractor, + "MaskRCNNFPNFeatureExtractor": MaskRCNNFPNFeatureExtractor, +} + + +def make_roi_mask_feature_extractor(cfg): + func = _ROI_MASK_FEATURE_EXTRACTORS[cfg.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR] + return func(cfg) diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py new file mode 100644 index 0000000000000000000000000000000000000000..1a4f790de16f5808908285ada394234ca40e1ada --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from maskrcnn_benchmark.layers import Conv2d, ConvTranspose2d +from torch import nn +from torch.nn import functional as F + +from .roi_seq_predictors import make_roi_seq_predictor + + +class MaskRCNNC4Predictor(nn.Module): + def __init__(self, cfg): + super(MaskRCNNC4Predictor, self).__init__() + num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES + dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + + if cfg.MODEL.ROI_HEADS.USE_FPN: + if cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'CAT': + num_inputs = dim_reduced + 1 + elif cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'MIX' or cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL': + num_inputs = dim_reduced * 2 + else: + num_inputs = dim_reduced + else: + stage_index = 4 + stage2_relative_factor = 2 ** (stage_index - 1) + res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + num_inputs = res2_out_channels * stage2_relative_factor + + self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) + self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) + + for name, param in self.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name: + # Caffe2 implementation uses MSRAFill, which in fact + # corresponds to kaiming_normal_ in PyTorch + nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") + + def forward(self, x): + x = F.relu(self.conv5_mask(x)) + return self.mask_fcn_logits(x) + + +class CharMaskRCNNC4Predictor(nn.Module): + def __init__(self, cfg): + super(CharMaskRCNNC4Predictor, self).__init__() + # num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES + num_classes = 1 + char_num_classes = cfg.MODEL.ROI_MASK_HEAD.CHAR_NUM_CLASSES + dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + + if cfg.MODEL.ROI_HEADS.USE_FPN: + if cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'CAT': + num_inputs = dim_reduced + 1 + elif cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'MIX' or cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL': + num_inputs = dim_reduced * 2 + else: + num_inputs = dim_reduced + else: + stage_index = 4 + stage2_relative_factor = 2 ** (stage_index - 1) + res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + num_inputs = res2_out_channels * stage2_relative_factor + + self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) + if cfg.MODEL.CHAR_MASK_ON: + self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) + self.char_mask_fcn_logits = Conv2d(dim_reduced, char_num_classes, 1, 1, 0) + else: + self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) + + for name, param in self.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name: + # Caffe2 implementation uses MSRAFill, which in fact + # corresponds to kaiming_normal_ in PyTorch + nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") + + def forward(self, x): + x = F.relu(self.conv5_mask(x)) + return self.mask_fcn_logits(x), self.char_mask_fcn_logits(x) + + +class SeqCharMaskRCNNC4Predictor(nn.Module): + def __init__(self, cfg): + super(SeqCharMaskRCNNC4Predictor, self).__init__() + # num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES + num_classes = 1 + char_num_classes = cfg.MODEL.ROI_MASK_HEAD.CHAR_NUM_CLASSES + dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + + if cfg.MODEL.ROI_HEADS.USE_FPN: + if cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'CAT': + num_inputs = dim_reduced + 1 + elif cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'MIX' or 'ATTENTION_CHANNEL' in cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION: + num_inputs = dim_reduced * 2 + else: + num_inputs = dim_reduced + else: + stage_index = 4 + stage2_relative_factor = 2 ** (stage_index - 1) + res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + num_inputs = res2_out_channels * stage2_relative_factor + + self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) + if cfg.MODEL.CHAR_MASK_ON: + self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) + self.char_mask_fcn_logits = Conv2d(dim_reduced, char_num_classes, 1, 1, 0) + self.seq = make_roi_seq_predictor(cfg, dim_reduced) + else: + self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) + + for name, param in self.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name: + # Caffe2 implementation uses MSRAFill, which in fact + # corresponds to kaiming_normal_ in PyTorch + nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") + + def forward(self, x, decoder_targets=None, word_targets=None): + x = F.relu(self.conv5_mask(x)) + if self.training: + loss_seq_decoder = self.seq( + x, decoder_targets=decoder_targets, word_targets=word_targets + ) + return ( + self.mask_fcn_logits(x), + self.char_mask_fcn_logits(x), + loss_seq_decoder, + ) + else: + decoded_chars, decoded_scores, detailed_decoded_scores = self.seq( + x, use_beam_search=True + ) + return ( + self.mask_fcn_logits(x), + self.char_mask_fcn_logits(x), + decoded_chars, + decoded_scores, + detailed_decoded_scores, + ) + +class SeqMaskRCNNC4Predictor(nn.Module): + def __init__(self, cfg): + super(SeqMaskRCNNC4Predictor, self).__init__() + num_classes = 1 + # char_num_classes = cfg.MODEL.ROI_MASK_HEAD.CHAR_NUM_CLASSES + dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + + if cfg.MODEL.ROI_HEADS.USE_FPN: + if cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'CAT': + num_inputs = dim_reduced + 1 + elif cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'MIX' or cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL': + num_inputs = dim_reduced * 2 + else: + num_inputs = dim_reduced + else: + stage_index = 4 + stage2_relative_factor = 2 ** (stage_index - 1) + res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + num_inputs = res2_out_channels * stage2_relative_factor + + self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) + if cfg.SEQUENCE.SEQ_ON: + self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) + self.seq = make_roi_seq_predictor(cfg, dim_reduced) + else: + self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) + + for name, param in self.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name: + # Caffe2 implementation uses MSRAFill, which in fact + # corresponds to kaiming_normal_ in PyTorch + nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") + + def forward(self, x, decoder_targets=None, word_targets=None): + x = F.relu(self.conv5_mask(x)) + if self.training: + loss_seq_decoder = self.seq( + x, decoder_targets=decoder_targets, word_targets=word_targets + ) + return ( + self.mask_fcn_logits(x), + loss_seq_decoder, + ) + else: + decoded_chars, decoded_scores, detailed_decoded_scores = self.seq( + x, use_beam_search=True + ) + return ( + self.mask_fcn_logits(x), + decoded_chars, + decoded_scores, + detailed_decoded_scores, + ) + +class SeqRCNNC4Predictor(nn.Module): + def __init__(self, cfg): + super(SeqRCNNC4Predictor, self).__init__() + num_classes = 1 + # char_num_classes = cfg.MODEL.ROI_MASK_HEAD.CHAR_NUM_CLASSES + dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + + if cfg.MODEL.ROI_HEADS.USE_FPN: + if cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'CAT': + num_inputs = dim_reduced + 1 + elif cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'MIX' or cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL': + num_inputs = dim_reduced * 2 + else: + num_inputs = dim_reduced + else: + stage_index = 4 + stage2_relative_factor = 2 ** (stage_index - 1) + res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + num_inputs = res2_out_channels * stage2_relative_factor + + self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) + if cfg.SEQUENCE.SEQ_ON: + # self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) + self.seq = make_roi_seq_predictor(cfg, dim_reduced) + # else: + # self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) + + for name, param in self.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name: + # Caffe2 implementation uses MSRAFill, which in fact + # corresponds to kaiming_normal_ in PyTorch + nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") + + def forward(self, x, decoder_targets=None, word_targets=None): + x = F.relu(self.conv5_mask(x)) + if self.training: + loss_seq_decoder = self.seq( + x, decoder_targets=decoder_targets, word_targets=word_targets + ) + return ( + None, + loss_seq_decoder, + ) + else: + decoded_chars, decoded_scores, detailed_decoded_scores = self.seq( + x, use_beam_search=True + ) + return ( + None, + decoded_chars, + decoded_scores, + detailed_decoded_scores, + ) + +_ROI_MASK_PREDICTOR = { + "MaskRCNNC4Predictor": MaskRCNNC4Predictor, + "CharMaskRCNNC4Predictor": CharMaskRCNNC4Predictor, + "SeqCharMaskRCNNC4Predictor": SeqCharMaskRCNNC4Predictor, + "SeqMaskRCNNC4Predictor": SeqMaskRCNNC4Predictor, + "SeqRCNNC4Predictor": SeqRCNNC4Predictor, +} + + +def make_roi_mask_predictor(cfg): + func = _ROI_MASK_PREDICTOR[cfg.MODEL.ROI_MASK_HEAD.PREDICTOR] + return func(cfg) diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_seq_predictors.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_seq_predictors.py new file mode 100644 index 0000000000000000000000000000000000000000..40ae7ac3884ffda378a3a667594e763eb60c9fef --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_seq_predictors.py @@ -0,0 +1,381 @@ +# Written by Minghui Liao +import math +import random + +import numpy as np +import torch +from maskrcnn_benchmark.utils.chars import char2num, num2char +from torch import nn +from torch.nn import functional as F + + +gpu_device = torch.device("cpu") +cpu_device = torch.device("cpu") + + +def reduce_mul(l): + out = 1.0 + for x in l: + out *= x + return out + + +def check_all_done(seqs): + for seq in seqs: + if not seq[-1]: + return False + return True + + +# TODO +class SequencePredictor(nn.Module): + def __init__(self, cfg, dim_in): + super(SequencePredictor, self).__init__() + self.cfg = cfg + if cfg.SEQUENCE.TWO_CONV: + self.seq_encoder = nn.Sequential( + nn.Conv2d(dim_in, dim_in, 3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, stride=2, ceil_mode=True), + nn.Conv2d(dim_in, 256, 3, padding=1), + nn.ReLU(inplace=True), + ) + else: + self.seq_encoder = nn.Sequential( + nn.Conv2d(dim_in, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, stride=2, ceil_mode=True), + ) + x_onehot_size = int(cfg.SEQUENCE.RESIZE_WIDTH / 2) + y_onehot_size = int(cfg.SEQUENCE.RESIZE_HEIGHT / 2) + self.seq_decoder = BahdanauAttnDecoderRNN( + 256, cfg.SEQUENCE.NUM_CHAR, cfg.SEQUENCE.NUM_CHAR, n_layers=1, dropout_p=0.1, onehot_size = (y_onehot_size, x_onehot_size) + ) + # self.criterion_seq_decoder = nn.NLLLoss(ignore_index = -1, reduce=False) + self.criterion_seq_decoder = nn.NLLLoss(ignore_index=-1, reduction="none") + # self.rescale = nn.Upsample(size=(16, 64), mode="bilinear", align_corners=False) + self.rescale = nn.Upsample(size=(cfg.SEQUENCE.RESIZE_HEIGHT, cfg.SEQUENCE.RESIZE_WIDTH), mode="bilinear", align_corners=False) + + self.x_onehot = nn.Embedding(x_onehot_size, x_onehot_size) + self.x_onehot.weight.data = torch.eye(x_onehot_size) + self.y_onehot = nn.Embedding(y_onehot_size, y_onehot_size) + self.y_onehot.weight.data = torch.eye(y_onehot_size) + + for name, param in self.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name: + # Caffe2 implementation uses MSRAFill, which in fact + # corresponds to kaiming_normal_ in PyTorch + nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") + + def forward( + self, x, decoder_targets=None, word_targets=None, use_beam_search=False + ): + rescale_out = self.rescale(x) + seq_decoder_input = self.seq_encoder(rescale_out) + x_onehot_size = int(self.cfg.SEQUENCE.RESIZE_WIDTH / 2) + y_onehot_size = int(self.cfg.SEQUENCE.RESIZE_HEIGHT / 2) + x_t, y_t = np.meshgrid(np.linspace(0, x_onehot_size - 1, x_onehot_size), np.linspace(0, y_onehot_size - 1, y_onehot_size)) + x_t = torch.LongTensor(x_t, device=cpu_device) + y_t = torch.LongTensor(y_t, device=cpu_device) + x_onehot_embedding = ( + self.x_onehot(x_t) + .transpose(0, 2) + .transpose(1, 2) + .repeat(seq_decoder_input.size(0), 1, 1, 1) + ) + y_onehot_embedding = ( + self.y_onehot(y_t) + .transpose(0, 2) + .transpose(1, 2) + .repeat(seq_decoder_input.size(0), 1, 1, 1) + ) + seq_decoder_input_loc = torch.cat( + [seq_decoder_input, x_onehot_embedding, y_onehot_embedding], 1 + ) + seq_decoder_input_reshape = ( + seq_decoder_input_loc.view( + seq_decoder_input_loc.size(0), seq_decoder_input_loc.size(1), -1 + ) + .transpose(0, 2) + .transpose(1, 2) + ) + if self.training: + bos_onehot = np.zeros( + (seq_decoder_input_reshape.size(1), 1), dtype=np.int32 + ) + bos_onehot[:, 0] = self.cfg.SEQUENCE.BOS_TOKEN + decoder_input = torch.tensor(bos_onehot.tolist(), device=gpu_device) + decoder_hidden = torch.zeros( + (seq_decoder_input_reshape.size(1), 256), device=gpu_device + ) + use_teacher_forcing = ( + True + if random.random() < self.cfg.SEQUENCE.TEACHER_FORCE_RATIO + else False + ) + target_length = decoder_targets.size(1) + if use_teacher_forcing: + # Teacher forcing: Feed the target as the next input + for di in range(target_length): + decoder_output, decoder_hidden, decoder_attention = self.seq_decoder( + decoder_input, decoder_hidden, seq_decoder_input_reshape + ) + if di == 0: + loss_seq_decoder = self.criterion_seq_decoder( + decoder_output, word_targets[:, di] + ) + else: + loss_seq_decoder += self.criterion_seq_decoder( + decoder_output, word_targets[:, di] + ) + decoder_input = decoder_targets[:, di] # Teacher forcing + else: + # Without teacher forcing: use its own predictions as the next input + for di in range(target_length): + decoder_output, decoder_hidden, decoder_attention = self.seq_decoder( + decoder_input, decoder_hidden, seq_decoder_input_reshape + ) + topv, topi = decoder_output.topk(1) + decoder_input = topi.squeeze( + 1 + ).detach() # detach from history as input + if di == 0: + loss_seq_decoder = self.criterion_seq_decoder( + decoder_output, word_targets[:, di] + ) + else: + loss_seq_decoder += self.criterion_seq_decoder( + decoder_output, word_targets[:, di] + ) + loss_seq_decoder = loss_seq_decoder.sum() / loss_seq_decoder.size(0) + loss_seq_decoder = 0.2 * loss_seq_decoder + return loss_seq_decoder + else: + words = [] + decoded_scores = [] + detailed_decoded_scores = [] + # real_length = 0 + if use_beam_search: + for batch_index in range(seq_decoder_input_reshape.size(1)): + decoder_hidden = torch.zeros((1, 256), device=gpu_device) + word = [] + char_scores = [] + detailed_char_scores = [] + top_seqs = self.beam_search( + seq_decoder_input_reshape[:, batch_index : batch_index + 1, :], + decoder_hidden, + beam_size=6, + max_len=self.cfg.SEQUENCE.MAX_LENGTH, + ) + top_seq = top_seqs[0] + for character in top_seq[1:]: + character_index = character[0] + if character_index == self.cfg.SEQUENCE.NUM_CHAR - 1: + char_scores.append(character[1]) + detailed_char_scores.append(character[2]) + break + else: + if character_index == 0: + word.append("~") + char_scores.append(0.0) + else: + word.append(num2char(character_index)) + char_scores.append(character[1]) + detailed_char_scores.append(character[2]) + words.append("".join(word)) + decoded_scores.append(char_scores) + detailed_decoded_scores.append(detailed_char_scores) + else: + for batch_index in range(seq_decoder_input_reshape.size(1)): + bos_onehot = np.zeros((1, 1), dtype=np.int32) + bos_onehot[:, 0] = self.cfg.SEQUENCE.BOS_TOKEN + decoder_input = torch.tensor(bos_onehot.tolist(), device=gpu_device) + decoder_hidden = torch.zeros((1, 256), device=gpu_device) + word = [] + char_scores = [] + for di in range(self.cfg.SEQUENCE.MAX_LENGTH): + decoder_output, decoder_hidden, decoder_attention = self.seq_decoder( + decoder_input, + decoder_hidden, + seq_decoder_input_reshape[ + :, batch_index : batch_index + 1, : + ], + ) + # decoder_attentions[di] = decoder_attention.data + topv, topi = decoder_output.data.topk(1) + char_scores.append(topv.item()) + if topi.item() == self.cfg.SEQUENCE.NUM_CHAR - 1: + break + else: + if topi.item() == 0: + word.append("~") + else: + word.append(num2char(topi.item())) + + # real_length = di + decoder_input = topi.squeeze(1).detach() + words.append("".join(word)) + decoded_scores.append(char_scores) + return words, decoded_scores, detailed_decoded_scores + + def beam_search_step(self, encoder_context, top_seqs, k): + all_seqs = [] + for seq in top_seqs: + seq_score = reduce_mul([_score for _, _score, _, _ in seq]) + if seq[-1][0] == self.cfg.SEQUENCE.NUM_CHAR - 1: + all_seqs.append((seq, seq_score, seq[-1][2], True)) + continue + decoder_hidden = seq[-1][-1][0] + onehot = np.zeros((1, 1), dtype=np.int32) + onehot[:, 0] = seq[-1][0] + decoder_input = torch.tensor(onehot.tolist(), device=gpu_device) + decoder_output, decoder_hidden, decoder_attention = self.seq_decoder( + decoder_input, decoder_hidden, encoder_context + ) + detailed_char_scores = decoder_output.cpu().numpy() + # print(decoder_output.shape) + scores, candidates = decoder_output.data[:, 1:].topk(k) + for i in range(k): + character_score = scores[:, i] + character_index = candidates[:, i] + score = seq_score * character_score.item() + char_score = seq_score * detailed_char_scores + rs_seq = seq + [ + ( + character_index.item() + 1, + character_score.item(), + char_score, + [decoder_hidden], + ) + ] + done = character_index.item() + 1 == self.cfg.SEQUENCE.NUM_CHAR - 1 + all_seqs.append((rs_seq, score, char_score, done)) + all_seqs = sorted(all_seqs, key=lambda seq: seq[1], reverse=True) + topk_seqs = [seq for seq, _, _, _ in all_seqs[:k]] + all_done = check_all_done(all_seqs[:k]) + return topk_seqs, all_done + + def beam_search(self, encoder_context, decoder_hidden, beam_size=6, max_len=32): + char_score = np.zeros(self.cfg.SEQUENCE.NUM_CHAR) + top_seqs = [[(self.cfg.SEQUENCE.BOS_TOKEN, 1.0, char_score, [decoder_hidden])]] + # loop + for _ in range(max_len): + top_seqs, all_done = self.beam_search_step( + encoder_context, top_seqs, beam_size + ) + if all_done: + break + return top_seqs + + +class Attn(nn.Module): + def __init__(self, method, hidden_size, embed_size, onehot_size): + super(Attn, self).__init__() + self.method = method + self.hidden_size = hidden_size + self.embed_size = embed_size + self.attn = nn.Linear(2 * self.hidden_size + onehot_size, hidden_size) + # self.attn = nn.Linear(hidden_size, hidden_size) + self.v = nn.Parameter(torch.rand(hidden_size)) + stdv = 1.0 / math.sqrt(self.v.size(0)) + self.v.data.normal_(mean=0, std=stdv) + + def forward(self, hidden, encoder_outputs): + """ + :param hidden: + previous hidden state of the decoder, in shape (B, hidden_size) + :param encoder_outputs: + encoder outputs from Encoder, in shape (H*W, B, hidden_size) + :return + attention energies in shape (B, H*W) + """ + max_len = encoder_outputs.size(0) + # this_batch_size = encoder_outputs.size(1) + H = hidden.repeat(max_len, 1, 1).transpose(0, 1) # (B, H*W, hidden_size) + encoder_outputs = encoder_outputs.transpose(0, 1) # (B, H*W, hidden_size) + attn_energies = self.score( + H, encoder_outputs + ) # compute attention score (B, H*W) + return F.softmax(attn_energies, dim=1).unsqueeze( + 1 + ) # normalize with softmax (B, 1, H*W) + + def score(self, hidden, encoder_outputs): + energy = torch.tanh( + self.attn(torch.cat([hidden, encoder_outputs], 2)) + ) # (B, H*W, 2*hidden_size+H+W)->(B, H*W, hidden_size) + energy = energy.transpose(2, 1) # (B, hidden_size, H*W) + v = self.v.repeat(encoder_outputs.data.shape[0], 1).unsqueeze( + 1 + ) # (B, 1, hidden_size) + energy = torch.bmm(v, energy) # (B, 1, H*W) + return energy.squeeze(1) # (B, H*W) + + +class BahdanauAttnDecoderRNN(nn.Module): + def __init__( + self, + hidden_size, + embed_size, + output_size, + n_layers=1, + dropout_p=0, + bidirectional=False, + onehot_size = (8, 32) + ): + super(BahdanauAttnDecoderRNN, self).__init__() + # Define parameters + self.hidden_size = hidden_size + self.embed_size = embed_size + self.output_size = output_size + self.n_layers = n_layers + self.dropout_p = dropout_p + # Define layers + self.embedding = nn.Embedding(output_size, embed_size) + self.embedding.weight.data = torch.eye(embed_size) + # self.dropout = nn.Dropout(dropout_p) + self.word_linear = nn.Linear(embed_size, hidden_size) + self.attn = Attn("concat", hidden_size, embed_size, onehot_size[0] + onehot_size[1]) + self.rnn = nn.GRUCell(2 * hidden_size + onehot_size[0] + onehot_size[1], hidden_size) + self.out = nn.Linear(hidden_size, output_size) + + def forward(self, word_input, last_hidden, encoder_outputs): + """ + :param word_input: + word input for current time step, in shape (B) + :param last_hidden: + last hidden stat of the decoder, in shape (layers*direction*B, hidden_size) + :param encoder_outputs: + encoder outputs in shape (H*W, B, C) + :return + decoder output + """ + # Get the embedding of the current input word (last output word) + word_embedded_onehot = self.embedding(word_input).view( + 1, word_input.size(0), -1 + ) # (1,B,embed_size) + word_embedded = self.word_linear(word_embedded_onehot) # (1, B, hidden_size) + attn_weights = self.attn(last_hidden, encoder_outputs) # (B, 1, H*W) + context = attn_weights.bmm( + encoder_outputs.transpose(0, 1) + ) # (B, 1, H*W) * (B, H*W, C) = (B,1,C) + context = context.transpose(0, 1) # (1,B,C) + # Combine embedded input word and attended context, run through RNN + # 2 * hidden_size + W + H: 256 + 256 + 32 + 8 = 552 + rnn_input = torch.cat((word_embedded, context), 2) + last_hidden = last_hidden.view(last_hidden.size(0), -1) + rnn_input = rnn_input.view(word_input.size(0), -1) + hidden = self.rnn(rnn_input, last_hidden) + if not self.training: + output = F.softmax(self.out(hidden), dim=1) + else: + output = F.log_softmax(self.out(hidden), dim=1) + # Return final output, hidden state + # print(output.shape) + return output, hidden, attn_weights + + +def make_roi_seq_predictor(cfg, dim_in): + return SequencePredictor(cfg, dim_in) diff --git a/maskrcnn_benchmark/modeling/roi_heads/roi_heads.py b/maskrcnn_benchmark/modeling/roi_heads/roi_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..899b3a27515079606278030e1ffbd98148718ed6 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/roi_heads.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +from .box_head.box_head import build_roi_box_head +from .mask_head.mask_head import build_roi_mask_head + + +class CombinedROIHeads(torch.nn.ModuleDict): + """ + Combines a set of individual heads (for box prediction or masks) into a single + head. + """ + + def __init__(self, cfg, heads): + super(CombinedROIHeads, self).__init__(heads) + self.cfg = cfg.clone() + if cfg.MODEL.MASK_ON and cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR: + self.mask.feature_extractor = self.box.feature_extractor + + def forward(self, features, proposals, targets=None): + losses = {} + # TODO rename x to roi_box_features, if it doesn't increase memory consumption + x, detections, loss_box = self.box(features, proposals, targets) + losses.update(loss_box) + if self.cfg.MODEL.MASK_ON or self.cfg.SEQUENCE.SEQ_ON: + mask_features = features + # optimization: during training, if we share the feature extractor between + # the box and the mask heads, + # then we can reuse the features already computed + if ( + self.training + and self.cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR + ): + mask_features = x + # During training, self.box() will return + # the unaltered proposals as "detections" + # this makes the API consistent during training and testing + x, detections, loss_mask = self.mask(mask_features, detections, targets) + if loss_mask is not None: + losses.update(loss_mask) + return x, detections, losses + + +def build_roi_heads(cfg): + # individually create the heads, that will be combined together + # afterwards + roi_heads = [] + if not cfg.MODEL.RPN_ONLY: + roi_heads.append(("box", build_roi_box_head(cfg))) + if cfg.MODEL.MASK_ON or cfg.SEQUENCE.SEQ_ON: + roi_heads.append(("mask", build_roi_mask_head(cfg))) + + # combine individual heads in a single module + if roi_heads: + roi_heads = CombinedROIHeads(cfg, roi_heads) + + return roi_heads diff --git a/maskrcnn_benchmark/modeling/rpn/__init__.py b/maskrcnn_benchmark/modeling/rpn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b01f30cfddd8ed97d5a39f55641fbc929297d885 --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# from .rpn import build_rpn diff --git a/maskrcnn_benchmark/modeling/rpn/anchor_generator.py b/maskrcnn_benchmark/modeling/rpn/anchor_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..feebedbf54bfb53006467c0c761aec2692ad455d --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/anchor_generator.py @@ -0,0 +1,263 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import math + +import numpy as np +import torch +from torch import nn + +from maskrcnn_benchmark.structures.bounding_box import BoxList + + +class BufferList(nn.Module): + """ + Similar to nn.ParameterList, but for buffers + """ + + def __init__(self, buffers=None): + super(BufferList, self).__init__() + if buffers is not None: + self.extend(buffers) + + def extend(self, buffers): + offset = len(self) + for i, buffer in enumerate(buffers): + self.register_buffer(str(offset + i), buffer) + return self + + def __len__(self): + return len(self._buffers) + + def __iter__(self): + return iter(self._buffers.values()) + + +class AnchorGenerator(nn.Module): + """ + For a set of image sizes and feature maps, computes a set + of anchors + """ + + def __init__( + self, + sizes=(128, 256, 512), + aspect_ratios=(0.5, 1.0, 2.0), + anchor_strides=(8, 16, 32), + straddle_thresh=0, + ): + super(AnchorGenerator, self).__init__() + + if len(anchor_strides) == 1: + anchor_stride = anchor_strides[0] + cell_anchors = [ + generate_anchors(anchor_stride, sizes, aspect_ratios).float() + ] + else: + if len(anchor_strides) != len(sizes): + raise RuntimeError("FPN should have #anchor_strides == #sizes") + cell_anchors = [ + generate_anchors(anchor_stride, (size,), aspect_ratios).float() + for anchor_stride, size in zip(anchor_strides, sizes) + ] + self.strides = anchor_strides + self.cell_anchors = BufferList(cell_anchors) + self.straddle_thresh = straddle_thresh + + def num_anchors_per_location(self): + return [len(cell_anchors) for cell_anchors in self.cell_anchors] + + def grid_anchors(self, grid_sizes): + anchors = [] + for size, stride, base_anchors in zip( + grid_sizes, self.strides, self.cell_anchors + ): + grid_height, grid_width = size + device = base_anchors.device + shifts_x = torch.arange( + 0, grid_width * stride, step=stride, dtype=torch.float32, device=device + ) + shifts_y = torch.arange( + 0, grid_height * stride, step=stride, dtype=torch.float32, device=device + ) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) + + anchors.append( + (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) + ) + + return anchors + + def add_visibility_to(self, boxlist): + image_width, image_height = boxlist.size + anchors = boxlist.bbox + if self.straddle_thresh >= 0: + inds_inside = ( + (anchors[..., 0] >= -self.straddle_thresh) + & (anchors[..., 1] >= -self.straddle_thresh) + & (anchors[..., 2] < image_width + self.straddle_thresh) + & (anchors[..., 3] < image_height + self.straddle_thresh) + ) + else: + device = anchors.device + inds_inside = torch.ones(anchors.shape[0], dtype=torch.bool, device=device) + boxlist.add_field("visibility", inds_inside) + + def forward(self, image_list, feature_maps): + grid_height, grid_width = feature_maps[0].shape[-2:] + grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] + anchors_over_all_feature_maps = self.grid_anchors(grid_sizes) + anchors = [] + for i, (image_height, image_width) in enumerate(image_list.image_sizes): + anchors_in_image = [] + for anchors_per_feature_map in anchors_over_all_feature_maps: + boxlist = BoxList( + anchors_per_feature_map, (image_width, image_height), mode="xyxy" + ) + self.add_visibility_to(boxlist) + anchors_in_image.append(boxlist) + anchors.append(anchors_in_image) + return anchors + + +def make_anchor_generator(config): + anchor_sizes = config.MODEL.RPN.ANCHOR_SIZES + aspect_ratios = config.MODEL.RPN.ASPECT_RATIOS + anchor_stride = config.MODEL.RPN.ANCHOR_STRIDE + straddle_thresh = config.MODEL.RPN.STRADDLE_THRESH + + if config.MODEL.RPN.USE_FPN: + assert len(anchor_stride) == len( + anchor_sizes + ), "FPN should have len(ANCHOR_STRIDE) == len(ANCHOR_SIZES)" + else: + assert len(anchor_stride) == 1, "Non-FPN should have a single ANCHOR_STRIDE" + anchor_generator = AnchorGenerator( + anchor_sizes, aspect_ratios, anchor_stride, straddle_thresh + ) + return anchor_generator + + +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +# +# Based on: +# -------------------------------------------------------- +# Faster R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick and Sean Bell +# -------------------------------------------------------- + + +# Verify that we compute the same anchors as Shaoqing's matlab implementation: +# +# >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat +# >> anchors +# +# anchors = +# +# -83 -39 100 56 +# -175 -87 192 104 +# -359 -183 376 200 +# -55 -55 72 72 +# -119 -119 136 136 +# -247 -247 264 264 +# -35 -79 52 96 +# -79 -167 96 184 +# -167 -343 184 360 + +# array([[ -83., -39., 100., 56.], +# [-175., -87., 192., 104.], +# [-359., -183., 376., 200.], +# [ -55., -55., 72., 72.], +# [-119., -119., 136., 136.], +# [-247., -247., 264., 264.], +# [ -35., -79., 52., 96.], +# [ -79., -167., 96., 184.], +# [-167., -343., 184., 360.]]) + + +def generate_anchors( + stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2) +): + """Generates a matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors + are centered on stride / 2, have (approximate) sqrt areas of the specified + sizes, and aspect ratios as given. + """ + return _generate_anchors( + stride, + np.array(sizes, dtype=np.float) / stride, + np.array(aspect_ratios, dtype=np.float), + ) + + +def _generate_anchors(base_size, scales, aspect_ratios): + """Generate anchor (reference) windows by enumerating aspect ratios X + scales wrt a reference (0, 0, base_size - 1, base_size - 1) window. + """ + anchor = np.array([1, 1, base_size, base_size], dtype=np.float) - 1 + anchors = _ratio_enum(anchor, aspect_ratios) + anchors = np.vstack( + [_scale_enum(anchors[i, :], scales) for i in range(anchors.shape[0])] + ) + return torch.from_numpy(anchors) + + +def _whctrs(anchor): + """Return width, height, x center, and y center for an anchor (window).""" + w = anchor[2] - anchor[0] + 1 + h = anchor[3] - anchor[1] + 1 + x_ctr = anchor[0] + 0.5 * (w - 1) + y_ctr = anchor[1] + 0.5 * (h - 1) + return w, h, x_ctr, y_ctr + + +def _mkanchors(ws, hs, x_ctr, y_ctr): + """Given a vector of widths (ws) and heights (hs) around a center + (x_ctr, y_ctr), output a set of anchors (windows). + """ + ws = ws[:, np.newaxis] + hs = hs[:, np.newaxis] + anchors = np.hstack( + ( + x_ctr - 0.5 * (ws - 1), + y_ctr - 0.5 * (hs - 1), + x_ctr + 0.5 * (ws - 1), + y_ctr + 0.5 * (hs - 1), + ) + ) + return anchors + + +def _ratio_enum(anchor, ratios): + """Enumerate a set of anchors for each aspect ratio wrt an anchor.""" + w, h, x_ctr, y_ctr = _whctrs(anchor) + size = w * h + size_ratios = size / ratios + ws = np.round(np.sqrt(size_ratios)) + hs = np.round(ws * ratios) + anchors = _mkanchors(ws, hs, x_ctr, y_ctr) + return anchors + + +def _scale_enum(anchor, scales): + """Enumerate a set of anchors for each scale wrt an anchor.""" + w, h, x_ctr, y_ctr = _whctrs(anchor) + ws = w * scales + hs = h * scales + anchors = _mkanchors(ws, hs, x_ctr, y_ctr) + return anchors diff --git a/maskrcnn_benchmark/modeling/rpn/inference.py b/maskrcnn_benchmark/modeling/rpn/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..bbd2355479d303e2421924873addfd8cfcb432c2 --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/inference.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +from maskrcnn_benchmark.modeling.box_coder import BoxCoder +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms +from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes + +import pdb + + +class RPNPostProcessor(torch.nn.Module): + """ + Performs post-processing on the outputs of the RPN boxes, before feeding the + proposals to the heads + """ + + def __init__( + self, + pre_nms_top_n, + post_nms_top_n, + nms_thresh, + min_size, + box_coder=None, + fpn_post_nms_top_n=None, + ): + """ + Arguments: + pre_nms_top_n (int) + post_nms_top_n (int) + nms_thresh (float) + min_size (int) + box_coder (BoxCoder) + fpn_post_nms_top_n (int) + """ + super(RPNPostProcessor, self).__init__() + self.pre_nms_top_n = pre_nms_top_n + self.post_nms_top_n = post_nms_top_n + self.nms_thresh = nms_thresh + self.min_size = min_size + + if box_coder is None: + box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + self.box_coder = box_coder + + if fpn_post_nms_top_n is None: + fpn_post_nms_top_n = post_nms_top_n + self.fpn_post_nms_top_n = fpn_post_nms_top_n + + def add_gt_proposals(self, proposals, targets): + """ + Arguments: + proposals: list[BoxList] + targets: list[BoxList] + """ + # Get the device we're operating on + device = proposals[0].bbox.device + + gt_boxes = [target.copy_with_fields([]) for target in targets] + + # later cat of bbox requires all fields to be present for all bbox + # so we need to add a dummy for objectness that's missing + for gt_box in gt_boxes: + gt_box.add_field("objectness", torch.ones(len(gt_box), device=device)) + + proposals = [ + cat_boxlist((proposal, gt_box)) + for proposal, gt_box in zip(proposals, gt_boxes) + ] + + return proposals + + def forward_for_single_feature_map(self, anchors, objectness, box_regression): + """ + Arguments: + anchors: list[BoxList] + objectness: tensor of size N, A, H, W + box_regression: tensor of size N, A * 4, H, W + """ + device = objectness.device + N, A, H, W = objectness.shape + + # put in the same format as anchors + objectness = objectness.permute(0, 2, 3, 1).reshape(N, -1) + objectness = objectness.sigmoid() + box_regression = box_regression.view(N, -1, 4, H, W).permute(0, 3, 4, 1, 2) + box_regression = box_regression.reshape(N, -1, 4) + + num_anchors = A * H * W + + pre_nms_top_n = min(self.pre_nms_top_n, num_anchors) + objectness, topk_idx = objectness.topk(pre_nms_top_n, dim=1, sorted=True) + + batch_idx = torch.arange(N, device=device)[:, None] + box_regression = box_regression[batch_idx, topk_idx] + + image_shapes = [box.size for box in anchors] + concat_anchors = torch.cat([a.bbox for a in anchors], dim=0) + concat_anchors = concat_anchors.reshape(N, -1, 4)[batch_idx, topk_idx] + + proposals = self.box_coder.decode( + box_regression.view(-1, 4), concat_anchors.view(-1, 4) + ) + + proposals = proposals.view(N, -1, 4) + + result = [] + for proposal, score, im_shape in zip(proposals, objectness, image_shapes): + boxlist = BoxList(proposal, im_shape, mode="xyxy") + boxlist.add_field("objectness", score) + boxlist = boxlist.clip_to_image(remove_empty=False) + boxlist = remove_small_boxes(boxlist, self.min_size) + boxlist = boxlist_nms( + boxlist, + self.nms_thresh, + max_proposals=self.post_nms_top_n, + score_field="objectness", + ) + result.append(boxlist) + return result + + def forward(self, anchors, objectness, box_regression, targets=None): + """ + Arguments: + anchors: list[list[BoxList]] + objectness: list[tensor] + box_regression: list[tensor] + + Returns: + boxlists (list[BoxList]): the post-processed anchors, after + applying box decoding and NMS + """ + sampled_boxes = [] + num_levels = len(objectness) + anchors = list(zip(*anchors)) + for a, o, b in zip(anchors, objectness, box_regression): + sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) + + boxlists = list(zip(*sampled_boxes)) + boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] + + if num_levels > 1: + boxlists = self.select_over_all_levels(boxlists) + + # append ground-truth bboxes to proposals + if self.training and targets is not None: + boxlists = self.add_gt_proposals(boxlists, targets) + + return boxlists + + def select_over_all_levels(self, boxlists): + num_images = len(boxlists) + # different behavior during training and during testing: + # during training, post_nms_top_n is over *all* the proposals combined, while + # during testing, it is over the proposals for each image + # TODO resolve this difference and make it consistent. It should be per image, + # and not per batch + if self.training: + objectness = torch.cat( + [boxlist.get_field("objectness") for boxlist in boxlists], dim=0 + ) + box_sizes = [len(boxlist) for boxlist in boxlists] + post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) + _, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True) + inds_mask = torch.zeros_like(objectness, dtype=torch.bool) + inds_mask[inds_sorted] = 1 + inds_mask = inds_mask.split(box_sizes) + for i in range(num_images): + boxlists[i] = boxlists[i][inds_mask[i]] + else: + for i in range(num_images): + objectness = boxlists[i].get_field("objectness") + post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) + _, inds_sorted = torch.topk( + objectness, post_nms_top_n, dim=0, sorted=True + ) + boxlists[i] = boxlists[i][inds_sorted] + return boxlists + + +def make_rpn_postprocessor(config, rpn_box_coder, is_train): + fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN + if not is_train: + fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST + + pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TRAIN + post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TRAIN + if not is_train: + pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TEST + post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TEST + nms_thresh = config.MODEL.RPN.NMS_THRESH + min_size = config.MODEL.RPN.MIN_SIZE + box_selector = RPNPostProcessor( + pre_nms_top_n=pre_nms_top_n, + post_nms_top_n=post_nms_top_n, + nms_thresh=nms_thresh, + min_size=min_size, + box_coder=rpn_box_coder, + fpn_post_nms_top_n=fpn_post_nms_top_n, + ) + return box_selector diff --git a/maskrcnn_benchmark/modeling/rpn/loss.py b/maskrcnn_benchmark/modeling/rpn/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..94b89a4613f07876d6163a10329e7d71113289ed --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/loss.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +This file contains specific functions for computing losses on the RPN +file +""" + +import torch +from torch.nn import functional as F + +from ..balanced_positive_negative_sampler import BalancedPositiveNegativeSampler +from ..utils import cat + +from maskrcnn_benchmark.layers import smooth_l1_loss +from maskrcnn_benchmark.modeling.matcher import Matcher +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist + + +class RPNLossComputation(object): + """ + This class computes the RPN loss. + """ + + def __init__(self, proposal_matcher, fg_bg_sampler, box_coder): + """ + Arguments: + proposal_matcher (Matcher) + fg_bg_sampler (BalancedPositiveNegativeSampler) + box_coder (BoxCoder) + """ + # self.target_preparator = target_preparator + self.proposal_matcher = proposal_matcher + self.fg_bg_sampler = fg_bg_sampler + self.box_coder = box_coder + + def match_targets_to_anchors(self, anchor, target): + match_quality_matrix = boxlist_iou(target, anchor) + matched_idxs = self.proposal_matcher(match_quality_matrix) + # RPN doesn't need any fields from target + # for creating the labels, so clear them all + target = target.copy_with_fields([]) + # get the targets corresponding GT for each anchor + # NB: need to clamp the indices because we can have a single + # GT in the image, and matched_idxs can be -2, which goes + # out of bounds + matched_targets = target[matched_idxs.clamp(min=0)] + matched_targets.add_field("matched_idxs", matched_idxs) + return matched_targets + + def prepare_targets(self, anchors, targets): + labels = [] + regression_targets = [] + for anchors_per_image, targets_per_image in zip(anchors, targets): + matched_targets = self.match_targets_to_anchors( + anchors_per_image, targets_per_image + ) + + matched_idxs = matched_targets.get_field("matched_idxs") + labels_per_image = matched_idxs >= 0 + labels_per_image = labels_per_image.to(dtype=torch.float32) + # discard anchors that go out of the boundaries of the image + labels_per_image[~anchors_per_image.get_field("visibility")] = -1 + + # discard indices that are between thresholds + inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS + labels_per_image[inds_to_discard] = -1 + + # compute regression targets + regression_targets_per_image = self.box_coder.encode( + matched_targets.bbox, anchors_per_image.bbox + ) + + labels.append(labels_per_image) + regression_targets.append(regression_targets_per_image) + + return labels, regression_targets + + def __call__(self, anchors, objectness, box_regression, targets): + """ + Arguments: + anchors (list[BoxList]) + objectness (list[Tensor]) + box_regression (list[Tensor]) + targets (list[BoxList]) + + Returns: + objectness_loss (Tensor) + box_loss (Tensor + """ + anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors] + labels, regression_targets = self.prepare_targets(anchors, targets) + sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) + sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1) + sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1) + + sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) + + objectness_flattened = [] + box_regression_flattened = [] + # for each feature level, permute the outputs to make them be in the + # same format as the labels. Note that the labels are computed for + # all feature levels concatenated, so we keep the same representation + # for the objectness and the box_regression + for objectness_per_level, box_regression_per_level in zip( + objectness, box_regression + ): + N, A, H, W = objectness_per_level.shape + objectness_per_level = objectness_per_level.permute(0, 2, 3, 1).reshape( + N, -1 + ) + box_regression_per_level = box_regression_per_level.view(N, -1, 4, H, W) + box_regression_per_level = box_regression_per_level.permute(0, 3, 4, 1, 2) + box_regression_per_level = box_regression_per_level.reshape(N, -1, 4) + objectness_flattened.append(objectness_per_level) + box_regression_flattened.append(box_regression_per_level) + # concatenate on the first dimension (representing the feature levels), to + # take into account the way the labels were generated (with all feature maps + # being concatenated as well) + objectness = cat(objectness_flattened, dim=1).reshape(-1) + box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4) + + labels = torch.cat(labels, dim=0) + regression_targets = torch.cat(regression_targets, dim=0) + + box_loss = smooth_l1_loss( + box_regression[sampled_pos_inds], + regression_targets[sampled_pos_inds], + beta=1.0 / 9, + size_average=False, + ) / (sampled_inds.numel()) + + objectness_loss = F.binary_cross_entropy_with_logits( + objectness[sampled_inds], labels[sampled_inds] + ) + + return objectness_loss, box_loss + + +def make_rpn_loss_evaluator(cfg, box_coder): + matcher = Matcher( + cfg.MODEL.RPN.FG_IOU_THRESHOLD, + cfg.MODEL.RPN.BG_IOU_THRESHOLD, + allow_low_quality_matches=True, + ) + + fg_bg_sampler = BalancedPositiveNegativeSampler( + cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE, cfg.MODEL.RPN.POSITIVE_FRACTION + ) + + loss_evaluator = RPNLossComputation(matcher, fg_bg_sampler, box_coder) + return loss_evaluator diff --git a/maskrcnn_benchmark/modeling/rpn/rpn.py b/maskrcnn_benchmark/modeling/rpn/rpn.py new file mode 100644 index 0000000000000000000000000000000000000000..c7160a11cef44dfc51a9122ab0a50e296b07ab41 --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/rpn.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +import torch.nn.functional as F +from torch import nn + +from maskrcnn_benchmark.modeling.box_coder import BoxCoder +from .loss import make_rpn_loss_evaluator +from .anchor_generator import make_anchor_generator +from .inference import make_rpn_postprocessor + + +class RPNHead(nn.Module): + """ + Adds a simple RPN Head with classification and regression heads + """ + + def __init__(self, in_channels, num_anchors): + """ + Arguments: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + """ + super(RPNHead, self).__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) + self.bbox_pred = nn.Conv2d( + in_channels, num_anchors * 4, kernel_size=1, stride=1 + ) + + for l in [self.conv, self.cls_logits, self.bbox_pred]: + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + def forward(self, x): + logits = [] + bbox_reg = [] + for feature in x: + t = F.relu(self.conv(feature)) + logits.append(self.cls_logits(t)) + bbox_reg.append(self.bbox_pred(t)) + return logits, bbox_reg + + +class RPNModule(torch.nn.Module): + """ + Module for RPN computation. Takes feature maps from the backbone and RPN + proposals and losses. Works for both FPN and non-FPN. + """ + + def __init__(self, cfg): + super(RPNModule, self).__init__() + + self.cfg = cfg.clone() + + anchor_generator = make_anchor_generator(cfg) + + in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + head = RPNHead(in_channels, anchor_generator.num_anchors_per_location()[0]) + + rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + + box_selector_train = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=True) + box_selector_test = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=False) + + loss_evaluator = make_rpn_loss_evaluator(cfg, rpn_box_coder) + + self.anchor_generator = anchor_generator + self.head = head + self.box_selector_train = box_selector_train + self.box_selector_test = box_selector_test + self.loss_evaluator = loss_evaluator + + def forward(self, images, features, targets=None): + """ + Arguments: + images (ImageList): images for which we want to compute the predictions + features (list[Tensor]): features computed from the images that are + used for computing the predictions. Each tensor in the list + correspond to different feature levels + targets (list[BoxList): ground-truth boxes present in the image (optional) + + Returns: + boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per + image. + losses (dict[Tensor]): the losses for the model during training. During + testing, it is an empty dict. + """ + objectness, rpn_box_regression = self.head(features) + anchors = self.anchor_generator(images, features) + + if self.training: + return self._forward_train(anchors, objectness, rpn_box_regression, targets) + else: + return self._forward_test(anchors, objectness, rpn_box_regression) + + def _forward_train(self, anchors, objectness, rpn_box_regression, targets): + if self.cfg.MODEL.RPN_ONLY: + # When training an RPN-only model, the loss is determined by the + # predicted objectness and rpn_box_regression values and there is + # no need to transform the anchors into predicted boxes; this is an + # optimization that avoids the unnecessary transformation. + boxes = anchors + else: + # For end-to-end models, anchors must be transformed into boxes and + # sampled into a training batch. + with torch.no_grad(): + boxes = self.box_selector_train( + anchors, objectness, rpn_box_regression, targets + ) + loss_objectness, loss_rpn_box_reg = self.loss_evaluator( + anchors, objectness, rpn_box_regression, targets + ) + losses = { + "loss_objectness": loss_objectness, + "loss_rpn_box_reg": loss_rpn_box_reg, + } + return boxes, losses + + def _forward_test(self, anchors, objectness, rpn_box_regression): + boxes = self.box_selector_test(anchors, objectness, rpn_box_regression) + if self.cfg.MODEL.RPN_ONLY: + # For end-to-end models, the RPN proposals are an intermediate state + # and don't bother to sort them in decreasing score order. For RPN-only + # models, the proposals are the final output and we return them in + # high-to-low confidence order. + inds = [ + box.get_field("objectness").sort(descending=True)[1] for box in boxes + ] + boxes = [box[ind] for box, ind in zip(boxes, inds)] + return boxes, {} + + +def build_rpn(cfg): + """ + This gives the gist of it. Not super important because it doesn't change as much + """ + return RPNModule(cfg) diff --git a/maskrcnn_benchmark/modeling/segmentation/inference.py b/maskrcnn_benchmark/modeling/segmentation/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..8f638b04bb77a9163921f0abee31c5888d036cf7 --- /dev/null +++ b/maskrcnn_benchmark/modeling/segmentation/inference.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +import numpy as np +import torch +import cv2 +import pyclipper +from shapely.geometry import Polygon + +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist, cat_boxlist_gt +from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes +from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask +import random + +import time + + +class SEGPostProcessor(torch.nn.Module): + """ + Performs post-processing on the outputs of the RPN boxes, before feeding the + proposals to the heads + """ + + def __init__( + self, + top_n, + binary_thresh, + box_thresh, + min_size, + cfg, + ): + """ + Arguments: + top_n (int) + binary_thresh (float) + box_thresh (float) + min_size (int) + """ + super(SEGPostProcessor, self).__init__() + self.top_n = top_n + self.binary_thresh = binary_thresh + self.box_thresh = box_thresh + self.min_size = min_size + self.cfg = cfg + + def add_gt_proposals(self, proposals, targets): + """ + Arguments: + proposals: list[BoxList] + targets: list[BoxList] + """ + # Get the device we're operating on + # device = proposals[0].bbox. + if self.cfg.MODEL.SEG.USE_SEG_POLY or self.cfg.MODEL.ROI_BOX_HEAD.USE_MASKED_FEATURE or self.cfg.MODEL.ROI_MASK_HEAD.USE_MASKED_FEATURE: + gt_boxes = [target.copy_with_fields(['masks']) for target in targets] + else: + gt_boxes = [target.copy_with_fields([]) for target in targets] + # later cat of bbox requires all fields to be present for all bbox + # so we need to add a dummy for objectness that's missing + # for gt_box in gt_boxes: + # gt_box.add_field("objectness", torch.ones(len(gt_box), device=device)) + proposals = [ + cat_boxlist_gt([proposal, gt_box]) + for proposal, gt_box in zip(proposals, gt_boxes) + ] + + return proposals + + def aug_tensor_proposals(self, boxes): + # boxes: N * 4 + boxes = boxes.float() + N = boxes.shape[0] + device = boxes.device + aug_boxes = torch.zeros((4, N, 4), device=device) + aug_boxes[0, :, :] = boxes.clone() + xmin, ymin, xmax, ymax = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] + x_center = (xmin + xmax) / 2. + y_center = (ymin + ymax) / 2. + width = xmax - xmin + height = ymax - ymin + for i in range(3): + choice = random.random() + if choice < 0.5: + # shrink or expand + ratio = (torch.randn((N,), device=device) * 3 + 1) / 2. + height = height * ratio + ratio = (torch.randn((N,), device=device) * 3 + 1) / 2. + width = width * ratio + else: + move_x = width * (torch.randn((N,), device=device) * 4 - 2) + move_y = height * (torch.randn((N,), device=device) * 4 - 2) + x_center += move_x + y_center += move_y + boxes[:, 0] = x_center - width / 2 + boxes[:, 2] = x_center + width / 2 + boxes[:, 1] = y_center - height / 2 + boxes[:, 3] = y_center + height / 2 + aug_boxes[i+1, :, :] = boxes.clone() + return aug_boxes.reshape((-1, 4)) + + def forward_for_single_feature_map(self, pred, image_shapes): + """ + Arguments: + pred: tensor of size N, 1, H, W + """ + device = pred.device + # torch.cuda.synchronize() + # start_time = time.time() + bitmap = self.binarize(pred) + # torch.cuda.synchronize() + # end_time = time.time() + # print('binarize time:', end_time - start_time) + N, height, width = pred.shape[0], pred.shape[2], pred.shape[3] + # torch.cuda.synchronize() + # start_time = time.time() + bitmap_numpy = bitmap.cpu().numpy() # The first channel + pred_map_numpy = pred.cpu().numpy() + # torch.cuda.synchronize() + # end_time = time.time() + # print('gpu2numpy time:', end_time - start_time) + boxes_batch = [] + rotated_boxes_batch = [] + polygons_batch = [] + scores_batch = [] + # torch.cuda.synchronize() + # start_time = time.time() + for batch_index in range(N): + image_shape = image_shapes[batch_index] + boxes, scores, rotated_boxes, polygons = self.boxes_from_bitmap( + pred_map_numpy[batch_index], + bitmap_numpy[batch_index], width, height) + boxes = boxes.to(device) + if self.training and self.cfg.MODEL.SEG.AUG_PROPOSALS: + boxes = self.aug_tensor_proposals(boxes) + if boxes.shape[0] > self.top_n: + boxes = boxes[:self.top_n, :] + # _, top_index = scores.topk(self.top_n, 0, sorted=False) + # boxes = boxes[top_index, :] + # scores = scores[top_index] + # boxlist = BoxList(boxes, (width, height), mode="xyxy") + boxlist = BoxList(boxes, (image_shape[1], image_shape[0]), mode="xyxy") + if self.cfg.MODEL.SEG.USE_SEG_POLY or self.cfg.MODEL.ROI_BOX_HEAD.USE_MASKED_FEATURE or self.cfg.MODEL.ROI_MASK_HEAD.USE_MASKED_FEATURE: + masks = SegmentationMask(polygons, (image_shape[1], image_shape[0])) + boxlist.add_field('masks', masks) + boxlist = boxlist.clip_to_image(remove_empty=False) + # boxlist = remove_small_boxes(boxlist, self.min_size) + boxes_batch.append(boxlist) + rotated_boxes_batch.append(rotated_boxes) + polygons_batch.append(polygons) + scores_batch.append(scores) + # torch.cuda.synchronize() + # end_time = time.time() + # print('loop time:', end_time - start_time) + return boxes_batch, rotated_boxes_batch, polygons_batch, scores_batch + + def forward(self, seg_output, image_shapes, targets=None): + """ + Arguments: + seg_output: list[tensor] + + Returns: + boxlists (list[BoxList]): bounding boxes + """ + sampled_boxes = [] + boxes_batch, rotated_boxes_batch, polygons_batch, scores_batch = self.forward_for_single_feature_map(seg_output, image_shapes) + if not self.training: + return boxes_batch, rotated_boxes_batch, polygons_batch, scores_batch + sampled_boxes.append(boxes_batch) + + boxlists = list(zip(*sampled_boxes)) + boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] + + # append ground-truth bboxes to proposals + if self.training and targets is not None: + boxlists = self.add_gt_proposals(boxlists, targets) + return boxlists + + # def select_over_all_levels(self, boxlists): + # num_images = len(boxlists) + # # different behavior during training and during testing: + # # during training, post_nms_top_n is over *all* the proposals combined, while + # # during testing, it is over the proposals for each image + # # TODO resolve this difference and make it consistent. It should be per image, + # # and not per batch + # if self.training: + # objectness = torch.cat( + # [boxlist.get_field("objectness") for boxlist in boxlists], dim=0 + # ) + # box_sizes = [len(boxlist) for boxlist in boxlists] + # post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) + # _, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True) + # inds_mask = torch.zeros_like(objectness, dtype=torch.uint8) + # inds_mask[inds_sorted] = 1 + # inds_mask = inds_mask.split(box_sizes) + # for i in range(num_images): + # boxlists[i] = boxlists[i][inds_mask[i]] + # else: + # for i in range(num_images): + # objectness = boxlists[i].get_field("objectness") + # post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) + # _, inds_sorted = torch.topk( + # objectness, post_nms_top_n, dim=0, sorted=True + # ) + # boxlists[i] = boxlists[i][inds_sorted] + # return boxlists + + def binarize(self, pred): + if self.cfg.MODEL.SEG.USE_MULTIPLE_THRESH: + binary_maps = [] + for thre in self.cfg.MODEL.SEG.MULTIPLE_THRESH: + binary_map = pred > thre + binary_maps.append(binary_map) + return torch.cat(binary_maps, dim=1) + else: + return pred > self.binary_thresh + + def boxes_from_bitmap(self, pred, bitmap, dest_width, dest_height): + """ + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + """ + # assert _bitmap.size(0) == 1 + # bitmap = _bitmap[0] # The first channel + pred = pred[0] + height, width = bitmap.shape[1], bitmap.shape[2] + boxes = [] + scores = [] + rotated_boxes = [] + polygons = [] + contours_all = [] + for i in range(bitmap.shape[0]): + try: + _, contours, _ = cv2.findContours( + (bitmap[i] * 255).astype(np.uint8), + cv2.RETR_LIST, + cv2.CHAIN_APPROX_NONE, + ) + except BaseException: + contours, _ = cv2.findContours( + (bitmap[i] * 255).astype(np.uint8), + cv2.RETR_LIST, + cv2.CHAIN_APPROX_NONE, + ) + contours_all.extend(contours) + for contour in contours_all: + epsilon = 0.01 * cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, epsilon, True) + polygon = approx.reshape((-1, 2)) + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + points = np.array(points) + score = self.box_score_fast(pred, points) + if not self.training and self.box_thresh > score: + continue + if polygon.shape[0] > 2: + polygon = self.unclip(polygon, expand_ratio=self.cfg.MODEL.SEG.EXPAND_RATIO) + if len(polygon) > 1: + continue + else: + continue + # polygon = polygon.reshape(-1, 2) + polygon = polygon.reshape(-1) + box = self.unclip(points, expand_ratio=self.cfg.MODEL.SEG.BOX_EXPAND_RATIO).reshape(-1, 2) + box = np.array(box) + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height + ) + min_x, min_y = min(box[:, 0]), min(box[:, 1]) + max_x, max_y = max(box[:, 0]), max(box[:, 1]) + horizontal_box = torch.from_numpy(np.array([min_x, min_y, max_x, max_y])) + boxes.append(horizontal_box) + scores.append(score) + rotated_box, _ = self.get_mini_boxes(box.reshape(-1, 1, 2)) + rotated_box = np.array(rotated_box) + rotated_boxes.append(rotated_box) + polygons.append([polygon]) + if len(boxes) == 0: + boxes = [torch.from_numpy(np.array([0, 0, 0, 0]))] + scores = [0.] + + boxes = torch.stack(boxes) + scores = torch.from_numpy(np.array(scores)) + return boxes, scores, rotated_boxes, polygons + + def aug_proposals(self, box): + xmin, ymin, xmax, ymax = box[0], box[1], box[2], box[3] + x_center = int((xmin + xmax) / 2.) + y_center = int((ymin + ymax) / 2.) + width = xmax - xmin + height = ymax - ymin + choice = random.random() + if choice < 0.5: + # shrink or expand + ratio = (random.random() * 3 + 1) / 2. + height = height * ratio + ratio = (random.random() * 3 + 1) / 2. + width = width * ratio + else: + move_x = width * (random.random() * 4 - 2) + move_y = height * (random.random() * 4 - 2) + x_center += move_x + y_center += move_y + xmin = int(x_center - width / 2) + xmax = int(x_center + width / 2) + ymin = int(y_center - height / 2) + ymax = int(y_center + height / 2) + return [xmin, ymin, xmax, ymax] + + def unclip(self, box, expand_ratio=1.5): + poly = Polygon(box) + distance = poly.area * expand_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [points[index_1], points[index_2], points[index_3], points[index_4]] + return box, min(bounding_box[1]) + + def box_score(self, bitmap, box): + """ + naive version of box score computation, + only for helping principle understand. + """ + mask = np.zeros_like(bitmap, dtype=np.uint8) + cv2.fillPoly(mask, box.reshape(1, 4, 2).astype(np.int32), 1) + return cv2.mean(bitmap, mask)[0] + + def box_score_fast(self, bitmap, _box): + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, 4, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] + + +def make_seg_postprocessor(config, is_train): + top_n = config.MODEL.SEG.TOP_N_TRAIN + if not is_train: + top_n = config.MODEL.SEG.TOP_N_TEST + + binary_thresh = config.MODEL.SEG.BINARY_THRESH + box_thresh = config.MODEL.SEG.BOX_THRESH + min_size = config.MODEL.SEG.MIN_SIZE + box_selector = SEGPostProcessor( + top_n=top_n, + binary_thresh=binary_thresh, + box_thresh=box_thresh, + min_size=min_size, + cfg = config + ) + return box_selector diff --git a/maskrcnn_benchmark/modeling/segmentation/loss.py b/maskrcnn_benchmark/modeling/segmentation/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..08d662cced182ba55859dc29d8e146e27acb43ac --- /dev/null +++ b/maskrcnn_benchmark/modeling/segmentation/loss.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +""" +This file contains specific functions for computing losses on the SEG +file +""" + +import torch + + +class SEGLossComputation(object): + """ + This class computes the SEG loss. + """ + + def __init__(self, cfg): + self.eps = 1e-6 + self.cfg = cfg + + def __call__(self, preds, targets): + """ + Arguments: + preds (Tensor) + targets (list[Tensor]) + masks (list[Tensor]) + Returns: + seg_loss (Tensor) + """ + image_size = (preds.shape[2], preds.shape[3]) + segm_targets, masks = self.prepare_targets(targets, image_size) + device = preds.device + segm_targets = segm_targets.float().to(device) + masks = masks.float().to(device) + seg_loss = self.dice_loss(preds, segm_targets, masks) + return seg_loss + + def dice_loss(self, pred, gt, m): + intersection = torch.sum(pred * gt * m) + union = torch.sum(pred * m) + torch.sum(gt * m) + self.eps + loss = 1 - 2.0 * intersection / union + return loss + + def project_masks_on_image(self, mask_polygons, labels, shrink_ratio, image_size): + seg_map, training_mask = mask_polygons.convert_seg_map( + labels, shrink_ratio, image_size, self.cfg.MODEL.SEG.IGNORE_DIFFICULT + ) + return torch.from_numpy(seg_map), torch.from_numpy(training_mask) + + def prepare_targets(self, targets, image_size): + segms = [] + training_masks = [] + for target_per_image in targets: + segmentation_masks = target_per_image.get_field("masks") + labels = target_per_image.get_field("labels") + seg_maps_per_image, training_masks_per_image = self.project_masks_on_image( + segmentation_masks, labels, self.cfg.MODEL.SEG.SHRINK_RATIO, image_size + ) + segms.append(seg_maps_per_image) + training_masks.append(training_masks_per_image) + return torch.stack(segms), torch.stack(training_masks) + + +def make_seg_loss_evaluator(cfg): + loss_evaluator = SEGLossComputation(cfg) + return loss_evaluator diff --git a/maskrcnn_benchmark/modeling/segmentation/segmentation.py b/maskrcnn_benchmark/modeling/segmentation/segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4a01f8a274416908ad2524e026adee48897db6 --- /dev/null +++ b/maskrcnn_benchmark/modeling/segmentation/segmentation.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +import torch +from torch import nn + +from .inference import make_seg_postprocessor +from .loss import make_seg_loss_evaluator +import time + + +def conv3x3(in_planes, out_planes, stride=1, has_bias=False): + "3x3 convolution with padding" + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=has_bias + ) + + +def conv3x3_bn_relu(in_planes, out_planes, stride=1, has_bias=False): + return nn.Sequential( + conv3x3(in_planes, out_planes, stride), + nn.BatchNorm2d(out_planes), + nn.ReLU(inplace=True), + ) + + +class SEGHead(nn.Module): + """ + Adds a simple SEG Head with pixel-level prediction + """ + + def __init__(self, in_channels, cfg): + """ + Arguments: + in_channels (int): number of channels of the input feature + """ + super(SEGHead, self).__init__() + self.cfg = cfg + ndim = 256 + self.fpn_out5 = nn.Sequential( + conv3x3(ndim, 64), nn.Upsample(scale_factor=8, mode="nearest") + ) + self.fpn_out4 = nn.Sequential( + conv3x3(ndim, 64), nn.Upsample(scale_factor=4, mode="nearest") + ) + self.fpn_out3 = nn.Sequential( + conv3x3(ndim, 64), nn.Upsample(scale_factor=2, mode="nearest") + ) + self.fpn_out2 = conv3x3(ndim, 64) + self.seg_out = nn.Sequential( + conv3x3_bn_relu(in_channels, 64, 1), + nn.ConvTranspose2d(64, 64, 2, 2), + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.ConvTranspose2d(64, 1, 2, 2), + nn.Sigmoid(), + ) + if self.cfg.MODEL.SEG.USE_PPM: + # PPM Module + pool_scales=(2, 4, 8) + fc_dim = 256 + self.ppm_pooling = [] + self.ppm_conv = [] + for scale in pool_scales: + self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) + self.ppm_conv.append(nn.Sequential( + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True) + )) + self.ppm_pooling = nn.ModuleList(self.ppm_pooling) + self.ppm_conv = nn.ModuleList(self.ppm_conv) + self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, ndim, 1) + self.ppm_conv.apply(self.weights_init) + self.ppm_last_conv.apply(self.weights_init) + self.fpn_out5.apply(self.weights_init) + self.fpn_out4.apply(self.weights_init) + self.fpn_out3.apply(self.weights_init) + self.fpn_out2.apply(self.weights_init) + self.seg_out.apply(self.weights_init) + + def forward(self, x): + if self.cfg.MODEL.SEG.USE_PPM: + conv5 = x[-2] + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): + ppm_out.append(pool_conv(nn.functional.interpolate( + pool_scale(conv5), + (input_size[2], input_size[3]), + mode='bilinear', align_corners=False))) + ppm_out = torch.cat(ppm_out, 1) + f = self.ppm_last_conv(ppm_out) + else: + f = x[-2] + # p5 = self.fpn_out5(x[-2]) + p5 = self.fpn_out5(f) + p4 = self.fpn_out4(x[-3]) + p3 = self.fpn_out3(x[-4]) + p2 = self.fpn_out2(x[-5]) + fuse = torch.cat((p5, p4, p3, p2), 1) + out = self.seg_out(fuse) + return out, fuse + + def weights_init(self, m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find("BatchNorm") != -1: + m.weight.data.fill_(1.0) + m.bias.data.fill_(1e-4) + + +class SEGModule(torch.nn.Module): + """ + Module for RPN computation. Takes feature maps from the backbone and RPN + proposals and losses. Works for both FPN and non-FPN. + """ + + def __init__(self, cfg): + super(SEGModule, self).__init__() + + self.cfg = cfg.clone() + + in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + head = SEGHead(in_channels, cfg) + + box_selector_train = make_seg_postprocessor(cfg, is_train=True) + box_selector_test = make_seg_postprocessor(cfg, is_train=False) + + loss_evaluator = make_seg_loss_evaluator(cfg) + + # self.anchor_generator = anchor_generator + self.head = head + self.box_selector_train = box_selector_train + self.box_selector_test = box_selector_test + self.loss_evaluator = loss_evaluator + + def forward(self, images, features, targets=None): + """ + Arguments: + images (ImageList): images for which we want to compute the predictions + features (Tensor): fused feature from FPN + targets (Tensor): segmentaion gt map + + Returns: + boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per + image. + losses (dict[Tensor]): the losses for the model during training. During + testing, it is an empty dict. + """ + preds, fuse_feature = self.head(features) + # anchors = self.anchor_generator(images, features) + image_shapes = images.get_sizes() + if self.training: + return self._forward_train(preds, targets, image_shapes), [fuse_feature] + else: + return self._forward_test(preds, image_shapes), [fuse_feature] + + def _forward_train(self, preds, targets, image_shapes): + # Segmentation map must be transformed into boxes for detection. + # sampled into a training batch. + with torch.no_grad(): + boxes = self.box_selector_train(preds, image_shapes, targets) + loss_seg = self.loss_evaluator(preds, targets) + losses = {"loss_seg": loss_seg} + return boxes, losses + + def _forward_test(self, preds, image_shapes): + # torch.cuda.synchronize() + # start_time = time.time() + boxes, rotated_boxes, polygons, scores = self.box_selector_test(preds, image_shapes) + # torch.cuda.synchronize() + # end_time = time.time() + # print('post time:', end_time - start_time) + seg_results = {'rotated_boxes': rotated_boxes, 'polygons': polygons, 'preds': preds, 'scores': scores} + return boxes, seg_results + + +def build_segmentation(cfg): + """ + This gives the gist of it. Not super important because it doesn't change as much + """ + return SEGModule(cfg) diff --git a/maskrcnn_benchmark/modeling/utils.py b/maskrcnn_benchmark/modeling/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..14e79e680a2c4976d2092fb8720f6f72cc6bfc39 --- /dev/null +++ b/maskrcnn_benchmark/modeling/utils.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Miscellaneous utility functions +""" + +import torch + + +def cat(tensors, dim=0): + """ + Efficient version of torch.cat that avoids a copy if there is only a single element in a list + """ + assert isinstance(tensors, (list, tuple)) + if len(tensors) == 1: + return tensors[0] + return torch.cat(tensors, dim) + \ No newline at end of file diff --git a/maskrcnn_benchmark/solver/__init__.py b/maskrcnn_benchmark/solver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..75f40530cccb6b989d33193de92a6c26a07cf751 --- /dev/null +++ b/maskrcnn_benchmark/solver/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .build import make_optimizer +from .build import make_lr_scheduler +from .lr_scheduler import WarmupMultiStepLR diff --git a/maskrcnn_benchmark/solver/build.py b/maskrcnn_benchmark/solver/build.py new file mode 100644 index 0000000000000000000000000000000000000000..bed60e911fc77b7edab5bf1d5e0049425547da9b --- /dev/null +++ b/maskrcnn_benchmark/solver/build.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +from .lr_scheduler import WarmupMultiStepLR + + +def make_optimizer(cfg, model): + params = [] + for key, value in model.named_parameters(): + if not value.requires_grad: + continue + lr = cfg.SOLVER.BASE_LR + weight_decay = cfg.SOLVER.WEIGHT_DECAY + if "bias" in key: + lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR + weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS + params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] + + if cfg.SOLVER.USE_ADAM: + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + else: + optimizer = torch.optim.SGD(params, lr, momentum=cfg.SOLVER.MOMENTUM) + + return optimizer + + +def make_lr_scheduler(cfg, optimizer): + return WarmupMultiStepLR( + optimizer, + cfg.SOLVER.STEPS, + cfg.SOLVER.GAMMA, + warmup_factor=cfg.SOLVER.WARMUP_FACTOR, + warmup_iters=cfg.SOLVER.WARMUP_ITERS, + warmup_method=cfg.SOLVER.WARMUP_METHOD, + pow_schedule_mode = cfg.SOLVER.POW_SCHEDULE, + max_iter = cfg.SOLVER.MAX_ITER, + ) diff --git a/maskrcnn_benchmark/solver/lr_scheduler.py b/maskrcnn_benchmark/solver/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3a5feafe3f089f26125acc5b19dd89ac5599b0 --- /dev/null +++ b/maskrcnn_benchmark/solver/lr_scheduler.py @@ -0,0 +1,65 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from bisect import bisect_right + +import torch + + +# FIXME ideally this would be achieved with a CombinedLRScheduler, +# separating MultiStepLR with WarmupLR +# but the current LRScheduler design doesn't allow it +class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): + def __init__( + self, + optimizer, + milestones, + gamma=0.1, + warmup_factor=1.0 / 3, + warmup_iters=500, + warmup_method="linear", + last_epoch=-1, + pow_schedule_mode = False, + max_iter = 300000, + lr_pow = 0.9 + ): + if not list(milestones) == sorted(milestones): + raise ValueError( + "Milestones should be a list of" " increasing integers. Got {}", + milestones, + ) + + if warmup_method not in ("constant", "linear"): + raise ValueError( + "Only 'constant' or 'linear' warmup_method accepted" + "got {}".format(warmup_method) + ) + self.milestones = milestones + self.gamma = gamma + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + self.pow_schedule_mode = pow_schedule_mode + self.max_iter = max_iter + self.lr_pow = lr_pow + super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + warmup_factor = 1 + if self.last_epoch < self.warmup_iters: + if self.warmup_method == "constant": + warmup_factor = self.warmup_factor + elif self.warmup_method == "linear": + alpha = self.last_epoch / self.warmup_iters + warmup_factor = self.warmup_factor * (1 - alpha) + alpha + if self.pow_schedule_mode: + scale_running_lr = ((1. - float(self.last_epoch) / self.max_iter) ** self.lr_pow) + return [ + base_lr * warmup_factor * scale_running_lr + for base_lr in self.base_lrs + ] + else: + return [ + base_lr + * warmup_factor + * self.gamma ** bisect_right(self.milestones, self.last_epoch) + for base_lr in self.base_lrs + ] diff --git a/maskrcnn_benchmark/structures/__init__.py b/maskrcnn_benchmark/structures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/maskrcnn_benchmark/structures/bounding_box.py b/maskrcnn_benchmark/structures/bounding_box.py new file mode 100644 index 0000000000000000000000000000000000000000..87ca55dd2c8617471825d3a2ef315db0263d2b0e --- /dev/null +++ b/maskrcnn_benchmark/structures/bounding_box.py @@ -0,0 +1,315 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import numpy as np +import torch +# from shapely import affinity +# from shapely.geometry import box + + +# transpose +FLIP_LEFT_RIGHT = 0 +FLIP_TOP_BOTTOM = 1 + + +class BoxList(object): + """ + This class represents a set of bounding boxes. + The bounding boxes are represented as a Nx4 Tensor. + In order ot uniquely determine the bounding boxes with respect + to an image, we also store the corresponding image dimensions. + They can contain extra information that is specific to each bounding box, such as + labels. + """ + + def __init__(self, bbox, image_size, mode="xyxy", use_char_ann=True, is_fake=False): + device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device("cpu") + bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device) + if bbox.ndimension() != 2: + raise ValueError( + "bbox should have 2 dimensions, got {}".format(bbox.ndimension()) + ) + if bbox.size(-1) != 4: + raise ValueError( + "last dimenion of bbox should have a " + "size of 4, got {}".format(bbox.size(-1)) + ) + if mode not in ("xyxy", "xywh"): + raise ValueError("mode should be 'xyxy' or 'xywh'") + + self.bbox = bbox + self.size = image_size # (image_width, image_height) + self.mode = mode + self.extra_fields = {} + self.use_char_ann = use_char_ann + + def set_size(self, size): + self.size = size + bbox = BoxList( + self.bbox, size, mode=self.mode, use_char_ann=self.use_char_ann + ) + for k, v in self.extra_fields.items(): + if not isinstance(v, torch.Tensor): + v = v.set_size(size) + bbox.add_field(k, v) + + return bbox.convert(self.mode) + + def add_field(self, field, field_data): + self.extra_fields[field] = field_data + + def get_field(self, field): + return self.extra_fields[field] + + def has_field(self, field): + return field in self.extra_fields + + def fields(self): + return list(self.extra_fields.keys()) + + def _copy_extra_fields(self, bbox): + for k, v in bbox.extra_fields.items(): + self.extra_fields[k] = v + + def convert(self, mode): + if mode not in ("xyxy", "xywh"): + raise ValueError("mode should be 'xyxy' or 'xywh'") + if mode == self.mode: + return self + # we only have two modes, so don't need to check + # self.mode + xmin, ymin, xmax, ymax = self._split_into_xyxy() + if mode == "xyxy": + bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1) + bbox = BoxList(bbox, self.size, mode=mode, use_char_ann=self.use_char_ann) + else: + TO_REMOVE = 1 + bbox = torch.cat( + (xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1 + ) + bbox = BoxList(bbox, self.size, mode=mode, use_char_ann=self.use_char_ann) + bbox._copy_extra_fields(self) + return bbox + + def _split_into_xyxy(self): + if self.mode == "xyxy": + xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1) + return xmin, ymin, xmax, ymax + elif self.mode == "xywh": + TO_REMOVE = 1 + xmin, ymin, w, h = self.bbox.split(1, dim=-1) + return ( + xmin, + ymin, + xmin + (w - TO_REMOVE).clamp(min=0), + ymin + (h - TO_REMOVE).clamp(min=0), + ) + else: + raise RuntimeError("Should not be here") + + def resize(self, size, *args, **kwargs): + """ + Returns a resized copy of this bounding box + + :param size: The requested size in pixels, as a 2-tuple: + (width, height). + """ + + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) + if ratios[0] == ratios[1]: + ratio = ratios[0] + scaled_box = self.bbox * ratio + bbox = BoxList( + scaled_box, size, mode=self.mode, use_char_ann=self.use_char_ann + ) + # bbox._copy_extra_fields(self) + for k, v in self.extra_fields.items(): + if not isinstance(v, torch.Tensor): + v = v.resize(size, *args, **kwargs) + bbox.add_field(k, v) + return bbox + + ratio_width, ratio_height = ratios + xmin, ymin, xmax, ymax = self._split_into_xyxy() + scaled_xmin = xmin * ratio_width + scaled_xmax = xmax * ratio_width + scaled_ymin = ymin * ratio_height + scaled_ymax = ymax * ratio_height + scaled_box = torch.cat( + (scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1 + ) + bbox = BoxList(scaled_box, size, mode="xyxy", use_char_ann=self.use_char_ann) + # bbox._copy_extra_fields(self) + for k, v in self.extra_fields.items(): + if not isinstance(v, torch.Tensor): + v = v.resize(size, *args, **kwargs) + bbox.add_field(k, v) + + return bbox.convert(self.mode) + + def poly2box(self, poly): + xmin = min(poly[0::2]) + xmax = max(poly[0::2]) + ymin = min(poly[1::2]) + ymax = max(poly[1::2]) + return [xmin, ymin, xmax, ymax] + + def rotate(self, angle, r_c, start_h, start_w): + masks = self.extra_fields["masks"] + masks = masks.rotate(angle, r_c, start_h, start_w) + polys = masks.polygons + boxes = [] + for poly in polys: + box = self.poly2box(poly.polygons[0].numpy()) + boxes.append(box) + self.size = (r_c[0] * 2, r_c[1] * 2) + bbox = BoxList(boxes, self.size, mode="xyxy", use_char_ann=self.use_char_ann) + for k, v in self.extra_fields.items(): + if k == "masks": + v = masks + else: + if self.use_char_ann: + if not isinstance(v, torch.Tensor): + v = v.rotate(angle, r_c, start_h, start_w) + else: + if not isinstance(v, torch.Tensor) and k != "char_masks": + v = v.rotate(angle, r_c, start_h, start_w) + bbox.add_field(k, v) + return bbox.convert(self.mode) + + def transpose(self, method): + """ + Transpose bounding box (flip or rotate in 90 degree steps) + :param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`, + :py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`, + :py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`, + :py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`. + """ + if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): + raise NotImplementedError( + "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" + ) + + image_width, image_height = self.size + xmin, ymin, xmax, ymax = self._split_into_xyxy() + if method == FLIP_LEFT_RIGHT: + TO_REMOVE = 1 + transposed_xmin = image_width - xmax - TO_REMOVE + transposed_xmax = image_width - xmin - TO_REMOVE + transposed_ymin = ymin + transposed_ymax = ymax + elif method == FLIP_TOP_BOTTOM: + transposed_xmin = xmin + transposed_xmax = xmax + transposed_ymin = image_height - ymax + transposed_ymax = image_height - ymin + + transposed_boxes = torch.cat( + (transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1 + ) + bbox = BoxList( + transposed_boxes, self.size, mode="xyxy", use_char_ann=self.use_char_ann + ) + # bbox._copy_extra_fields(self) + for k, v in self.extra_fields.items(): + if not isinstance(v, torch.Tensor): + v = v.transpose(method) + bbox.add_field(k, v) + return bbox.convert(self.mode) + + def crop(self, box): + """ + Cropss a rectangular region from this bounding box. The box is a + 4-tuple defining the left, upper, right, and lower pixel + coordinate. + """ + xmin, ymin, xmax, ymax = self._split_into_xyxy() + w, h = box[2] - box[0], box[3] - box[1] + cropped_xmin = (xmin - box[0]).clamp(min=0, max=w) + cropped_ymin = (ymin - box[1]).clamp(min=0, max=h) + cropped_xmax = (xmax - box[0]).clamp(min=0, max=w) + cropped_ymax = (ymax - box[1]).clamp(min=0, max=h) + + keep_ind = None + not_empty = np.where( + (cropped_xmin != cropped_xmax) & (cropped_ymin != cropped_ymax) + )[0] + if len(not_empty) > 0: + keep_ind = not_empty + cropped_box = torch.cat( + (cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1 + ) + cropped_box = cropped_box[not_empty] + bbox = BoxList(cropped_box, (w, h), mode="xyxy", use_char_ann=self.use_char_ann) + # bbox._copy_extra_fields(self) + for k, v in self.extra_fields.items(): + if self.use_char_ann: + if not isinstance(v, torch.Tensor): + v = v.crop(box, keep_ind) + else: + if not isinstance(v, torch.Tensor) and k != "char_masks": + v = v.crop(box, keep_ind) + bbox.add_field(k, v) + return bbox.convert(self.mode) + + # Tensor-like methods + + def to(self, device): + bbox = BoxList(self.bbox.to(device), self.size, self.mode, self.use_char_ann) + for k, v in self.extra_fields.items(): + if hasattr(v, "to"): + v = v.to(device) + bbox.add_field(k, v) + return bbox + + def __getitem__(self, item): + bbox = BoxList(self.bbox[item], self.size, self.mode, self.use_char_ann) + for k, v in self.extra_fields.items(): + bbox.add_field(k, v[item]) + return bbox + + def __len__(self): + return self.bbox.shape[0] + + def clip_to_image(self, remove_empty=True): + TO_REMOVE = 1 + self.bbox[:, 0].clamp_(min=0, max=self.size[0] - TO_REMOVE) + self.bbox[:, 1].clamp_(min=0, max=self.size[1] - TO_REMOVE) + self.bbox[:, 2].clamp_(min=0, max=self.size[0] - TO_REMOVE) + self.bbox[:, 3].clamp_(min=0, max=self.size[1] - TO_REMOVE) + if remove_empty: + box = self.bbox + keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) + return self[keep] + return self + + def area(self): + TO_REMOVE = 1 + box = self.bbox + area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE) + return area + + def copy_with_fields(self, fields): + bbox = BoxList(self.bbox, self.size, self.mode, self.use_char_ann) + if not isinstance(fields, (list, tuple)): + fields = [fields] + for field in fields: + bbox.add_field(field, self.get_field(field)) + return bbox + + def __repr__(self): + s = self.__class__.__name__ + "(" + s += "num_boxes={}, ".format(len(self)) + s += "image_width={}, ".format(self.size[0]) + s += "image_height={}, ".format(self.size[1]) + s += "mode={})".format(self.mode) + return s + + +if __name__ == "__main__": + bbox = BoxList([[0, 0, 10, 10], [0, 0, 5, 5]], (10, 10)) + s_bbox = bbox.resize((5, 5)) + print(s_bbox) + print(s_bbox.bbox) + + t_bbox = bbox.transpose(0) + print(t_bbox) + print(t_bbox.bbox) diff --git a/maskrcnn_benchmark/structures/boxlist_ops.py b/maskrcnn_benchmark/structures/boxlist_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..7645f0a8bbddbb0dcad868276cd77ec11d12e894 --- /dev/null +++ b/maskrcnn_benchmark/structures/boxlist_ops.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from maskrcnn_benchmark.layers import nms as _box_nms + +from .bounding_box import BoxList +from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask +import numpy as np +import shapely +from shapely.geometry import Polygon,MultiPoint + +def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="score"): + """ + Performs non-maximum suppression on a boxlist, with scores specified + in a boxlist field via score_field. + + Arguments: + boxlist(BoxList) + nms_thresh (float) + max_proposals (int): if > 0, then only the top max_proposals are kept + after non-maxium suppression + score_field (str) + """ + if nms_thresh <= 0: + return boxlist + mode = boxlist.mode + boxlist = boxlist.convert("xyxy") + boxes = boxlist.bbox + score = boxlist.get_field(score_field) + keep = _box_nms(boxes, score, nms_thresh) + if max_proposals > 0: + keep = keep[:max_proposals] + boxlist = boxlist[keep] + return boxlist.convert(mode) + + +def remove_small_boxes(boxlist, min_size): + """ + Only keep boxes with both sides >= min_size + + Arguments: + boxlist (Boxlist) + min_size (int) + """ + # TODO maybe add an API for querying the ws / hs + xywh_boxes = boxlist.convert("xywh").bbox + _, _, ws, hs = xywh_boxes.unbind(dim=1) + keep = ((ws >= min_size) & (hs >= min_size)).nonzero().squeeze(1) + return boxlist[keep] + + +# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py +# with slight modifications +def boxlist_iou(boxlist1, boxlist2): + """Compute the intersection over union of two set of boxes. + The box order must be (xmin, ymin, xmax, ymax). + + Arguments: + box1: (BoxList) bounding boxes, sized [N,4]. + box2: (BoxList) bounding boxes, sized [M,4]. + + Returns: + (tensor) iou, sized [N,M]. + + Reference: + https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py + """ + if boxlist1.size != boxlist2.size: + raise RuntimeError( + "boxlists should have same image size, got {}, {}".format( + boxlist1, boxlist2 + ) + ) + + # N = len(boxlist1) + # M = len(boxlist2) + + area1 = boxlist1.area() + area2 = boxlist2.area() + + box1, box2 = boxlist1.bbox, boxlist2.bbox + + lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2] + rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2] + + TO_REMOVE = 1 + + wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + iou = inter / (area1[:, None] + area2 - inter) + return iou + +# def boxlist_polygon_iou(target, proposal): +# """Compute the intersection over union of two set of boxes. +# The box order must be (xmin, ymin, xmax, ymax). + +# Arguments: +# box1: (BoxList) bounding boxes, sized [N,4]. +# box2: (BoxList) bounding boxes, sized [M,4]. + +# Returns: +# (tensor) iou, sized [N,M]. + +# Reference: +# https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py +# """ +# if target.size != proposal.size: +# raise RuntimeError( +# "boxlists should have same image size, got {}, {}".format( +# target, proposal +# ) +# ) +# target_polygon = target.get_field("masks").to_np_polygon() +# proposal_polygon = proposal.get_field("masks").to_np_polygon() +# print(target_polygon) +# print(proposal_polygon) +# polygon_points1 = target_polygon[0].reshape(-1, 2) +# poly1 = Polygon(polygon_points1).convex_hull +# polygon_points2 = proposal_polygon[0].reshape(-1, 2) +# poly2 = Polygon(polygon_points2).convex_hull +# union_poly = np.concatenate((polygon_points1, polygon_points2)) +# if not poly1.intersects(poly2): # this test is fast and can accelerate calculation +# iou = 0 +# else: +# try: +# inter_area = poly1.intersection(poly2).area +# #union_area = poly1.area + poly2.area - inter_area +# union_area = MultiPoint(union_poly).convex_hull.area +# if union_area == 0: +# return 0 +# iou = float(inter_area) / union_area +# except shapely.geos.TopologicalError: +# print('shapely.geos.TopologicalError occured, iou set to 0') +# iou = 0 +# return iou + + +# TODO redundant, remove +def _cat(tensors, dim=0): + """ + Efficient version of torch.cat + avoids a copy if there is only a single element in a list + """ + assert isinstance(tensors, (list, tuple)) + if len(tensors) == 1: + return tensors[0] + return torch.cat(tensors, dim) + +def _cat_mask(masks): + polygons_cat = [] + size = masks[0].size + for mask in masks: + polygons = mask.get_polygons() + polygons_cat.extend(polygons) + masks_cat = SegmentationMask(polygons_cat, size) + return masks_cat + + +def cat_boxlist(bboxes): + """ + Concatenates a list of BoxList (having the same image size) into a + single BoxList + + Arguments: + bboxes (list[BoxList]) + """ + # if bboxes is None: + # return None + # if bboxes[0] is None: + # bboxes = [bboxes[1] + assert isinstance(bboxes, (list, tuple)) + assert all(isinstance(bbox, BoxList) for bbox in bboxes) + + size = bboxes[0].size + assert all(bbox.size == size for bbox in bboxes) + + mode = bboxes[0].mode + assert all(bbox.mode == mode for bbox in bboxes) + + fields = set(bboxes[0].fields()) + assert all(set(bbox.fields()) == fields for bbox in bboxes) + + cat_boxes = BoxList(_cat([bbox.bbox for bbox in bboxes], dim=0), size, mode) + + for field in fields: + if field == 'masks': + data = _cat_mask([bbox.get_field(field) for bbox in bboxes]) + else: + data = _cat([bbox.get_field(field) for bbox in bboxes], dim=0) + cat_boxes.add_field(field, data) + + return cat_boxes + + +def cat_boxlist_gt(bboxes): + """ + Concatenates a list of BoxList (having the same image size) into a + single BoxList + + Arguments: + bboxes (list[BoxList]) + """ + assert isinstance(bboxes, (list, tuple)) + assert all(isinstance(bbox, BoxList) for bbox in bboxes) + + size = bboxes[0].size + # bboxes[1].set_size(size) + assert all(bbox.size == size for bbox in bboxes) + + mode = bboxes[0].mode + assert all(bbox.mode == mode for bbox in bboxes) + + fields = set(bboxes[0].fields()) + assert all(set(bbox.fields()) == fields for bbox in bboxes) + if bboxes[0].bbox.sum().item() == 0: + cat_boxes = BoxList(bboxes[1].bbox, size, mode) + else: + cat_boxes = BoxList(_cat([bbox.bbox for bbox in bboxes], dim=0), size, mode) + + for field in fields: + if bboxes[0].bbox.sum().item() == 0: + if field == 'masks': + data = _cat_mask([bbox.get_field(field) for bbox in bboxes[1:]]) + else: + data = _cat([bbox.get_field(field) for bbox in bboxes[1:]], dim=0) + else: + if field == 'masks': + data = _cat_mask([bbox.get_field(field) for bbox in bboxes]) + else: + data = _cat([bbox.get_field(field) for bbox in bboxes], dim=0) + cat_boxes.add_field(field, data) + + return cat_boxes diff --git a/maskrcnn_benchmark/structures/image_list.py b/maskrcnn_benchmark/structures/image_list.py new file mode 100644 index 0000000000000000000000000000000000000000..42b0b716131a3fc3e019143d0fa52fbd53fda37d --- /dev/null +++ b/maskrcnn_benchmark/structures/image_list.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + + +class ImageList(object): + """ + Structure that holds a list of images (of possibly + varying sizes) as a single tensor. + This works by padding the images to the same size, + and storing in a field the original sizes of each image + """ + + def __init__(self, tensors, image_sizes): + """ + Arguments: + tensors (tensor) + image_sizes (list[tuple[int, int]]) + """ + self.tensors = tensors + self.image_sizes = image_sizes + + def to(self, *args, **kwargs): + cast_tensor = self.tensors.to(*args, **kwargs) + return ImageList(cast_tensor, self.image_sizes) + + def get_sizes(self): + return self.image_sizes + + +def to_image_list(tensors, size_divisible=0): + """ + tensors can be an ImageList, a torch.Tensor or + an iterable of Tensors. It can't be a numpy array. + When tensors is an iterable of Tensors, it pads + the Tensors with zeros so that they have the same + shape + """ + if isinstance(tensors, torch.Tensor) and size_divisible > 0: + tensors = [tensors] + + if isinstance(tensors, ImageList): + return tensors + elif isinstance(tensors, torch.Tensor): + # single tensor shape can be inferred + assert tensors.dim() == 4 + image_sizes = [tensor.shape[-2:] for tensor in tensors] + return ImageList(tensors, image_sizes) + elif isinstance(tensors, (tuple, list)): + max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) + + # TODO Ideally, just remove this and let me model handle arbitrary + # input sizs + if size_divisible > 0: + import math + + stride = size_divisible + max_size = list(max_size) + max_size[1] = int(math.ceil(max_size[1] / stride) * stride) + max_size[2] = int(math.ceil(max_size[2] / stride) * stride) + max_size = tuple(max_size) + + batch_shape = (len(tensors),) + max_size + batched_imgs = tensors[0].new(*batch_shape).zero_() + for img, pad_img in zip(tensors, batched_imgs): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + + image_sizes = [im.shape[-2:] for im in tensors] + + return ImageList(batched_imgs, image_sizes) + else: + raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors))) + + +def to_image_target_list(tensors, size_divisible=0, targets=None): + """ + tensors can be an ImageList, a torch.Tensor or + an iterable of Tensors. It can't be a numpy array. + When tensors is an iterable of Tensors, it pads + the Tensors with zeros so that they have the same + shape + """ + if isinstance(tensors, torch.Tensor) and size_divisible > 0: + tensors = [tensors] + + if isinstance(tensors, ImageList): + return tensors + elif isinstance(tensors, torch.Tensor): + # single tensor shape can be inferred + assert tensors.dim() == 4 + image_sizes = [tensor.shape[-2:] for tensor in tensors] + return ImageList(tensors, image_sizes) + elif isinstance(tensors, (tuple, list)): + max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) + + # TODO Ideally, just remove this and let me model handle arbitrary + # input sizs + if size_divisible > 0: + import math + + stride = size_divisible + max_size = list(max_size) + max_size[1] = int(math.ceil(max_size[1] / stride) * stride) + max_size[2] = int(math.ceil(max_size[2] / stride) * stride) + max_size = tuple(max_size) + + batch_shape = (len(tensors),) + max_size + batched_imgs = tensors[0].new(*batch_shape).zero_() + if targets is None: + for img, pad_img in zip(tensors, batched_imgs): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + else: + for img, pad_img, target in zip(tensors, batched_imgs, targets): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + if target is not None: + target.set_size((pad_img.shape[2], pad_img.shape[1])) + + image_sizes = [im.shape[-2:] for im in tensors] + + return ImageList(batched_imgs, image_sizes), targets + else: + raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors))) \ No newline at end of file diff --git a/maskrcnn_benchmark/structures/segmentation_mask.py b/maskrcnn_benchmark/structures/segmentation_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..0d1b0722d4f3366de11166dca7803dd914ef6bd7 --- /dev/null +++ b/maskrcnn_benchmark/structures/segmentation_mask.py @@ -0,0 +1,766 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import cv2 +import numpy as np +import pycocotools.mask as mask_utils +import torch +from maskrcnn_benchmark.utils.chars import char2num +import pyclipper +# from PIL import Image +from shapely import affinity +from shapely.geometry import Polygon as ShapePolygon + + +# transpose +FLIP_LEFT_RIGHT = 0 +FLIP_TOP_BOTTOM = 1 + + +def convert_2d_tuple(t): + a = [] + for i in t: + a.extend(list(i)) + return a + + +class Mask(object): + """ + This class is unfinished and not meant for use yet + It is supposed to contain the mask for an object as + a 2d tensor + """ + + def __init__(self, masks, size, mode): + self.masks = masks + self.size = size + self.mode = mode + + def transpose(self, method): + if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): + raise NotImplementedError( + "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" + ) + + width, height = self.size + if method == FLIP_LEFT_RIGHT: + dim = width + # idx = 2 + elif method == FLIP_TOP_BOTTOM: + dim = height + # idx = 1 + + flip_idx = list(range(dim)[::-1]) + flipped_masks = self.masks.index_select(dim, flip_idx) + return Mask(flipped_masks, self.size, self.mode) + + def crop(self, box): + w, h = box[2] - box[0], box[3] - box[1] + + cropped_masks = self.masks[:, box[1] : box[3], box[0] : box[2]] + return Mask(cropped_masks, size=(w, h), mode=self.mode) + + def resize(self, size, *args, **kwargs): + pass + + +class SegmentationMask(object): + """ + This class stores the segmentations for all objects in the image + """ + + def __init__(self, polygons, size, mode=None): + """ + Arguments: + polygons: a list of list of lists of numbers. The first + level of the list correspond to individual instances, + the second level to all the polygons that compose the + object, and the third level to the polygon coordinates. + """ + assert isinstance(polygons, list) + + self.polygons = [Polygons(p, size, mode) for p in polygons] + self.size = size + self.mode = mode + + def transpose(self, method): + if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): + raise NotImplementedError( + "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" + ) + + flipped = [] + for polygon in self.polygons: + flipped.append(polygon.transpose(method)) + return SegmentationMask(flipped, size=self.size, mode=self.mode) + + def crop(self, box, keep_ind=None): + w, h = box[2] - box[0], box[3] - box[1] + if keep_ind is not None: + self.polygons = np.array(self.polygons) + self.polygons = self.polygons[keep_ind] + cropped = [] + for polygon in self.polygons: + cropped.append(polygon.crop(box)) + return SegmentationMask(cropped, size=(w, h), mode=self.mode) + + def rotate(self, angle, r_c, start_h, start_w): + rotated = [] + for polygon in self.polygons: + rotated.append(polygon.rotate(angle, r_c, start_h, start_w)) + return SegmentationMask(rotated, size=(r_c[0] * 2, r_c[1] * 2), mode=self.mode) + + def resize(self, size, *args, **kwargs): + scaled = [] + for polygon in self.polygons: + scaled.append(polygon.resize(size, *args, **kwargs)) + return SegmentationMask(scaled, size=size, mode=self.mode) + + def set_size(self, size): + self.size = size + for polygon in self.polygons: + polygon.set_size(size) + + def to(self, *args, **kwargs): + return self + + def __getitem__(self, item): + if isinstance(item, (int, slice)): + selected_polygons = [self.polygons[item]] + else: + # advanced indexing on a single dimension + selected_polygons = [] + if isinstance(item, torch.Tensor) and item.dtype == torch.bool: + item = item.nonzero() + item = item.squeeze(1) if item.numel() > 0 else item + item = item.tolist() + for i in item: + selected_polygons.append(self.polygons[i]) + return SegmentationMask(selected_polygons, size=self.size, mode=self.mode) + + def __iter__(self): + return iter(self.polygons) + + def __repr__(self): + s = self.__class__.__name__ + "(" + s += "num_instances={}, ".format(len(self.polygons)) + s += "image_width={}, ".format(self.size[0]) + s += "image_height={})".format(self.size[1]) + return s + + def size(self): + return self.size + + def get_polygons(self): + return self.polygons + + def to_np_polygon(self): + np_polygons = [] + for polygon in self.polygons: + polys = polygon.get_polygons() + for poly in polys: + np_poly = poly.numpy() + np_polygons.append(np_poly) + return np_polygons + + + def convert_seg_map(self, labels, shrink_ratio, seg_size, ignore_difficult=True): + # width, height = self.size + # assert self.size[0] == seg_size[1] + # assert self.size[1] == seg_size[0] + height, width = seg_size[0], seg_size[1] + seg_map = np.zeros((1, height, width), dtype=np.uint8) + training_mask = np.ones((height, width), dtype=np.uint8) + for poly, label in zip(self.polygons, labels): + poly = poly.get_polygons()[0] + poly = poly.reshape((-1, 2)).numpy() + if ignore_difficult and label.item() == -1: + cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) + continue + if poly.shape[0] < 4: + continue + p = ShapePolygon(poly) + if p.length == 0: + continue + try: + d = p.area * (1 - np.power(shrink_ratio, 2)) / p.length + except: + continue + subj = [tuple(s) for s in poly] + pco = pyclipper.PyclipperOffset() + pco.AddPath(subj, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + s = pco.Execute(-d) + if s == []: + cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) + continue + out = convert_2d_tuple(s[0]) + out = np.array(out).reshape(-1, 2) + cv2.fillPoly(seg_map[0, :, :], [out.astype(np.int32)], 1) + return seg_map, training_mask + + +class Polygons(object): + """ + This class holds a set of polygons that represents a single instance + of an object mask. The object can be represented as a set of + polygons + """ + + def __init__(self, polygons, size, mode): + # assert isinstance(polygons, list), '{}'.format(polygons) + if isinstance(polygons, list): + polygons = [torch.as_tensor(p, dtype=torch.float32) for p in polygons] + elif isinstance(polygons, Polygons): + polygons = polygons.polygons + + self.polygons = polygons + self.size = size + self.mode = mode + + def transpose(self, method): + if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): + raise NotImplementedError( + "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" + ) + + flipped_polygons = [] + width, height = self.size + if method == FLIP_LEFT_RIGHT: + dim = width + idx = 0 + elif method == FLIP_TOP_BOTTOM: + dim = height + idx = 1 + + for poly in self.polygons: + p = poly.clone() + TO_REMOVE = 1 + p[idx::2] = dim - poly[idx::2] - TO_REMOVE + flipped_polygons.append(p) + + return Polygons(flipped_polygons, size=self.size, mode=self.mode) + + def rotate(self, angle, r_c, start_h, start_w): + poly = self.polygons[0].numpy().reshape(-1, 2) + poly[:, 0] += start_w + poly[:, 1] += start_h + polys = ShapePolygon(poly) + r_polys = list(affinity.rotate(polys, angle, r_c).boundary.coords[:-1]) + p = [] + for r in r_polys: + p += list(r) + return Polygons([p], size=self.size, mode=self.mode) + + def crop(self, box): + w, h = box[2] - box[0], box[3] - box[1] + + # TODO chck if necessary + w = max(w, 1) + h = max(h, 1) + + cropped_polygons = [] + + for poly in self.polygons: + p = poly.clone() + p[0::2] = p[0::2] - box[0] # .clamp(min=0, max=w) + p[1::2] = p[1::2] - box[1] # .clamp(min=0, max=h) + cropped_polygons.append(p) + + return Polygons(cropped_polygons, size=(w, h), mode=self.mode) + + def resize(self, size, *args, **kwargs): + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) + if ratios[0] == ratios[1]: + ratio = ratios[0] + scaled_polys = [p * ratio for p in self.polygons] + return Polygons(scaled_polys, size, mode=self.mode) + + ratio_w, ratio_h = ratios + scaled_polygons = [] + for poly in self.polygons: + p = poly.clone() + p[0::2] *= ratio_w + p[1::2] *= ratio_h + scaled_polygons.append(p) + + return Polygons(scaled_polygons, size=size, mode=self.mode) + + def convert(self, mode): + width, height = self.size + if mode == "mask": + # print([p.numpy() for p in self.polygons]) + try: + rles = mask_utils.frPyObjects( + [p.numpy() for p in self.polygons], height, width + ) + except: + print([p.numpy() for p in self.polygons]) + mask = torch.ones((height, width), dtype=torch.uint8) + return mask + rle = mask_utils.merge(rles) + mask = mask_utils.decode(rle) + mask = torch.from_numpy(mask) + # TODO add squeeze? + return mask + + def set_size(self, size): + self.size = size + + def get_polygons(self): + return self.polygons + + def __repr__(self): + s = self.__class__.__name__ + "(" + s += "num_polygons={}, ".format(len(self.polygons)) + s += "image_width={}, ".format(self.size[0]) + s += "image_height={}, ".format(self.size[1]) + s += "mode={})".format(self.mode) + return s + + +class CharPolygons(object): + """ + This class holds a set of polygons that represents a single instance + of an object mask. The object can be represented as a set of + polygons + """ + + def __init__( + self, + char_boxes, + word=None, + use_char_ann=False, + char_classes=None, + size=None, + mode=None, + char_num_classes=37, + ): + if isinstance(char_boxes, CharPolygons): + if char_classes is None: + char_classes = char_boxes.char_classes + self.word = char_boxes.word + char_boxes = char_boxes.char_boxes + else: + if char_classes is None: + char_classes = [ + torch.as_tensor(p[8], dtype=torch.float32) for p in char_boxes + ] + char_boxes = [ + torch.as_tensor(p[:8], dtype=torch.float32) for p in char_boxes + ] + self.word = word + self.char_boxes = char_boxes + self.char_classes = char_classes + self.size = size + self.mode = mode + self.use_char_ann = use_char_ann + self.char_num_classes = char_num_classes + + def transpose(self, method): + if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): + raise NotImplementedError( + "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" + ) + + flipped_polygons = [] + width, height = self.size + if method == FLIP_LEFT_RIGHT: + dim = width + idx = 0 + elif method == FLIP_TOP_BOTTOM: + dim = height + idx = 1 + + for char_box in self.char_boxes: + p = char_box.clone() + TO_REMOVE = 1 + p[idx::2] = dim - char_box[idx::2] - TO_REMOVE + flipped_polygons.append(p) + + return CharPolygons( + flipped_polygons, + word=self.word, + use_char_ann=self.use_char_ann, + char_classes=self.char_classes, + size=self.size, + mode=self.mode, + char_num_classes=self.char_num_classes, + ) + + def crop(self, box): + w, h = box[2] - box[0], box[3] - box[1] + + # TODO chck if necessary + w = max(w, 1) + h = max(h, 1) + cropped_polygons = [] + for char_box in self.char_boxes: + p = char_box.clone() + p[0::2] = p[0::2] - box[0] # .clamp(min=0, max=w) + p[1::2] = p[1::2] - box[1] # .clamp(min=0, max=h) + cropped_polygons.append(p) + + return CharPolygons( + cropped_polygons, + word=self.word, + use_char_ann=self.use_char_ann, + char_classes=self.char_classes, + size=(w, h), + mode=self.mode, + char_num_classes=self.char_num_classes, + ) + + def rotate(self, angle, r_c, start_h, start_w): + r_polys = [] + for poly in self.char_boxes: + poly = poly.numpy() + poly[0::2] += start_w + poly[1::2] += start_h + poly = ShapePolygon(np.array(poly).reshape(4, 2)) + r_poly = np.array( + list(affinity.rotate(poly, angle, r_c).boundary.coords[:-1]) + ).reshape(-1, 8) + r_polys.append(r_poly[0]) + return CharPolygons( + r_polys, + word=self.word, + use_char_ann=self.use_char_ann, + char_classes=self.char_classes, + size=(r_c[0] * 2, r_c[1] * 2), + mode=self.mode, + char_num_classes=self.char_num_classes, + ) + + def resize(self, size, *args, **kwargs): + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) + if ratios[0] == ratios[1]: + ratio = ratios[0] + scaled_polys = [p * ratio for p in self.char_boxes] + return CharPolygons( + scaled_polys, + word=self.word, + use_char_ann=self.use_char_ann, + char_classes=self.char_classes, + size=size, + mode=self.mode, + char_num_classes=self.char_num_classes, + ) + + ratio_w, ratio_h = ratios + scaled_polygons = [] + for poly in self.char_boxes: + p = poly.clone() + p[0::2] *= ratio_w + p[1::2] *= ratio_h + scaled_polygons.append(p) + + return CharPolygons( + scaled_polygons, + word=self.word, + use_char_ann=self.use_char_ann, + char_classes=self.char_classes, + size=size, + mode=self.mode, + char_num_classes=self.char_num_classes, + ) + + def set_size(self, size): + self.size = size + + def convert(self, mode): + width, height = self.size + if mode == "char_mask": + if not self.use_char_ann: + char_map = -np.ones((height, width)) + char_map_weight = np.zeros((self.char_num_classes,)) + else: + char_map = np.zeros((height, width)) + char_map_weight = np.ones((self.char_num_classes,)) + for i, p in enumerate(self.char_boxes): + poly = p.numpy().reshape(4, 2) + poly = shrink_poly(poly, 0.25) + cv2.fillPoly( + char_map, [poly.astype(np.int32)], int(self.char_classes[i]) + ) + pos_index = np.where(char_map > 0) + pos_num = pos_index[0].size + if pos_num > 0: + pos_weight = 1.0 * (height * width - pos_num) / pos_num + char_map_weight[1:] = pos_weight + return torch.from_numpy(char_map), torch.from_numpy(char_map_weight) + elif mode == "seq_char_mask": + decoder_target = self.char_num_classes * np.ones((32,)) + word_target = -np.ones((32,)) + if not self.use_char_ann: + char_map = -np.ones((height, width)) + char_map_weight = np.zeros((self.char_num_classes,)) + for i, char in enumerate(self.word): + if i > 31: + break + decoder_target[i] = char2num(char) + word_target[i] = char2num(char) + end_point = min(max(1, len(self.word)), 31) + word_target[end_point] = self.char_num_classes + else: + char_map = np.zeros((height, width)) + char_map_weight = np.ones((self.char_num_classes,)) + word_length = 0 + for i, p in enumerate(self.char_boxes): + poly = p.numpy().reshape(4, 2) + if i < 32: + decoder_target[i] = int(self.char_classes[i]) + word_target[i] = int(self.char_classes[i]) + word_length += 1 + poly = shrink_poly(poly, 0.25) + cv2.fillPoly( + char_map, [poly.astype(np.int32)], int(self.char_classes[i]) + ) + end_point = min(max(1, word_length), 31) + word_target[end_point] = self.char_num_classes + pos_index = np.where(char_map > 0) + pos_num = pos_index[0].size + if pos_num > 0: + pos_weight = 1.0 * (height * width - pos_num) / pos_num + char_map_weight[1:] = pos_weight + return ( + torch.from_numpy(char_map), + torch.from_numpy(char_map_weight), + torch.from_numpy(decoder_target), + torch.from_numpy(word_target), + ) + + def creat_color_map(self, n_class, width): + splits = int(np.ceil(np.power((n_class * 1.0), 1.0 / 3))) + maps = [] + for i in range(splits): + r = int(i * width * 1.0 / (splits - 1)) + for j in range(splits): + g = int(j * width * 1.0 / (splits - 1)) + for k in range(splits - 1): + b = int(k * width * 1.0 / (splits - 1)) + maps.append([r, g, b]) + return np.array(maps) + + def __repr__(self): + s = self.__class__.__name__ + "(" + s += "num_char_boxes={}, ".format(len(self.char_boxes)) + s += "num_char_classes={}, ".format(len(self.char_classes)) + s += "image_width={}, ".format(self.size[0]) + s += "image_height={}, ".format(self.size[1]) + s += "mode={})".format(self.mode) + return s + + +class SegmentationCharMask(object): + def __init__( + self, chars_boxes, words=None, use_char_ann=True, size=None, mode=None, char_num_classes=37 + ): + # self.chars_boxes=[CharPolygons(char_boxes, word=word, use_char_ann=use_char_ann, size=size, mode=mode) for char_boxes, word in zip(chars_boxes, words)] + if words is None: + self.chars_boxes = [ + CharPolygons( + char_boxes, + word=None, + use_char_ann=use_char_ann, + size=size, + mode=mode, + char_num_classes=char_num_classes, + ) + for char_boxes in chars_boxes + ] + else: + self.chars_boxes = [ + CharPolygons( + char_boxes, + word=words[i], + use_char_ann=use_char_ann, + size=size, + mode=mode, + char_num_classes=char_num_classes, + ) + for i, char_boxes in enumerate(chars_boxes) + ] + self.size = size + self.mode = mode + self.use_char_ann = use_char_ann + self.char_num_classes = char_num_classes + + def transpose(self, method): + if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): + raise NotImplementedError( + "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" + ) + + flipped = [] + for char_boxes in self.chars_boxes: + flipped.append(char_boxes.transpose(method)) + return SegmentationCharMask( + flipped, use_char_ann=self.use_char_ann, size=self.size, mode=self.mode, char_num_classes=self.char_num_classes + ) + + def crop(self, box, keep_ind): + cropped = [] + w, h = box[2] - box[0], box[3] - box[1] + if keep_ind is not None: + self.chars_boxes = np.array(self.chars_boxes) + self.chars_boxes = self.chars_boxes[keep_ind] + for char_boxes in self.chars_boxes: + cropped.append(char_boxes.crop(box)) + return SegmentationCharMask( + cropped, use_char_ann=self.use_char_ann, size=(w, h), mode=self.mode + ) + + def resize(self, size, *args, **kwargs): + scaled = [] + for char_boxes in self.chars_boxes: + scaled.append(char_boxes.resize(size, *args, **kwargs)) + return SegmentationCharMask( + scaled, use_char_ann=self.use_char_ann, size=size, mode=self.mode, char_num_classes=self.char_num_classes + ) + + def set_size(self, size): + self.size = size + for char_box in self.chars_boxes: + char_box.set_size(size) + + def rotate(self, angle, r_c, start_h, start_w): + rotated = [] + for char_boxes in self.chars_boxes: + rotated.append(char_boxes.rotate(angle, r_c, start_h, start_w)) + return SegmentationCharMask( + rotated, + use_char_ann=self.use_char_ann, + size=(r_c[0] * 2, r_c[1] * 2), + mode=self.mode, + char_num_classes=self.char_num_classes, + ) + + def __iter__(self): + return iter(self.chars_boxes) + + def __getitem__(self, item): + if isinstance(item, (int, slice)): + selected_chars_boxes = [self.chars_boxes[item]] + else: + # advanced indexing on a single dimension + selected_chars_boxes = [] + if isinstance(item, torch.Tensor) and item.dtype == torch.bool: + item = item.nonzero() + item = item.squeeze(1) if item.numel() > 0 else item + item = item.tolist() + for i in item: + if i >= len(self.chars_boxes): + print(i) + print("chars_boxes.shape: ", len(self.chars_boxes)) + input() + selected_chars_boxes.append(self.chars_boxes[i]) + return SegmentationCharMask( + selected_chars_boxes, + use_char_ann=self.use_char_ann, + size=self.size, + mode=self.mode, + char_num_classes=self.char_num_classes, + ) + + def __repr__(self): + s = self.__class__.__name__ + "(" + s += "num_char_boxes={}, ".format(len(self.chars_boxes)) + s += "image_width={}, ".format(self.size[0]) + s += "image_height={})".format(self.size[1]) + return s + + +def shrink_poly(poly, shrink): + # shrink ratio + R = shrink + r = [None, None, None, None] + for i in range(4): + r[i] = min( + np.linalg.norm(poly[i] - poly[(i + 1) % 4]), + np.linalg.norm(poly[i] - poly[(i - 1) % 4]), + ) + # find the longer pair + if np.linalg.norm(poly[0] - poly[1]) + np.linalg.norm( + poly[2] - poly[3] + ) > np.linalg.norm(poly[0] - poly[3]) + np.linalg.norm(poly[1] - poly[2]): + # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2) + ## p0, p1 + theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0])) + poly[0][0] += R * r[0] * np.cos(theta) + poly[0][1] += R * r[0] * np.sin(theta) + poly[1][0] -= R * r[1] * np.cos(theta) + poly[1][1] -= R * r[1] * np.sin(theta) + ## p2, p3 + theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0])) + poly[3][0] += R * r[3] * np.cos(theta) + poly[3][1] += R * r[3] * np.sin(theta) + poly[2][0] -= R * r[2] * np.cos(theta) + poly[2][1] -= R * r[2] * np.sin(theta) + ## p0, p3 + theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1])) + poly[0][0] += R * r[0] * np.sin(theta) + poly[0][1] += R * r[0] * np.cos(theta) + poly[3][0] -= R * r[3] * np.sin(theta) + poly[3][1] -= R * r[3] * np.cos(theta) + ## p1, p2 + theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1])) + poly[1][0] += R * r[1] * np.sin(theta) + poly[1][1] += R * r[1] * np.cos(theta) + poly[2][0] -= R * r[2] * np.sin(theta) + poly[2][1] -= R * r[2] * np.cos(theta) + else: + ## p0, p3 + # print poly + theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1])) + poly[0][0] += R * r[0] * np.sin(theta) + poly[0][1] += R * r[0] * np.cos(theta) + poly[3][0] -= R * r[3] * np.sin(theta) + poly[3][1] -= R * r[3] * np.cos(theta) + ## p1, p2 + theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1])) + poly[1][0] += R * r[1] * np.sin(theta) + poly[1][1] += R * r[1] * np.cos(theta) + poly[2][0] -= R * r[2] * np.sin(theta) + poly[2][1] -= R * r[2] * np.cos(theta) + ## p0, p1 + theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0])) + poly[0][0] += R * r[0] * np.cos(theta) + poly[0][1] += R * r[0] * np.sin(theta) + poly[1][0] -= R * r[1] * np.cos(theta) + poly[1][1] -= R * r[1] * np.sin(theta) + ## p2, p3 + theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0])) + poly[3][0] += R * r[3] * np.cos(theta) + poly[3][1] += R * r[3] * np.sin(theta) + poly[2][0] -= R * r[2] * np.cos(theta) + poly[2][1] -= R * r[2] * np.sin(theta) + return poly + + +def shrink_rect(poly, shrink): + xmin = min(poly[:, 0]) + xmax = max(poly[:, 0]) + ymin = min(poly[:, 1]) + ymax = max(poly[:, 1]) + # assert xmax > xmin and ymax > ymin + xc = (xmax + xmin) / 2 + yc = (ymax + ymin) / 2 + w = xmax - xmin + h = ymax - ymin + sxmin = xc - w / 2 * shrink + sxmax = xc + w / 2 * shrink + symin = yc - h / 2 * shrink + symax = yc + h / 2 * shrink + return np.array([sxmin, symin, sxmax, symin, sxmax, symax, sxmin, symax]).reshape( + (4, 2) + ) + + +def is_poly_inbox(poly, height, width): + min_x = min(poly[:, 0]) + min_y = min(poly[:, 1]) + max_x = max(poly[:, 0]) + max_y = max(poly[:, 1]) + if (max_x < 0 and max_y < 0) or (min_x > width and min_y > height): + return False + else: + return True diff --git a/maskrcnn_benchmark/utils/README.md b/maskrcnn_benchmark/utils/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9765b24a730b77556104187ac3ef5439ab0859fd --- /dev/null +++ b/maskrcnn_benchmark/utils/README.md @@ -0,0 +1,5 @@ +# Utility functions + +This folder contain utility functions that are not used in the +core library, but are useful for building models or training +code using the config system. diff --git a/maskrcnn_benchmark/utils/__init__.py b/maskrcnn_benchmark/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/maskrcnn_benchmark/utils/c2_model_loading.py b/maskrcnn_benchmark/utils/c2_model_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..4cc6ecad1c848fc675b13821488ce3f305e7970c --- /dev/null +++ b/maskrcnn_benchmark/utils/c2_model_loading.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging +import pickle +from collections import OrderedDict + +import torch + +from maskrcnn_benchmark.utils.model_serialization import load_state_dict + + +def _rename_basic_resnet_weights(layer_keys): + layer_keys = [k.replace("_", ".") for k in layer_keys] + layer_keys = [k.replace(".w", ".weight") for k in layer_keys] + layer_keys = [k.replace(".bn", "_bn") for k in layer_keys] + layer_keys = [k.replace(".b", ".bias") for k in layer_keys] + layer_keys = [k.replace("_bn.s", "_bn.scale") for k in layer_keys] + layer_keys = [k.replace(".biasranch", ".branch") for k in layer_keys] + layer_keys = [k.replace("bbox.pred", "bbox_pred") for k in layer_keys] + layer_keys = [k.replace("cls.score", "cls_score") for k in layer_keys] + layer_keys = [k.replace("res.conv1_", "conv1_") for k in layer_keys] + + # RPN / Faster RCNN + layer_keys = [k.replace(".biasbox", ".bbox") for k in layer_keys] + layer_keys = [k.replace("conv.rpn", "rpn.conv") for k in layer_keys] + layer_keys = [k.replace("rpn.bbox.pred", "rpn.bbox_pred") for k in layer_keys] + layer_keys = [k.replace("rpn.cls.logits", "rpn.cls_logits") for k in layer_keys] + + # Affine-Channel -> BatchNorm enaming + layer_keys = [k.replace("_bn.scale", "_bn.weight") for k in layer_keys] + + # Make torchvision-compatible + layer_keys = [k.replace("conv1_bn.", "bn1.") for k in layer_keys] + + layer_keys = [k.replace("res2.", "layer1.") for k in layer_keys] + layer_keys = [k.replace("res3.", "layer2.") for k in layer_keys] + layer_keys = [k.replace("res4.", "layer3.") for k in layer_keys] + layer_keys = [k.replace("res5.", "layer4.") for k in layer_keys] + + layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys] + layer_keys = [k.replace(".branch2a_bn.", ".bn1.") for k in layer_keys] + layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys] + layer_keys = [k.replace(".branch2b_bn.", ".bn2.") for k in layer_keys] + layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys] + layer_keys = [k.replace(".branch2c_bn.", ".bn3.") for k in layer_keys] + + layer_keys = [k.replace(".branch1.", ".downsample.0.") for k in layer_keys] + layer_keys = [k.replace(".branch1_bn.", ".downsample.1.") for k in layer_keys] + + return layer_keys + +def _rename_fpn_weights(layer_keys, stage_names): + for mapped_idx, stage_name in enumerate(stage_names, 1): + suffix = "" + if mapped_idx < 4: + suffix = ".lateral" + layer_keys = [ + k.replace("fpn.inner.layer{}.sum{}".format(stage_name, suffix), "fpn_inner{}".format(mapped_idx)) for k in layer_keys + ] + layer_keys = [k.replace("fpn.layer{}.sum".format(stage_name), "fpn_layer{}".format(mapped_idx)) for k in layer_keys] + + + layer_keys = [k.replace("rpn.conv.fpn2", "rpn.conv") for k in layer_keys] + layer_keys = [k.replace("rpn.bbox_pred.fpn2", "rpn.bbox_pred") for k in layer_keys] + layer_keys = [ + k.replace("rpn.cls_logits.fpn2", "rpn.cls_logits") for k in layer_keys + ] + + return layer_keys + + +def _rename_weights_for_resnet(weights, stage_names): + original_keys = sorted(weights.keys()) + layer_keys = sorted(weights.keys()) + + # for X-101, rename output to fc1000 to avoid conflicts afterwards + layer_keys = [k if k != "pred_b" else "fc1000_b" for k in layer_keys] + layer_keys = [k if k != "pred_w" else "fc1000_w" for k in layer_keys] + + # performs basic renaming: _ -> . , etc + layer_keys = _rename_basic_resnet_weights(layer_keys) + + # FPN + layer_keys = _rename_fpn_weights(layer_keys, stage_names) + + # Mask R-CNN + layer_keys = [k.replace("mask.fcn.logits", "mask_fcn_logits") for k in layer_keys] + layer_keys = [k.replace(".[mask].fcn", "mask_fcn") for k in layer_keys] + layer_keys = [k.replace("conv5.mask", "conv5_mask") for k in layer_keys] + + # Keypoint R-CNN + layer_keys = [k.replace("kps.score.lowres", "kps_score_lowres") for k in layer_keys] + layer_keys = [k.replace("kps.score", "kps_score") for k in layer_keys] + layer_keys = [k.replace("conv.fcn", "conv_fcn") for k in layer_keys] + + # Rename for our RPN structure + layer_keys = [k.replace("rpn.", "rpn.head.") for k in layer_keys] + + key_map = {k: v for k, v in zip(original_keys, layer_keys)} + + logger = logging.getLogger(__name__) + logger.info("Remapping C2 weights") + max_c2_key_size = max([len(k) for k in original_keys if "_momentum" not in k]) + + new_weights = OrderedDict() + for k in original_keys: + v = weights[k] + if "_momentum" in k: + continue + # if 'fc1000' in k: + # continue + w = torch.from_numpy(v) + # if "bn" in k: + # w = w.view(1, -1, 1, 1) + logger.info("C2 name: {: <{}} mapped name: {}".format(k, max_c2_key_size, key_map[k])) + new_weights[key_map[k]] = w + + return new_weights + + +def _load_c2_pickled_weights(file_path): + with open(file_path, "rb") as f: + data = pickle.load(f, encoding="latin1") + if "blobs" in data: + weights = data["blobs"] + else: + weights = data + return weights + +def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg): + import re + logger = logging.getLogger(__name__) + logger.info("Remapping conv weights for deformable conv weights") + layer_keys = sorted(state_dict.keys()) + for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1): + if not stage_with_dcn: + continue + for old_key in layer_keys: + pattern = ".*layer{}.*conv2.*".format(ix) + r = re.match(pattern, old_key) + if r is None: + continue + for param in ["weight", "bias"]: + if old_key.find(param) is -1: + continue + new_key = old_key.replace( + "conv2.{}".format(param), "conv2.conv.{}".format(param) + ) + logger.info("pattern: {}, old_key: {}, new_key: {}".format( + pattern, old_key, new_key + )) + state_dict[new_key] = state_dict[old_key] + del state_dict[old_key] + return state_dict + + +_C2_STAGE_NAMES = { + "R-50": ["1.2", "2.3", "3.5", "4.2"], + "R-101": ["1.2", "2.3", "3.22", "4.2"], +} + +def load_c2_format(cfg, f): + # TODO make it support other architectures + state_dict = _load_c2_pickled_weights(f) + conv_body = cfg.MODEL.BACKBONE.CONV_BODY + arch = conv_body.replace("-C4", "").replace("-FPN", "") + stages = _C2_STAGE_NAMES[arch] + state_dict = _rename_weights_for_resnet(state_dict, stages) + # *********************************** + # for deformable convolutional layer + state_dict = _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg) + # *********************************** + return dict(model=state_dict) diff --git a/maskrcnn_benchmark/utils/chars.py b/maskrcnn_benchmark/utils/chars.py new file mode 100644 index 0000000000000000000000000000000000000000..1b2638899f76674812973777e8e0770152b8e1f9 --- /dev/null +++ b/maskrcnn_benchmark/utils/chars.py @@ -0,0 +1,199 @@ +import os + +import cv2 +import numpy as np + + +def char2num(char): + if char in "0123456789": + num = ord(char) - ord("0") + 1 + elif char in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ": + num = ord(char.lower()) - ord("a") + 11 + else: + num = 0 + return num + + +def num2char(num): + chars = "_0123456789abcdefghijklmnopqrstuvwxyz" + char = chars[num] + # if num >=1 and num <=10: + # char = chr(ord('0') + num - 1) + # elif num > 10 and num <= 36: + # char = chr(ord('a') + num - 11) + # else: + # print('error number:%d'%(num)) + # exit() + return char + + +def getstr_grid(seg, box, threshold=192): + pos = 255 - (seg[0] * 255).astype(np.uint8) + mask_index = np.argmax(seg, axis=0) + mask_index = mask_index.astype(np.uint8) + pos = pos.astype(np.uint8) + string, score, rec_scores, char_polygons = seg2text( + pos, mask_index, seg, box, threshold=threshold + ) + return string, score, rec_scores, char_polygons + + +def seg2text(gray, mask, seg, box, threshold=192): + ## input numpy + img_h, img_w = gray.shape + box_w = box[2] - box[0] + box_h = box[3] - box[1] + ratio_h = float(box_h) / img_h + ratio_w = float(box_w) / img_w + # SE1=cv2.getStructuringElement(cv2.MORPH_RECT,(3,3)) + # gray = cv2.erode(gray,SE1) + # gray = cv2.dilate(gray,SE1) + # gray = cv2.morphologyEx(gray,cv2.MORPH_CLOSE,SE1) + ret, thresh = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY) + try: + _, contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + except: + contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + chars = [] + scores = [] + char_polygons = [] + for i in range(len(contours)): + char = {} + temp = np.zeros((img_h, img_w)).astype(np.uint8) + cv2.drawContours(temp, [contours[i]], 0, (255), -1) + x, y, w, h = cv2.boundingRect(contours[i]) + c_x, c_y = x + w / 2, y + h / 2 + perimeter = cv2.arcLength(contours[i], True) + epsilon = 0.01 * cv2.arcLength(contours[i], True) + approx = cv2.approxPolyDP(contours[i], epsilon, True) + pts = approx.reshape((-1, 2)) + pts[:, 0] = pts[:, 0] * ratio_w + box[0] + pts[:, 1] = pts[:, 1] * ratio_h + box[1] + polygon = list(pts.reshape((-1,))) + polygon = list(map(int, polygon)) + if len(polygon) >= 6: + char_polygons.append(polygon) + # x1 = x * ratio_w + box[0] + # y1 = y * ratio_h + box[1] + # x3 = (x + w) * ratio_w + box[0] + # y3 = (y + h) * ratio_h + box[1] + # polygon = [x1, y1, x3, y1, x3, y3, x1, y3] + regions = seg[1:, temp == 255].reshape((36, -1)) + cs = np.mean(regions, axis=1) + sym = num2char(np.argmax(cs.reshape((-1))) + 1) + char["x"] = c_x + char["y"] = c_y + char["s"] = sym + char["cs"] = cs.reshape((-1, 1)) + scores.append(np.max(char["cs"], axis=0)[0]) + + chars.append(char) + chars = sorted(chars, key=lambda x: x["x"]) + string = "" + css = [] + for char in chars: + string = string + char["s"] + css.append(char["cs"]) + if len(scores) > 0: + score = sum(scores) / len(scores) + else: + score = 0.00 + if not css: + css = [0.0] + return string, score, np.hstack(css), char_polygons + + +# def get_tight_rect(points, start_x, start_y, image_height, image_width, scale): +# points = list(points) +# ps = sorted(points, key=lambda x: x[0]) +# +# if ps[1][1] > ps[0][1]: +# px1 = ps[0][0] * scale + start_x +# py1 = ps[0][1] * scale + start_y +# px4 = ps[1][0] * scale + start_x +# py4 = ps[1][1] * scale + start_y +# else: +# px1 = ps[1][0] * scale + start_x +# py1 = ps[1][1] * scale + start_y +# px4 = ps[0][0] * scale + start_x +# py4 = ps[0][1] * scale + start_y +# if ps[3][1] > ps[2][1]: +# px2 = ps[2][0] * scale + start_x +# py2 = ps[2][1] * scale + start_y +# px3 = ps[3][0] * scale + start_x +# py3 = ps[3][1] * scale + start_y +# else: +# px2 = ps[3][0] * scale + start_x +# py2 = ps[3][1] * scale + start_y +# px3 = ps[2][0] * scale + start_x +# py3 = ps[2][1] * scale + start_y +# +# if px1 < 0: +# px1 = 1 +# if px1 > image_width: +# px1 = image_width - 1 +# if px2 < 0: +# px2 = 1 +# if px2 > image_width: +# px2 = image_width - 1 +# if px3 < 0: +# px3 = 1 +# if px3 > image_width: +# px3 = image_width - 1 +# if px4 < 0: +# px4 = 1 +# if px4 > image_width: +# px4 = image_width - 1 +# +# if py1 < 0: +# py1 = 1 +# if py1 > image_height: +# py1 = image_height - 1 +# if py2 < 0: +# py2 = 1 +# if py2 > image_height: +# py2 = image_height - 1 +# if py3 < 0: +# py3 = 1 +# if py3 > image_height: +# py3 = image_height - 1 +# if py4 < 0: +# py4 = 1 +# if py4 > image_height: +# py4 = image_height - 1 +# return [px1, py1, px2, py2, px3, py3, px4, py4] + +def get_tight_rect(points, start_x, start_y, image_height, image_width, scale): + points = list(points) + ps = sorted(points, key=lambda x: x[0]) + + if ps[1][1] > ps[0][1]: + px1 = ps[0][0] * scale + start_x + py1 = ps[0][1] * scale + start_y + px4 = ps[1][0] * scale + start_x + py4 = ps[1][1] * scale + start_y + else: + px1 = ps[1][0] * scale + start_x + py1 = ps[1][1] * scale + start_y + px4 = ps[0][0] * scale + start_x + py4 = ps[0][1] * scale + start_y + if ps[3][1] > ps[2][1]: + px2 = ps[2][0] * scale + start_x + py2 = ps[2][1] * scale + start_y + px3 = ps[3][0] * scale + start_x + py3 = ps[3][1] * scale + start_y + else: + px2 = ps[3][0] * scale + start_x + py2 = ps[3][1] * scale + start_y + px3 = ps[2][0] * scale + start_x + py3 = ps[2][1] * scale + start_y + + px1 = min(max(px1, 1), image_width - 1) + px2 = min(max(px2, 1), image_width - 1) + px3 = min(max(px3, 1), image_width - 1) + px4 = min(max(px4, 1), image_width - 1) + py1 = min(max(py1, 1), image_height - 1) + py2 = min(max(py2, 1), image_height - 1) + py3 = min(max(py3, 1), image_height - 1) + py4 = min(max(py4, 1), image_height - 1) + return [px1, py1, px2, py2, px3, py3, px4, py4] diff --git a/maskrcnn_benchmark/utils/checkpoint.py b/maskrcnn_benchmark/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..c436be8bdab7d5cc7fabbd5333549dc115807aea --- /dev/null +++ b/maskrcnn_benchmark/utils/checkpoint.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging +import os + +import torch + +from maskrcnn_benchmark.utils.model_serialization import load_state_dict +from maskrcnn_benchmark.utils.c2_model_loading import load_c2_format +from maskrcnn_benchmark.utils.imports import import_file +from maskrcnn_benchmark.utils.model_zoo import cache_url + + +class Checkpointer(object): + def __init__( + self, + model, + optimizer=None, + scheduler=None, + save_dir="", + save_to_disk=None, + logger=None, + ): + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.save_dir = save_dir + self.save_to_disk = save_to_disk + if logger is None: + logger = logging.getLogger(__name__) + self.logger = logger + + def save(self, name, **kwargs): + if not self.save_dir: + return + + if not self.save_to_disk: + return + + data = {} + data["model"] = self.model.state_dict() + if self.optimizer is not None: + data["optimizer"] = self.optimizer.state_dict() + if self.scheduler is not None: + data["scheduler"] = self.scheduler.state_dict() + data.update(kwargs) + + save_file = os.path.join(self.save_dir, "{}.pth".format(name)) + self.logger.info("Saving checkpoint to {}".format(save_file)) + torch.save(data, save_file) + self.tag_last_checkpoint(save_file) + + def load(self, f=None, resume=False): + if self.has_checkpoint(): + # override argument with existing checkpoint + f = self.get_checkpoint_file() + if not f: + # no checkpoint could be found + self.logger.info("No checkpoint found. Initializing model from scratch") + return {} + self.logger.info("Loading checkpoint from {}".format(f)) + checkpoint = self._load_file(f) + self._load_model(checkpoint) + if resume: + if "optimizer" in checkpoint and self.optimizer: + self.logger.info("Loading optimizer from {}".format(f)) + self.optimizer.load_state_dict(checkpoint.pop("optimizer")) + if "scheduler" in checkpoint and self.scheduler: + self.logger.info("Loading scheduler from {}".format(f)) + self.scheduler.load_state_dict(checkpoint.pop("scheduler")) + + # return any further checkpoint data + return checkpoint + + def has_checkpoint(self): + save_file = os.path.join(self.save_dir, "last_checkpoint") + return os.path.exists(save_file) + + def get_checkpoint_file(self): + save_file = os.path.join(self.save_dir, "last_checkpoint") + try: + with open(save_file, "r") as f: + last_saved = f.read() + except IOError: + # if file doesn't exist, maybe because it has just been + # deleted by a separate process + last_saved = "" + return last_saved + + def tag_last_checkpoint(self, last_filename): + save_file = os.path.join(self.save_dir, "last_checkpoint") + with open(save_file, "w") as f: + f.write(last_filename) + + def _load_file(self, f): + return torch.load(f, map_location=torch.device("cpu")) + + def _load_model(self, checkpoint): + load_state_dict(self.model, checkpoint.pop("model")) + + +class DetectronCheckpointer(Checkpointer): + def __init__( + self, + cfg, + model, + optimizer=None, + scheduler=None, + save_dir="", + save_to_disk=None, + logger=None, + ): + super(DetectronCheckpointer, self).__init__( + model, optimizer, scheduler, save_dir, save_to_disk, logger + ) + self.cfg = cfg.clone() + + def _load_file(self, f): + # catalog lookup + if f.startswith("catalog://"): + paths_catalog = import_file( + "maskrcnn_benchmark.config.paths_catalog", self.cfg.PATHS_CATALOG, True + ) + catalog_f = paths_catalog.ModelCatalog.get(f[len("catalog://") :]) + self.logger.info("{} points to {}".format(f, catalog_f)) + f = catalog_f + # download url files + if f.startswith("http"): + # if the file is a url path, download it and cache it + cached_f = cache_url(f) + self.logger.info("url {} cached in {}".format(f, cached_f)) + f = cached_f + # convert Caffe2 checkpoint from pkl + if f.endswith(".pkl"): + return load_c2_format(self.cfg, f) + # load native detectron.pytorch checkpoint + loaded = super(DetectronCheckpointer, self)._load_file(f) + if "model" not in loaded: + loaded = dict(model=loaded) + return loaded diff --git a/maskrcnn_benchmark/utils/collect_env.py b/maskrcnn_benchmark/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..2d0641dda61c9950cb54d0552106246248e571ef --- /dev/null +++ b/maskrcnn_benchmark/utils/collect_env.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import PIL + +from torch.utils.collect_env import get_pretty_env_info + + +def get_pil_version(): + return "\n Pillow ({})".format(PIL.__version__) + + +def collect_env_info(): + env_str = get_pretty_env_info() + env_str += get_pil_version() + return env_str diff --git a/maskrcnn_benchmark/utils/comm.py b/maskrcnn_benchmark/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed535ab65996a1a65136e04f92ab2c9dfd59d17 --- /dev/null +++ b/maskrcnn_benchmark/utils/comm.py @@ -0,0 +1,378 @@ +# # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# """ +# This file contains primitives for multi-gpu communication. +# This is useful when doing distributed training. +# """ + +# import os +# import pickle +# import tempfile +# import time + +# import torch +# import torch.distributed as dist + + + +# # def get_world_size(): +# # if not dist.is_initialized(): +# # return 1 +# # return dist.get_world_size() +# # +# # +# # def is_main_process(): +# # if not dist.is_initialized(): +# # return True +# # return dist.get_rank() == 0 +# # +# # def get_rank(): +# # if not dist.is_initialized(): +# # return 0 +# # return dist.get_rank() +# # +# # def synchronize(): +# # """ +# # Helper function to synchronize between multiple processes when +# # using distributed training +# # """ +# # if not dist.is_initialized(): +# # return +# # world_size = dist.get_world_size() +# # rank = dist.get_rank() +# # if world_size == 1: +# # return +# # +# # def _send_and_wait(r): +# # if rank == r: +# # tensor = torch.tensor(0, device="cuda") +# # else: +# # tensor = torch.tensor(1, device="cuda") +# # dist.broadcast(tensor, r) +# # while tensor.item() == 1: +# # time.sleep(1) +# # +# # _send_and_wait(0) +# # # now sync on the main process +# # _send_and_wait(1) +# # +# # +# def _encode(encoded_data, data): +# # gets a byte representation for the data +# encoded_bytes = pickle.dumps(data) +# # convert this byte string into a byte tensor +# storage = torch.ByteStorage.from_buffer(encoded_bytes) +# tensor = torch.ByteTensor(storage).to("cuda") +# # encoding: first byte is the size and then rest is the data +# s = tensor.numel() +# assert s <= 255, "Can't encode data greater than 255 bytes" +# # put the encoded data in encoded_data +# encoded_data[0] = s +# encoded_data[1 : (s + 1)] = tensor + + +# def _decode(encoded_data): +# size = encoded_data[0] +# encoded_tensor = encoded_data[1 : (size + 1)].to("cpu") +# return pickle.loads(bytearray(encoded_tensor.tolist())) + + +# # TODO try to use tensor in shared-memory instead of serializing to disk +# # this involves getting the all_gather to work +# def scatter_gather(data): +# """ +# This function gathers data from multiple processes, and returns them +# in a list, as they were obtained from each process. + +# This function is useful for retrieving data from multiple processes, +# when launching the code with torch.distributed.launch + +# Note: this function is slow and should not be used in tight loops, i.e., +# do not use it in the training loop. + +# Arguments: +# data: the object to be gathered from multiple processes. +# It must be serializable + +# Returns: +# result (list): a list with as many elements as there are processes, +# where each element i in the list corresponds to the data that was +# gathered from the process of rank i. +# """ +# # strategy: the main process creates a temporary directory, and communicates +# # the location of the temporary directory to all other processes. +# # each process will then serialize the data to the folder defined by +# # the main process, and then the main process reads all of the serialized +# # files and returns them in a list +# if not dist.is_initialized(): +# return [data] +# synchronize() +# # get rank of the current process +# rank = dist.get_rank() + +# # the data to communicate should be small +# data_to_communicate = torch.empty(256, dtype=torch.uint8, device="cuda") +# if rank == 0: +# # manually creates a temporary directory, that needs to be cleaned +# # afterwards +# tmp_dir = tempfile.mkdtemp() +# _encode(data_to_communicate, tmp_dir) + +# synchronize() +# # the main process (rank=0) communicates the data to all processes +# dist.broadcast(data_to_communicate, 0) + +# # get the data that was communicated +# tmp_dir = _decode(data_to_communicate) + +# # each process serializes to a different file +# file_template = "file{}.pth" +# tmp_file = os.path.join(tmp_dir, file_template.format(rank)) +# torch.save(data, tmp_file) + +# # synchronize before loading the data +# synchronize() + +# # only the master process returns the data +# if rank == 0: +# data_list = [] +# world_size = dist.get_world_size() +# for r in range(world_size): +# file_path = os.path.join(tmp_dir, file_template.format(r)) +# d = torch.load(file_path) +# data_list.append(d) +# # cleanup +# os.remove(file_path) +# # cleanup +# os.rmdir(tmp_dir) +# return data_list + + +# def get_world_size(): +# if not dist.is_available(): +# print('distributed is not available') +# return 1 +# if not dist.is_initialized(): +# print('distributed is not initialized') +# return 1 +# return dist.get_world_size() + + +# def get_rank(): +# if not dist.is_available(): +# return 0 +# if not dist.is_initialized(): +# return 0 +# return dist.get_rank() + + +# def is_main_process(): +# return get_rank() == 0 + + +# def synchronize(): +# """ +# Helper function to synchronize (barrier) among all processes when +# using distributed training +# """ +# if not dist.is_available(): +# return +# if not dist.is_initialized(): +# return +# world_size = dist.get_world_size() +# if world_size == 1: +# return +# dist.barrier() + + +# def all_gather(data): +# """ +# Run all_gather on arbitrary picklable data (not necessarily tensors) + +# Args: +# data: any picklable object + +# Returns: +# list[data]: list of data gathered from each rank +# """ +# world_size = get_world_size() +# if world_size == 1: +# return [data] + +# # serialized to a Tensor +# buffer = pickle.dumps(data) +# storage = torch.ByteStorage.from_buffer(buffer) +# tensor = torch.ByteTensor(storage).to("cuda") + +# # obtain Tensor size of each rank +# local_size = torch.IntTensor([tensor.numel()]).to("cuda") +# size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] +# dist.all_gather(size_list, local_size) +# size_list = [int(size.item()) for size in size_list] +# max_size = max(size_list) + +# # receiving Tensor from all ranks +# # we pad the tensor because torch all_gather does not support +# # gathering tensors of different shapes +# tensor_list = [] +# for _ in size_list: +# tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) +# if local_size != max_size: +# padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") +# tensor = torch.cat((tensor, padding), dim=0) +# dist.all_gather(tensor_list, tensor) + +# data_list = [] +# for size, tensor in zip(size_list, tensor_list): +# buffer = tensor.cpu().numpy().tobytes()[:size] +# data_list.append(pickle.loads(buffer)) + +# return data_list + + +# def reduce_dict(input_dict, average=True): +# """ +# Args: +# input_dict (dict): all the values will be reduced +# average (bool): whether to do average or sum + +# Reduce the values in the dictionary from all processes so that process with rank +# 0 has the averaged results. Returns a dict with the same fields as +# input_dict, after reduction. +# """ +# world_size = get_world_size() +# if world_size < 2: +# return input_dict +# with torch.no_grad(): +# names = [] +# values = [] +# # sort the keys so that they are consistent across processes +# for k in sorted(input_dict.keys()): +# names.append(k) +# values.append(input_dict[k]) +# values = torch.stack(values, dim=0) +# dist.reduce(values, dst=0) +# if dist.get_rank() == 0 and average: +# # only main process gets accumulated, so only divide by +# # world_size in this case +# values /= world_size +# reduced_dict = {k: v for k, v in zip(names, values)} +# return reduced_dict + + +""" +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import pickle +import time + +import torch +import torch.distributed as dist + + +def get_world_size(): + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +def scatter_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.LongTensor([tensor.numel()]).to("cuda") + size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that process with rank + 0 has the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/maskrcnn_benchmark/utils/env.py b/maskrcnn_benchmark/utils/env.py new file mode 100644 index 0000000000000000000000000000000000000000..1c7db32e41ec266ead9734f90d0173b4feff61ef --- /dev/null +++ b/maskrcnn_benchmark/utils/env.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import os + +from maskrcnn_benchmark.utils.imports import import_file + + +def setup_environment(): + """Perform environment setup work. The default setup is a no-op, but this + function allows the user to specify a Python source file that performs + custom setup work that may be necessary to their computing environment. + """ + custom_module_path = os.environ.get("TORCH_DETECTRON_ENV_MODULE") + if custom_module_path: + setup_custom_environment(custom_module_path) + else: + # The default setup is a no-op + pass + + +def setup_custom_environment(custom_module_path): + """Load custom environment setup from a Python source file and run the setup + function. + """ + module = import_file("maskrcnn_benchmark.utils.env.custom_module", custom_module_path) + assert hasattr(module, "setup_environment") and callable( + module.setup_environment + ), ( + "Custom environment module defined in {} does not have the " + "required callable attribute 'setup_environment'." + ).format( + custom_module_path + ) + module.setup_environment() + + +# Force environment setup when this module is imported +setup_environment() diff --git a/maskrcnn_benchmark/utils/imports.py b/maskrcnn_benchmark/utils/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..4b3cfa6616de7b04203ece24af8f54854dafe3f7 --- /dev/null +++ b/maskrcnn_benchmark/utils/imports.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import importlib +import importlib.util +import sys + + +# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa +def import_file(module_name, file_path, make_importable=False): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if make_importable: + sys.modules[module_name] = module + return module diff --git a/maskrcnn_benchmark/utils/logging.py b/maskrcnn_benchmark/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..d2333ce75390ee9643162f8902907d1a43ef2dbe --- /dev/null +++ b/maskrcnn_benchmark/utils/logging.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging +import os +import sys + +from tensorboardX import SummaryWriter + + +def setup_logger(name, save_dir, distributed_rank=0): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + # don't log results for the non-master process + if distributed_rank > 0: + return logger + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") + ch.setFormatter(formatter) + logger.addHandler(ch) + + if save_dir: + fh = logging.FileHandler(os.path.join(save_dir, "log.txt")) + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger + + +class Logger(object): + def __init__(self, log_dir, distributed_rank=0): + """Create a summary writer logging to log_dir.""" + self.distributed_rank = distributed_rank + if distributed_rank == 0: + self.writer = SummaryWriter(log_dir) + + + def scalar_summary(self, tag, value, step): + """Log a scalar variable.""" + if self.distributed_rank == 0: + self.writer.add_scalar(tag, value, step) diff --git a/maskrcnn_benchmark/utils/metric_logger.py b/maskrcnn_benchmark/utils/metric_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..c314e1311777d9085a6287cc44f3532a7550c3fe --- /dev/null +++ b/maskrcnn_benchmark/utils/metric_logger.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from collections import defaultdict +from collections import deque + +import torch + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20): + self.deque = deque(maxlen=window_size) + self.series = [] + self.total = 0.0 + self.count = 0 + + def update(self, value): + self.deque.append(value) + self.series.append(value) + self.count += 1 + self.total += value + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque)) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + return object.__getattr__(self, attr) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) + ) + return self.delimiter.join(loss_str) diff --git a/maskrcnn_benchmark/utils/miscellaneous.py b/maskrcnn_benchmark/utils/miscellaneous.py new file mode 100644 index 0000000000000000000000000000000000000000..db9a8b3679ceea2a5cd2b807421793bbbd3d3677 --- /dev/null +++ b/maskrcnn_benchmark/utils/miscellaneous.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import errno +import os + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise diff --git a/maskrcnn_benchmark/utils/model_serialization.py b/maskrcnn_benchmark/utils/model_serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..a95ad8b2a7a787d62dc3ea580b2dfd30e358da28 --- /dev/null +++ b/maskrcnn_benchmark/utils/model_serialization.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from collections import OrderedDict +import logging + +import torch + +from maskrcnn_benchmark.utils.imports import import_file + + +def align_and_update_state_dicts(model_state_dict, loaded_state_dict): + """ + Strategy: suppose that the models that we will create will have prefixes appended + to each of its keys, for example due to an extra level of nesting that the original + pre-trained weights from ImageNet won't contain. For example, model.state_dict() + might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains + res2.conv1.weight. We thus want to match both parameters together. + For that, we look for each model weight, look among all loaded keys if there is one + that is a suffix of the current weight name, and use it if that's the case. + If multiple matches exist, take the one with longest size + of the corresponding name. For example, for the same model as before, the pretrained + weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, + we want to match backbone[0].body.conv1.weight to conv1.weight, and + backbone[0].body.res2.conv1.weight to res2.conv1.weight. + """ + current_keys = sorted(list(model_state_dict.keys())) + loaded_keys = sorted(list(loaded_state_dict.keys())) + # get a matrix of string matches, where each (i, j) entry correspond to the size of the + # loaded_key string, if it matches + match_matrix = [ + len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys + ] + match_matrix = torch.as_tensor(match_matrix).view( + len(current_keys), len(loaded_keys) + ) + max_match_size, idxs = match_matrix.max(1) + # remove indices that correspond to no-match + idxs[max_match_size == 0] = -1 + + # used for logging + max_size = max([len(key) for key in current_keys]) if current_keys else 1 + max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 + log_str_template = "{: <{}} loaded from {: <{}} of shape {}" + logger = logging.getLogger(__name__) + for idx_new, idx_old in enumerate(idxs.tolist()): + if idx_old == -1: + continue + key = current_keys[idx_new] + key_old = loaded_keys[idx_old] + model_state_dict[key] = loaded_state_dict[key_old] + logger.info( + log_str_template.format( + key, + max_size, + key_old, + max_size_loaded, + tuple(loaded_state_dict[key_old].shape), + ) + ) + + +def strip_prefix_if_present(state_dict, prefix): + keys = sorted(state_dict.keys()) + if not all(key.startswith(prefix) for key in keys): + return state_dict + stripped_state_dict = OrderedDict() + for key, value in state_dict.items(): + stripped_state_dict[key.replace(prefix, "")] = value + return stripped_state_dict + + +def load_state_dict(model, loaded_state_dict): + model_state_dict = model.state_dict() + # if the state_dict comes from a model that was wrapped in a + # DataParallel or DistributedDataParallel during serialization, + # remove the "module" prefix before performing the matching + loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") + align_and_update_state_dicts(model_state_dict, loaded_state_dict) + + # use strict loading + model.load_state_dict(model_state_dict) diff --git a/maskrcnn_benchmark/utils/model_zoo.py b/maskrcnn_benchmark/utils/model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..22aede8f8f551195405787c39b0924a9f7152c86 --- /dev/null +++ b/maskrcnn_benchmark/utils/model_zoo.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import os +import sys + +try: + from torch.hub import _download_url_to_file + from torch.hub import urlparse + from torch.hub import HASH_REGEX +except ImportError: + from torch.utils.model_zoo import _download_url_to_file + from torch.utils.model_zoo import urlparse + from torch.utils.model_zoo import HASH_REGEX + +from maskrcnn_benchmark.utils.comm import is_main_process +from maskrcnn_benchmark.utils.comm import synchronize + + +# very similar to https://github.com/pytorch/pytorch/blob/master/torch/utils/model_zoo.py +# but with a few improvements and modifications +def cache_url(url, model_dir=None, progress=True): + r"""Loads the Torch serialized object at the given URL. + If the object is already present in `model_dir`, it's deserialized and + returned. The filename part of the URL should follow the naming convention + ``filename-.ext`` where ```` is the first eight or more + digits of the SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. + The default value of `model_dir` is ``$TORCH_HOME/models`` where + ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be + overridden with the ``$TORCH_MODEL_ZOO`` environment variable. + Args: + url (string): URL of the object to download + model_dir (string, optional): directory in which to save the object + progress (bool, optional): whether or not to display a progress bar to stderr + Example: + >>> cached_file = maskrcnn_benchmark.utils.model_zoo.cache_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') + """ + if model_dir is None: + torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch")) + model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models")) + if not os.path.exists(model_dir): + os.makedirs(model_dir) + parts = urlparse(url) + filename = os.path.basename(parts.path) + if filename == "model_final.pkl": + # workaround as pre-trained Caffe2 models from Detectron have all the same filename + # so make the full path the filename by replacing / with _ + filename = parts.path.replace("/", "_") + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file) and is_main_process(): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = HASH_REGEX.search(filename) + if hash_prefix is not None: + hash_prefix = hash_prefix.group(1) + # workaround: Caffe2 models don't have a hash, but follow the R-50 convention, + # which matches the hash PyTorch uses. So we skip the hash matching + # if the hash_prefix is less than 6 characters + if len(hash_prefix) < 6: + hash_prefix = None + _download_url_to_file(url, cached_file, hash_prefix, progress=progress) + synchronize() + return cached_file \ No newline at end of file diff --git a/maskrcnn_benchmark/utils/registry.py b/maskrcnn_benchmark/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..1e8d91e7d29283b76f4a42c2137f5026d3528fd9 --- /dev/null +++ b/maskrcnn_benchmark/utils/registry.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + + +def _register_generic(module_dict, module_name, module): + assert module_name not in module_dict + module_dict[module_name] = module + + +class Registry(dict): + ''' + A helper class for managing registering modules, it extends a dictionary + and provides a register functions. + + Eg. creeting a registry: + some_registry = Registry({"default": default_module}) + + There're two ways of registering new modules: + 1): normal way is just calling register function: + def foo(): + ... + some_registry.register("foo_module", foo) + 2): used as decorator when declaring the module: + @some_registry.register("foo_module") + @some_registry.register("foo_modeul_nickname") + def foo(): + ... + + Access of module is just like using a dictionary, eg: + f = some_registry["foo_modeul"] + ''' + def __init__(self, *args, **kwargs): + super(Registry, self).__init__(*args, **kwargs) + + def register(self, module_name, module=None): + # used as function call + if module is not None: + _register_generic(self, module_name, module) + return + + # used as decorator + def register_fn(fn): + _register_generic(self, module_name, fn) + return fn + + return register_fn \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..bb33c48e30527e4e7d746e2651d347b7f08902e6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +yacs +gdown +pyclipper +pytorch-extension \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..102dede268d07fe145acbf6145e21c70a780c39b --- /dev/null +++ b/setup.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#!/usr/bin/env python + +import glob +import os + +import torch +from setuptools import find_packages +from setuptools import setup +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +requirements = ["torch", "torchvision"] + + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "maskrcnn_benchmark", "csrc") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + + extra_compile_args = {"cxx": []} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + # if True: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + + sources = [os.path.join(extensions_dir, s) for s in sources] + + include_dirs = [extensions_dir] + + ext_modules = [ + extension( + "maskrcnn_benchmark._C", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + + return ext_modules + + +setup( + name="maskrcnn_benchmark", + version="0.1", + author="fmassa", + url="https://github.com/facebookresearch/maskrnn-benchmark", + description="object detection in pytorch", + packages=find_packages(exclude=("configs", "examples", "test",)), + # install_requires=requirements, + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/test.sh b/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..528647c5bde432fd03c83dfb40cd0ed3bfcffcdf --- /dev/null +++ b/test.sh @@ -0,0 +1 @@ +python tools/test_net.py --config-file configs/mixtrain/seg_rec_poly_fuse_feature.yaml \ No newline at end of file diff --git a/tools/convert_dataset.py b/tools/convert_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..779d6b55169f78928b6adda39264b015e5122c34 --- /dev/null +++ b/tools/convert_dataset.py @@ -0,0 +1,178 @@ +import os +import numpy as np +import cv2 +from shapely.geometry import box, Polygon +from shapely import affinity +import math + + +def _rect2quad(boxes): + x_min, y_min, x_max, y_max = boxes[:, 0].reshape((-1, 1)), boxes[:, 1].reshape((-1, 1)), boxes[:, 2].reshape((-1, 1)), boxes[:, 3].reshape((-1, 1)) + return np.hstack((x_min, y_min, x_max, y_min, x_max, y_max, x_min, y_max)) + +def _quad2rect(boxes): + ## only support rectangle + return np.hstack((boxes[:, 0].reshape((-1, 1)), boxes[:, 1].reshape((-1, 1)), boxes[:, 4].reshape((-1, 1)), boxes[:, 5].reshape((-1, 1)))) + +def _quad2minrect(boxes): + ## trans a quad(N*4) to a rectangle(N*4) which has miniual area to cover it + return np.hstack((boxes[:, ::2].min(axis=1).reshape((-1, 1)), boxes[:, 1::2].min(axis=1).reshape((-1, 1)), boxes[:, ::2].max(axis=1).reshape((-1, 1)), boxes[:, 1::2].max(axis=1).reshape((-1, 1)))) + + +def _quad2boxlist(boxes): + res = [] + for i in range(boxes.shape[0]): + res.append([[boxes[i][0], boxes[i][1]], [boxes[i][2], boxes[i][3]], [boxes[i][4], boxes[i][5]], [boxes[i][6], boxes[i][7]]]) + return res + +def _boxlist2quads(boxlist): + res = np.zeros((len(boxlist), 8)) + for i, box in enumerate(boxlist): + # print(box) + res[i] = np.array([box[0][0], box[0][1], box[1][0], box[1][1], box[2][0], box[2][1], box[3][0], box[3][1]]) + return res + +def _rotate_image(im, polygons, angle): + new_polygons = polygons + ## rotate image first + height, width, _ = im.shape + ## get the minimal rect to cover the rotated image + img_box = np.array([[0, 0, width, 0, width, height, 0, height]]) + rotated_img_box = _quad2minrect(_rotate_polygons(img_box, -1*angle, (width/2, height/2))) + r_height = int(max(rotated_img_box[0][3], rotated_img_box[0][1]) - min(rotated_img_box[0][3], rotated_img_box[0][1])) + r_width = int(max(rotated_img_box[0][2], rotated_img_box[0][0]) - min(rotated_img_box[0][2], rotated_img_box[0][0])) + r_height_padding = max(r_height, height) + r_width_padding = max(r_width, width) + ## padding im + im_padding = np.zeros((r_height_padding, r_width_padding, 3)) + start_h, start_w = int((r_height_padding - height)/2.0), int((r_width_padding - width)/2.0) + # start_h = max(start_h, 0) + # start_w = max(start_w, 0) + end_h, end_w = start_h + height, start_w + width + # print(start_h, end_h, start_w, end_w, im.shape) + im_padding[start_h:end_h, start_w:end_w, :] = im + + M = cv2.getRotationMatrix2D((r_width/2, r_height/2), angle, 1) + im = cv2.warpAffine(im_padding, M, (r_width, r_height)) + + ## polygons + new_polygons = _rotate_segms(polygons, -1*angle, (r_width/2, r_height/2), start_h, start_w) + + return im, new_polygons + +def _rotate_polygons(polygons, angle, r_c): + ## polygons: N*8 + ## r_x: rotate center x + ## r_y: rotate center y + ## angle: -15~15 + + poly_list = _quad2boxlist(polygons) + rotate_boxes_list = [] + for poly in poly_list: + box = Polygon(poly) + rbox = affinity.rotate(box, angle, r_c) + if len(list(rbox.exterior.coords))<5: + print(poly) + print(rbox) + # assert(len(list(rbox.exterior.coords))>=5) + rotate_boxes_list.append(rbox.boundary.coords[:-1]) + res = _boxlist2quads(rotate_boxes_list) + return res + +def _rotate_segms(polygons, angle, r_c, start_h, start_w): + ## polygons: N*8 + ## r_x: rotate center x + ## r_y: rotate center y + ## angle: -15~15 + poly_list=[] + for polygon in polygons: + tmp=[] + for i in range(int(len(polygon) / 2)): + tmp.append([polygon[2*i] + start_w, polygon[2*i+1] + start_h]) + poly_list.append(tmp) + + rotate_boxes_list = [] + for poly in poly_list: + box = Polygon(poly) + rbox = affinity.rotate(box, angle, r_c) + if len(list(rbox.exterior.coords))<5: + print(poly) + print(rbox) + rotate_boxes_list.append(rbox.boundary.coords[:-1]) + res = [] + for i, box in enumerate(rotate_boxes_list): + tmp = [] + for point in box: + tmp.append(point[0]) + tmp.append(point[1]) + res.append([tmp]) + + return res + +def _read_gt(gt_path): + polygons = [] + words = [] + with open(gt_path, 'r') as fid: + lines = fid.readlines() + for line in lines: + line = line.strip() + polygon = line.split(',')[:8] + word = line.split(',')[8] + polygon = [float(x) for x in polygon] + polygons.append(polygon) + words.append(word) + return polygons, words + +def format_new_gt(polygons, words, new_gt_path): + with open(new_gt_path, 'wt') as fid: + for polygon, word in zip(polygons, words): + # print(polygon) + polygon = [str(int(x)) for x in polygon[0]] + # polygon = [str(int(x)) for x in polygon] + line = ','.join(polygon) + ',' + word + # print(line) + fid.write(line+'\n') + +def visu_gt(img, polygons, visu_path): + for polygon in polygons: + pts = np.array(polygon, np.int32) + pts = pts.reshape((-1,1,2)) + cv2.polylines(img,[pts],True,(0,255,255)) + cv2.imwrite(visu_path, img) + + +img_dir = '../datasets/icdar2013/test_images' +gt_dir = '../datasets/icdar2013/test_gts' +angle = 45 +new_img_dir = '../datasets/icdar2013/rotated_test_images'+'_'+str(angle) +new_gt_dir = '../datasets/icdar2013/rotated_test_gts'+'_'+str(angle) +if not os.path.isdir(new_img_dir): + os.mkdir(new_img_dir) +if not os.path.isdir(new_gt_dir): + os.mkdir(new_gt_dir) + +visu_dir = '../output/visu/' + +for i in range(233): + img_name = 'img_' + str(i+1) + '.jpg' + img_path = os.path.join(img_dir, img_name) + img = cv2.imread(img_path) + gt_path = os.path.join(gt_dir, img_name + '.txt') + new_img_path = os.path.join(new_img_dir, img_name) + visu_path = os.path.join(visu_dir, img_name) + new_gt_path = os.path.join(new_gt_dir, 'gt_' + img_name.split('.')[0] + '.txt') + polygons, words = _read_gt(gt_path) + # print(img_name) + if angle == 90: + (h, w) = img.shape[:2] + img = cv2.transpose(img) + img = cv2.flip(img,flipCode=0) + # M = cv2.getRotationMatrix2D(center, 90, 1) + # img = cv2.warpAffine(img, M, (h, w)) + new_polygons = [[polygon[1], w-polygon[0], polygon[3], w-polygon[2], polygon[5], w-polygon[4], polygon[7], w-polygon[6]] for polygon in polygons] + else: + img, new_polygons = _rotate_image(img, polygons, angle) + format_new_gt(new_polygons, words, new_gt_path) + # visu_gt(img, new_polygons, visu_path) + cv2.imwrite(new_img_path, img) + \ No newline at end of file diff --git a/tools/demo.py b/tools/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..1d32ec8f775f602da4978fa5a5aeece86bc52879 --- /dev/null +++ b/tools/demo.py @@ -0,0 +1,239 @@ +import os +import cv2 +import torch +from torchvision import transforms as T + +from maskrcnn_benchmark.modeling.detector import build_detection_model +from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer +from maskrcnn_benchmark.structures.image_list import to_image_list +from maskrcnn_benchmark.config import cfg +from maskrcnn_benchmark.utils.chars import getstr_grid, get_tight_rect + +from PIL import Image +import numpy as np +import argparse + +class TextDemo(object): + def __init__( + self, + cfg, + confidence_threshold=0.7, + min_image_size=224, + output_polygon=True + ): + self.cfg = cfg.clone() + self.model = build_detection_model(cfg) + self.model.eval() + self.device = torch.device(cfg.MODEL.DEVICE) + self.model.to(self.device) + self.min_image_size = min_image_size + + checkpointer = DetectronCheckpointer(cfg, self.model) + _ = checkpointer.load(cfg.MODEL.WEIGHT) + + self.transforms = self.build_transform() + self.cpu_device = torch.device("cpu") + self.confidence_threshold = confidence_threshold + self.output_polygon = output_polygon + + def build_transform(self): + """ + Creates a basic transformation that was used to train the models + """ + cfg = self.cfg + # we are loading images with OpenCV, so we don't need to convert them + # to BGR, they are already! So all we need to do is to normalize + # by 255 if we want to convert to BGR255 format, or flip the channels + # if we want it to be in RGB in [0-1] range. + if cfg.INPUT.TO_BGR255: + to_bgr_transform = T.Lambda(lambda x: x * 255) + else: + to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]]) + + normalize_transform = T.Normalize( + mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD + ) + + transform = T.Compose( + [ + T.ToPILImage(), + T.Resize(self.min_image_size), + T.ToTensor(), + to_bgr_transform, + normalize_transform, + ] + ) + return transform + + def run_on_opencv_image(self, image): + """ + Arguments: + image (np.ndarray): an image as returned by OpenCV + Returns: + result_polygons (list): detection results + result_words (list): recognition results + """ + result_polygons, result_words = self.compute_prediction(image) + return result_polygons, result_words + + def compute_prediction(self, original_image): + # apply pre-processing to image + image = self.transforms(original_image) + # convert to an ImageList, padded so that it is divisible by + # cfg.DATALOADER.SIZE_DIVISIBILITY + image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY) + image_list = image_list.to(self.device) + # compute predictions + with torch.no_grad(): + predictions, _, _ = self.model(image_list) + global_predictions = predictions[0] + char_predictions = predictions[1] + char_mask = char_predictions['char_mask'] + char_boxes = char_predictions['boxes'] + words, rec_scores = self.process_char_mask(char_mask, char_boxes) + seq_words = char_predictions['seq_outputs'] + seq_scores = char_predictions['seq_scores'] + + global_predictions = [o.to(self.cpu_device) for o in global_predictions] + + # always single image is passed at a time + global_prediction = global_predictions[0] + + # reshape prediction (a BoxList) into the original image size + height, width = original_image.shape[:-1] + global_prediction = global_prediction.resize((width, height)) + boxes = global_prediction.bbox.tolist() + scores = global_prediction.get_field("scores").tolist() + masks = global_prediction.get_field("mask").cpu().numpy() + + result_polygons = [] + result_words = [] + for k, box in enumerate(boxes): + score = scores[k] + if score < self.confidence_threshold: + continue + box = list(map(int, box)) + mask = masks[k,0,:,:] + polygon = self.mask2polygon(mask, box, original_image.shape, threshold=0.5, output_polygon=self.output_polygon) + if polygon is None: + polygon = [box[0], box[1], box[2], box[1], box[2], box[3], box[0], box[3]] + result_polygons.append(polygon) + word = words[k] + rec_score = rec_scores[k] + seq_word = seq_words[k] + seq_char_scores = seq_scores[k] + seq_score = sum(seq_char_scores) / float(len(seq_char_scores)) + if seq_score > rec_score: + result_words.append(seq_word) + else: + result_words.append(word) + return result_polygons, result_words + + def process_char_mask(self, char_masks, boxes, threshold=192): + texts, rec_scores = [], [] + for index in range(char_masks.shape[0]): + box = list(boxes[index]) + box = list(map(int, box)) + text, rec_score, _, _ = getstr_grid(char_masks[index,:,:,:].copy(), box, threshold=threshold) + texts.append(text) + rec_scores.append(rec_score) + return texts, rec_scores + + def mask2polygon(self, mask, box, im_size, threshold=0.5, output_polygon=True): + # mask 32*128 + image_width, image_height = im_size[1], im_size[0] + box_h = box[3] - box[1] + box_w = box[2] - box[0] + cls_polys = (mask*255).astype(np.uint8) + poly_map = np.array(Image.fromarray(cls_polys).resize((box_w, box_h))) + poly_map = poly_map.astype(np.float32) / 255 + poly_map=cv2.GaussianBlur(poly_map,(3,3),sigmaX=3) + ret, poly_map = cv2.threshold(poly_map,0.5,1,cv2.THRESH_BINARY) + if output_polygon: + SE1=cv2.getStructuringElement(cv2.MORPH_RECT,(3,3)) + poly_map = cv2.erode(poly_map,SE1) + poly_map = cv2.dilate(poly_map,SE1); + poly_map = cv2.morphologyEx(poly_map,cv2.MORPH_CLOSE,SE1) + try: + _, contours, _ = cv2.findContours((poly_map * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) + except: + contours, _ = cv2.findContours((poly_map * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) + if len(contours)==0: + print(contours) + print(len(contours)) + return None + max_area=0 + max_cnt = contours[0] + for cnt in contours: + area=cv2.contourArea(cnt) + if area > max_area: + max_area = area + max_cnt = cnt + perimeter = cv2.arcLength(max_cnt,True) + epsilon = 0.01*cv2.arcLength(max_cnt,True) + approx = cv2.approxPolyDP(max_cnt,epsilon,True) + pts = approx.reshape((-1,2)) + pts[:,0] = pts[:,0] + box[0] + pts[:,1] = pts[:,1] + box[1] + polygon = list(pts.reshape((-1,))) + polygon = list(map(int, polygon)) + if len(polygon)<6: + return None + else: + SE1=cv2.getStructuringElement(cv2.MORPH_RECT,(3,3)) + poly_map = cv2.erode(poly_map,SE1) + poly_map = cv2.dilate(poly_map,SE1); + poly_map = cv2.morphologyEx(poly_map,cv2.MORPH_CLOSE,SE1) + idy,idx=np.where(poly_map == 1) + xy=np.vstack((idx,idy)) + xy=np.transpose(xy) + hull = cv2.convexHull(xy, clockwise=True) + #reverse order of points. + if hull is None: + return None + hull=hull[::-1] + #find minimum area bounding box. + rect = cv2.minAreaRect(hull) + corners = cv2.boxPoints(rect) + corners = np.array(corners, dtype="int") + pts = get_tight_rect(corners, box[0], box[1], image_height, image_width, 1) + polygon = [x * 1.0 for x in pts] + polygon = list(map(int, polygon)) + return polygon + + def visualization(self, image, polygons, words): + for polygon, word in zip(polygons, words): + pts = np.array(polygon, np.int32) + pts = pts.reshape((-1,1,2)) + xmin = min(pts[:,0,0]) + ymin = min(pts[:,0,1]) + cv2.polylines(image,[pts],True,(0,0,255)) + cv2.putText(image, word, (xmin, ymin), cv2.FONT_HERSHEY_COMPLEX, 1, (0,0,255), 2) + + +def main(args): + # update the config options with the config file + cfg.merge_from_file(args.config_file) + # manual override some options + # cfg.merge_from_list(["MODEL.DEVICE", "cpu"]) + + text_demo = TextDemo( + cfg, + min_image_size=800, + confidence_threshold=0.7, + output_polygon=True + ) + # load image and then run prediction + + image = cv2.imread(args.image_path) + result_polygons, result_words = text_demo.run_on_opencv_image(image) + text_demo.visualization(image, result_polygons, result_words) + cv2.imwrite(args.visu_path, image) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='parameters for demo') + parser.add_argument("--config-file", type=str, default='configs/mixtrain/seg_rec_poly_fuse_feature.yaml') + parser.add_argument("--image_path", type=str, default='./demo_images/demo.jpg') + parser.add_argument("--visu_path", type=str, default='./demo_images/demo_results.jpg') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/tools/test_net.py b/tools/test_net.py new file mode 100644 index 0000000000000000000000000000000000000000..63a7ff431bb414e2a190cc0959655140a86c2439 --- /dev/null +++ b/tools/test_net.py @@ -0,0 +1,103 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Set up custom environment before nearly anything else is imported +# NOTE: this should be the first import (no not reorder) +from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip + +import argparse +import os + +import torch +from maskrcnn_benchmark.config import cfg +from maskrcnn_benchmark.data import make_data_loader +from maskrcnn_benchmark.engine.text_inference import inference +from maskrcnn_benchmark.modeling.detector import build_detection_model +from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer +from maskrcnn_benchmark.utils.collect_env import collect_env_info +from maskrcnn_benchmark.utils.comm import synchronize, get_rank +from maskrcnn_benchmark.utils.logging import setup_logger +from maskrcnn_benchmark.utils.miscellaneous import mkdir +# Check if we can enable mixed-precision via apex.amp +try: + from apex import amp +except ImportError: + raise ImportError('Use APEX for mixed precision via apex.amp') + +def main(): + parser = argparse.ArgumentParser(description="PyTorch Object Detection Inference") + parser.add_argument( + "--config-file", + default="./configs/seq.yaml", + metavar="FILE", + help="path to config file", + ) + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + + args = parser.parse_args() + + num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 + distributed = num_gpus > 1 + + if distributed: + torch.cuda.set_device(args.local_rank) + torch.distributed.deprecated.init_process_group( + backend="nccl", init_method="env://" + ) + + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + + save_dir = "" + logger = setup_logger("maskrcnn_benchmark", save_dir, get_rank()) + logger.info("Using {} GPUs".format(num_gpus)) + logger.info(cfg) + + logger.info("Collecting env info (might take some time)") + logger.info("\n" + collect_env_info()) + + model = build_detection_model(cfg) + model.to(cfg.MODEL.DEVICE) + + # Initialize mixed-precision if necessary + use_mixed_precision = cfg.DTYPE == 'float16' + amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE) + + checkpointer = DetectronCheckpointer(cfg, model) + _ = checkpointer.load(cfg.MODEL.WEIGHT) + + iou_types = ("bbox",) + if cfg.MODEL.MASK_ON: + iou_types = iou_types + ("segm",) + output_folders = [None] * len(cfg.DATASETS.TEST) + if cfg.OUTPUT_DIR: + dataset_names = cfg.DATASETS.TEST + for idx, dataset_name in enumerate(dataset_names): + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) + mkdir(output_folder) + output_folders[idx] = output_folder + data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) + model_name = cfg.MODEL.WEIGHT.split('/')[-1] + for output_folder, data_loader_val in zip(output_folders, data_loaders_val): + inference( + model, + data_loader_val, + iou_types=iou_types, + box_only=cfg.MODEL.RPN_ONLY, + device=cfg.MODEL.DEVICE, + expected_results=cfg.TEST.EXPECTED_RESULTS, + expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, + output_folder=output_folder, + model_name=model_name, + cfg=cfg, + ) + synchronize() + + +if __name__ == "__main__": + main() diff --git a/tools/train_net.py b/tools/train_net.py new file mode 100644 index 0000000000000000000000000000000000000000..20fefe59581f6e681ec51876762353bf8b0c0ee5 --- /dev/null +++ b/tools/train_net.py @@ -0,0 +1,153 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +r""" +Basic training script for PyTorch +""" + +# Set up custom environment before nearly anything else is imported +# NOTE: this should be the first import (no not reorder) +from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip + +import argparse +import os + +import torch +from maskrcnn_benchmark.config import cfg +from maskrcnn_benchmark.data import make_data_loader +from maskrcnn_benchmark.solver import make_lr_scheduler +from maskrcnn_benchmark.solver import make_optimizer +from maskrcnn_benchmark.engine.trainer import do_train +from maskrcnn_benchmark.modeling.detector import build_detection_model +from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer +from maskrcnn_benchmark.utils.collect_env import collect_env_info +from maskrcnn_benchmark.utils.comm import synchronize, get_rank +from maskrcnn_benchmark.utils.imports import import_file +from maskrcnn_benchmark.utils.logging import setup_logger, Logger +from maskrcnn_benchmark.utils.miscellaneous import mkdir +# See if we can use apex.DistributedDataParallel instead of the torch default, +# and enable mixed-precision via apex.amp +try: + from apex import amp +except ImportError: + raise ImportError('Use APEX for multi-precision via apex.amp') + +def train(cfg, local_rank, distributed): + model = build_detection_model(cfg) + device = torch.device(cfg.MODEL.DEVICE) + model.to(device) + + optimizer = make_optimizer(cfg, model) + scheduler = make_lr_scheduler(cfg, optimizer) + + # Initialize mixed-precision training + use_mixed_precision = cfg.DTYPE == "float16" + amp_opt_level = 'O1' if use_mixed_precision else 'O0' + model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level) + + if distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[local_rank], output_device=local_rank, + # this should be removed if we update BatchNorm stats + broadcast_buffers=False, + # find_unused_parameters=True + ) + + arguments = {} + arguments["iteration"] = 0 + + output_dir = cfg.OUTPUT_DIR + + save_to_disk = get_rank() == 0 + checkpointer = DetectronCheckpointer( + cfg, model, optimizer, scheduler, output_dir, save_to_disk + ) + extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT, resume=cfg.SOLVER.RESUME) + if cfg.SOLVER.RESUME: + arguments.update(extra_checkpoint_data) + + data_loader = make_data_loader( + cfg, + is_train=True, + is_distributed=distributed, + start_iter=arguments["iteration"], + ) + + checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD + tb_logger = Logger(cfg.OUTPUT_DIR, local_rank) + do_train( + model, + data_loader, + optimizer, + scheduler, + checkpointer, + device, + checkpoint_period, + arguments, + tb_logger, + cfg, + local_rank, + ) + + return model + +def main(): + parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") + parser.add_argument( + "--config-file", + default="", + metavar="FILE", + help="path to config file", + type=str, + ) + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument( + "--skip-test", + dest="skip_test", + help="Do not test the final model", + action="store_true", + ) + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + + args = parser.parse_args() + + num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 + args.distributed = num_gpus > 1 + + if args.distributed: + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group( + backend="nccl", init_method="env://" + ) + + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + + output_dir = cfg.OUTPUT_DIR + if output_dir: + mkdir(output_dir) + + local_rank = get_rank() + logger = setup_logger("maskrcnn_benchmark", output_dir, local_rank) + if local_rank == 0: + logger.info("Using {} GPUs".format(num_gpus)) + logger.info(args) + + logger.info("Collecting env info (might take some time)") + logger.info("\n" + collect_env_info()) + + logger.info("Loaded configuration file {}".format(args.config_file)) + with open(args.config_file, "r") as cf: + config_str = "\n" + cf.read() + logger.info(config_str) + logger.info("Running with config:\n{}".format(cfg)) + + model = train(cfg, args.local_rank, args.distributed) + + +if __name__ == "__main__": + main() diff --git a/train.sh b/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..24588339c5bb4fdfd7bf7ea10b5b6fa70382946d --- /dev/null +++ b/train.sh @@ -0,0 +1,2 @@ +python -m torch.distributed.launch --nproc_per_node=8 tools/train_net.py --config-file configs/pretrain/seg_rec_poly_fuse_feature.yaml +# python -m torch.distributed.launch --nproc_per_node=8 tools/train_net.py --config-file configs/mixtrain/seg_rec_poly_fuse_feature.yaml \ No newline at end of file