Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Test script for S3DIS 6-fold cross validation | |
Gathering Area_X.pth from result folder of experiment record of each area as follows: | |
|- RECORDS_PATH | |
|- Area_1.pth | |
|- Area_2.pth | |
|- Area_3.pth | |
|- Area_4.pth | |
|- Area_5.pth | |
|- Area_6.pth | |
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
Please cite our work if the code is helpful to you. | |
""" | |
import argparse | |
import os | |
import torch | |
import numpy as np | |
import glob | |
from pointcept.utils.logger import get_root_logger | |
CLASS_NAMES = [ | |
"ceiling", | |
"floor", | |
"wall", | |
"beam", | |
"column", | |
"window", | |
"door", | |
"table", | |
"chair", | |
"sofa", | |
"bookcase", | |
"board", | |
"clutter", | |
] | |
def evaluation(intersection, union, target, logger=None): | |
iou_class = intersection / (union + 1e-10) | |
accuracy_class = intersection / (target + 1e-10) | |
mIoU = np.mean(iou_class) | |
mAcc = np.mean(accuracy_class) | |
allAcc = sum(intersection) / (sum(target) + 1e-10) | |
if logger is not None: | |
logger.info( | |
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}".format( | |
mIoU, mAcc, allAcc | |
) | |
) | |
for i in range(len(CLASS_NAMES)): | |
logger.info( | |
"Class_{idx} - {name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( | |
idx=i, | |
name=CLASS_NAMES[i], | |
iou=iou_class[i], | |
accuracy=accuracy_class[i], | |
) | |
) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--record_root", | |
required=True, | |
help="Path to the S3DIS record of each split", | |
) | |
config = parser.parse_args() | |
logger = get_root_logger( | |
log_file=os.path.join(config.record_root, "6-fold.log"), | |
file_mode="w", | |
) | |
records = sorted(glob.glob(os.path.join(config.record_root, "Area_*.pth"))) | |
assert len(records) == 6 | |
intersection_ = np.zeros(len(CLASS_NAMES), dtype=int) | |
union_ = np.zeros(len(CLASS_NAMES), dtype=int) | |
target_ = np.zeros(len(CLASS_NAMES), dtype=int) | |
for record in records: | |
area = os.path.basename(record).split(".")[0] | |
info = torch.load(record) | |
logger.info(f"<<<<<<<<<<<<<<<<< Parsing {area} <<<<<<<<<<<<<<<<<") | |
intersection = info["intersection"] | |
union = info["union"] | |
target = info["target"] | |
evaluation(intersection, union, target, logger=logger) | |
intersection_ += intersection | |
union_ += union | |
target_ += target | |
logger.info(f"<<<<<<<<<<<<<<<<< Parsing 6-fold <<<<<<<<<<<<<<<<<") | |
evaluation(intersection_, union_, target_, logger=logger) | |
if __name__ == "__main__": | |
main() | |