File size: 33,045 Bytes
01664b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ceacf4
01664b3
5ceacf4
01664b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ceacf4
01664b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ceacf4
01664b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ceacf4
 
 
 
01664b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ceacf4
 
 
01664b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ceacf4
 
 
 
 
 
 
 
01664b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ceacf4
01664b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ceacf4
01664b3
5ceacf4
 
01664b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
"""
inference.py
------------
Provides functionality to run the OPDMulti model on an input image, independent of dataset and ground truth, and 
visualize the output. Large portions of the code originate from get_prediction.py, rgbd_to_pcd_vis.py, 
evaluate_on_log.py, and other related files. The primary goal was to create a more standalone script which could be 
converted more easily into a public demo, thus the goal was to sever most dependencies on existing ground truth or 
datasets.

Example usage:
python inference.py \
    --rgb path/to/59-4860.png \
    --depth path/to/59-4860_d.png \
    --model path/to/model.pth \
    --output path/to/output_dir
"""

import argparse
import logging
import os
import time
from copy import deepcopy
from typing import Any

import imageio
import open3d as o3d
import numpy as np
import torch
import torch.nn as nn
from detectron2 import engine, evaluation
from detectron2.modeling import build_model
from detectron2.config import get_cfg, CfgNode
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.structures import instances
from detectron2.utils import comm
from detectron2.utils.logger import setup_logger
from PIL import Image, ImageChops

from mask2former import (
    add_maskformer2_config,
    add_motionnet_config,
)
from utilities import prediction_to_json

# import based on torch version. Required for model loading. Code is taken from fvcore.common.checkpoint, in order to
# replicate model loading without the overhead of setting up an OPDTrainer

TORCH_VERSION: tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 11):
    from torch.ao import quantization
    from torch.ao.quantization import FakeQuantizeBase, ObserverBase
elif (
    TORCH_VERSION >= (1, 8)
    and hasattr(torch.quantization, "FakeQuantizeBase")
    and hasattr(torch.quantization, "ObserverBase")
):
    from torch import quantization
    from torch.quantization import FakeQuantizeBase, ObserverBase

# TODO: find a global place for this instead of in many places in code
TYPE_CLASSIFICATION = {
    0: "rotation",
    1: "translation",
}

POINT_COLOR = [1, 0, 0]  # red for demonstration
ARROW_COLOR = [0, 1, 0]  # green
IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg")


def get_parser() -> argparse.ArgumentParser:
    """
    Specfy command-line arguments.

    The primary inputs to the script should be the image paths (RGBD) and camera intrinsics. Other arguments are
    provided to facilitate script testing and model changes. Run file with -h/--help to see all arguments.

    :return: parser for extracting command-line arguments
    """
    parser = argparse.ArgumentParser(description="Inference for OPDMulti")
    # The main arguments which should be specified by the user
    parser.add_argument(
        "--rgb",
        dest="rgb_image",
        metavar="FILE",
        help="path to RGB image file on which to run model",
    )
    parser.add_argument(
        "--depth",
        dest="depth_image",
        metavar="FILE",
        help="path to depth image file on which to run model",
    )
    parser.add_argument(  # FIXME: might make more sense to make this a path
        "-i",
        "--intrinsics",
        nargs=9,
        default=[
            214.85935872395834,
            0.0,
            0.0,
            0.0,
            214.85935872395834,
            0.0,
            125.90160319010417,
            95.13726399739583,
            1.0,
        ],
        dest="intrinsics",
        help="camera intrinsics matrix, as a list of values",
    )

    # optional parameters for user to specify
    parser.add_argument(
        "-n",
        "--num-samples",
        default=10,
        dest="num_samples",
        metavar="NUM",
        help="number of sample states to generate in visualization",
    )
    parser.add_argument(
        "--crop",
        action="store_true",
        dest="crop",
        help="crop whitespace out of images for visualization",
    )

    # local script development arguments
    parser.add_argument(
        "-m",
        "--model",
        default="path/to/model/file",  # FIXME: set a good default path
        dest="model",
        metavar="FILE",
        help="path to model file to run",
    )
    parser.add_argument(
        "-c",
        "--config",
        default="configs/coco/instance-segmentation/swin/opd_v1_real.yaml",
        metavar="FILE",
        dest="config_file",
        help="path to config file",
    )
    parser.add_argument(
        "-o",
        "--output",
        default="output",  # FIXME: set a good default path
        dest="output",
        help="path to output directory in which to save results",
    )
    parser.add_argument(
        "--num-processes",
        default=1,
        dest="num_processes",
        help="number of processes per machine. When using GPUs, this should be the number of GPUs.",
    )
    parser.add_argument(
        "-s",
        "--score-threshold",
        default=0.8,
        type=float,
        dest="score_threshold",
        help="threshold between 0.0 and 1.0 by which to filter out bad predictions",
    )
    parser.add_argument(
        "--input-format",
        default="RGB",
        dest="input_format",
        help="input format of image. Must be one of RGB, RGBD, or depth",
    )
    parser.add_argument(
        "--cpu",
        action="store_true",
        help="flag to require code to use CPU only",
    )

    return parser


def setup_cfg(args: argparse.Namespace) -> CfgNode:
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    # add model configurations
    add_deeplab_config(cfg)
    add_maskformer2_config(cfg)
    add_motionnet_config(cfg)
    cfg.merge_from_file(args.config_file)

    # set additional config parameters
    cfg.MODEL.WEIGHTS = args.model
    cfg.OBJ_DETECT = False  # TODO: figure out if this is needed, and parameterize it
    cfg.MODEL.MOTIONNET.VOTING = "none"
    # Output directory
    cfg.OUTPUT_DIR = args.output
    cfg.MODEL.DEVICE = "cpu" if args.cpu else "cuda"

    cfg.MODEL.MODELATTRPATH = None

    # Input format
    cfg.INPUT.FORMAT = args.input_format
    if args.input_format == "RGB":
        cfg.MODEL.PIXEL_MEAN = cfg.MODEL.PIXEL_MEAN[0:3]
        cfg.MODEL.PIXEL_STD = cfg.MODEL.PIXEL_STD[0:3]
    elif args.input_format == "depth":
        cfg.MODEL.PIXEL_MEAN = cfg.MODEL.PIXEL_MEAN[3:4]
        cfg.MODEL.PIXEL_STD = cfg.MODEL.PIXEL_STD[3:4]
    elif args.input_format == "RGBD":
        pass
    else:
        raise ValueError("Invalid input format")

    cfg.freeze()
    engine.default_setup(cfg, args)

    # Setup logger for "mask_former" module
    setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="opdformer")
    return cfg


def format_input(rgb_path: str) -> list[dict[str, Any]]:
    """
    Read and format input image into detectron2 form so that it can be passed to the model.

    :param rgb_path: path to RGB image file
    :return: list of dictionaries per image, where each dictionary is of the form
        {
            "file_name": path to RGB image,
            "image": torch.Tensor of dimensions [channel, height, width] representing the image
        }
    """
    image = imageio.imread(rgb_path).astype(np.float32)
    image_tensor = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))  # dim: [channel, height, width]
    return [{"file_name": rgb_path, "image": image_tensor}]


def load_model(model: nn.Module, checkpoint: Any) -> None:
    """
    Load weights from a checkpoint.

    The majority of the function definition is taken from the DetectionCheckpointer implementation provided in
    detectron2. While not all of this code is necessarily needed for model loading, it was ported with the intention
    of keeping the implementation and output as close to the original as possible, and reusing the checkpoint class here
    in isolation was determined to be infeasible.

    :param model: model for which to load weights
    :param checkpoint: checkpoint contains the weights.
    """

    def _strip_prefix_if_present(state_dict: dict[str, Any], prefix: str) -> None:
        """If prefix is found on all keys in state dict, remove prefix."""
        keys = sorted(state_dict.keys())
        if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
            return

        for key in keys:
            newkey = key[len(prefix) :]
            state_dict[newkey] = state_dict.pop(key)

    checkpoint_state_dict = checkpoint.pop("model")

    # convert from numpy to tensor
    for k, v in checkpoint_state_dict.items():
        if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor):
            raise ValueError("Unsupported type found in checkpoint! {}: {}".format(k, type(v)))
        if not isinstance(v, torch.Tensor):
            checkpoint_state_dict[k] = torch.from_numpy(v)

    # if the state_dict comes from a model that was wrapped in a
    # DataParallel or DistributedDataParallel during serialization,
    # remove the "module" prefix before performing the matching.
    _strip_prefix_if_present(checkpoint_state_dict, "module.")

    # workaround https://github.com/pytorch/pytorch/issues/24139
    model_state_dict = model.state_dict()
    incorrect_shapes = []
    for k in list(checkpoint_state_dict.keys()):  # state dict is modified in loop, so list op is necessary
        if k in model_state_dict:
            model_param = model_state_dict[k]
            # Allow mismatch for uninitialized parameters
            if TORCH_VERSION >= (1, 8) and isinstance(model_param, nn.parameter.UninitializedParameter):
                continue
            shape_model = tuple(model_param.shape)
            shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
            if shape_model != shape_checkpoint:
                has_observer_base_classes = (
                    TORCH_VERSION >= (1, 8)
                    and hasattr(quantization, "ObserverBase")
                    and hasattr(quantization, "FakeQuantizeBase")
                )
                if has_observer_base_classes:
                    # Handle the special case of quantization per channel observers,
                    # where buffer shape mismatches are expected.
                    def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module:
                        # foo.bar.param_or_buffer_name -> [foo, bar]
                        key_parts = key.split(".")[:-1]
                        cur_module = model
                        for key_part in key_parts:
                            cur_module = getattr(cur_module, key_part)
                        return cur_module

                    cls_to_skip = (
                        ObserverBase,
                        FakeQuantizeBase,
                    )
                    target_module = _get_module_for_key(model, k)
                    if isinstance(target_module, cls_to_skip):
                        # Do not remove modules with expected shape mismatches
                        # them from the state_dict loading. They have special logic
                        # in _load_from_state_dict to handle the mismatches.
                        continue

                incorrect_shapes.append((k, shape_checkpoint, shape_model))
                checkpoint_state_dict.pop(k)

    model.load_state_dict(checkpoint_state_dict, strict=False)


def predict(model: nn.Module, inp: list[dict[str, Any]]) -> list[dict[str, instances.Instances]]:
    """
    Compute model predictions.

    :param model: model to run on input
    :param inp: input, in the form
        {
            "image_file": path to image,
            "image": float32 torch.tensor of dimensions [channel, height, width] as RGB/RGBD/depth image
        }
    :return: list of detected instances and predicted openable parameters
    """
    with torch.no_grad(), evaluation.inference_context(model):
        out = model(inp)
    return out


def generate_rotation_visualization(
    pcd: o3d.geometry.PointCloud,
    axis_arrow: o3d.geometry.TriangleMesh,
    mask: np.ndarray,
    axis_vector: np.ndarray,
    origin: np.ndarray,
    range_min: float,
    range_max: float,
    num_samples: int,
    output_dir: str,
) -> None:
    """
    Generate visualization files for a rotation motion of a part.

    :param pcd: point cloud object representing 2D image input (RGBD) as a point cloud
    :param axis_arrow: mesh object representing axis arrow of rotation to be rendered in visualization
    :param mask: mask np.array of dimensions (height, width) representing the part to be rotated in the image
    :param axis_vector: np.array of dimensions (3, ) representing the vector of the axis of rotation
    :param origin: np.array of dimensions (3, ) representing the origin point of the axis of rotation
    :param range_min: float representing the minimum range of motion in radians
    :param range_max: float representing the maximum range of motion in radians
    :param num_samples: number of sample states to visualize in between range_min and range_max of motion
    :param output_dir: string path to directory in which to save visualization output
    """
    angle_in_radians = np.linspace(range_min, range_max, num_samples)
    angles_in_degrees = angle_in_radians * 180 / np.pi

    for idx, angle_in_degrees in enumerate(angles_in_degrees):
        # Make a copy of your original point cloud and arrow for each rotation
        rotated_pcd = deepcopy(pcd)
        rotated_arrow = deepcopy(axis_arrow)

        angle_rad = np.radians(angle_in_degrees)
        rotated_pcd = rotate_part(rotated_pcd, mask, axis_vector, origin, angle_rad)

        # Create a Visualizer object for each rotation
        vis = o3d.visualization.Visualizer()
        vis.create_window()

        # Add the rotated geometries
        vis.add_geometry(rotated_pcd)
        vis.add_geometry(rotated_arrow)

        # Apply the additional rotation around x-axis if desired
        angle_x = np.pi * 5.5 / 5  # 198 degrees
        rotation_matrix = o3d.geometry.get_rotation_matrix_from_axis_angle(np.asarray([1, 0, 0]) * angle_x)
        rotated_pcd.rotate(rotation_matrix, center=rotated_pcd.get_center())
        rotated_arrow.rotate(rotation_matrix, center=rotated_pcd.get_center())

        # Capture and save the image
        output_filename = f"{output_dir}/{idx}.png"
        vis.capture_screen_image(output_filename, do_render=True)
        vis.destroy_window()


def generate_translation_visualization(
    pcd: o3d.geometry.PointCloud,
    axis_arrow: o3d.geometry.TriangleMesh,
    mask: np.ndarray,
    end: np.ndarray,
    range_min: float,
    range_max: float,
    num_samples: int,
    output_dir: str,
) -> None:
    """
    Generate visualization files for a translation motion of a part.

    :param pcd: point cloud object representing 2D image input (RGBD) as a point cloud
    :param axis_arrow: mesh object representing axis arrow of translation to be rendered in visualization
    :param mask: mask np.array of dimensions (height, width) representing the part to be translated in the image
    :param axis_vector: np.array of dimensions (3, ) representing the vector of the axis of translation
    :param origin: np.array of dimensions (3, ) representing the origin point of the axis of translation
    :param range_min: float representing the minimum range of motion
    :param range_max: float representing the maximum range of motion
    :param num_samples: number of sample states to visualize in between range_min and range_max of motion
    :param output_dir: string path to directory in which to save visualization output
    """
    translate_distances = np.linspace(range_min, range_max, num_samples)
    for idx, translate_distance in enumerate(translate_distances):
        translated_pcd = deepcopy(pcd)
        translated_arrow = deepcopy(axis_arrow)

        translated_pcd = translate_part(translated_pcd, mask, end, translate_distance.item())

        # Create a Visualizer object for each rotation
        vis = o3d.visualization.Visualizer()
        vis.create_window()

        # Add the translated geometries
        vis.add_geometry(translated_pcd)
        vis.add_geometry(translated_arrow)

        # Apply the additional rotation around x-axis if desired
        # TODO: not sure why we need this rotation for the translation, and when it would be desired
        angle_x = np.pi * 5.5 / 5  # 198 degrees
        R = o3d.geometry.get_rotation_matrix_from_axis_angle(np.asarray([1, 0, 0]) * angle_x)
        translated_pcd.rotate(R, center=translated_pcd.get_center())
        translated_arrow.rotate(R, center=translated_pcd.get_center())

        # Capture and save the image
        output_filename = f"{output_dir}/{idx}.png"
        vis.capture_screen_image(output_filename, do_render=True)
        vis.destroy_window()


def get_rotation_matrix_from_vectors(vec1: np.ndarray, vec2: np.ndarray) -> np.ndarray:
    """
    Find the rotation matrix that aligns vec1 to vec2

    :param vec1: A 3d "source" vector
    :param vec2: A 3d "destination" vector
    :return: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
    """
    a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3)
    v = np.cross(a, b)
    c = np.dot(a, b)
    s = np.linalg.norm(v)
    kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
    rotation_matrix = np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s**2))
    return rotation_matrix


def draw_line(start_point: np.ndarray, end_point: np.ndarray) -> o3d.geometry.TriangleMesh:
    """
    Generate 3D mesh representing axis from start_point to end_point.

    :param start_point: np.ndarray of dimensions (3, ) representing the start point of the axis
    :param end_point: np.ndarray of dimensions (3, ) representing the end point of the axis
    :return: mesh object representing axis from start to end
    """
    # Compute direction vector and normalize it
    direction_vector = end_point - start_point
    normalized_vector = direction_vector / np.linalg.norm(direction_vector)

    # Compute the rotation matrix to align the Z-axis with the desired direction
    target_vector = np.array([0, 0, 1])
    rot_mat = get_rotation_matrix_from_vectors(target_vector, normalized_vector)

    # Create the cylinder (shaft of the arrow)
    cylinder_length = 0.9  # 90% of the total arrow length, you can adjust as needed
    cylinder_radius = 0.01  # Adjust the thickness of the arrow shaft
    cylinder = o3d.geometry.TriangleMesh.create_cylinder(radius=cylinder_radius, height=cylinder_length)

    # Move base of cylinder to origin, rotate, then translate to start_point
    cylinder.translate([0, 0, 0])
    cylinder.rotate(rot_mat, center=[0, 0, 0])
    cylinder.translate(start_point)

    # Create the cone (head of the arrow)
    cone_height = 0.1  # 10% of the total arrow length, adjust as needed
    cone_radius = 0.03  # Adjust the size of the arrowhead
    cone = o3d.geometry.TriangleMesh.create_cone(radius=cone_radius, height=cone_height)

    # Move base of cone to origin, rotate, then translate to end of cylinder
    cone.translate([-0, 0, 0])
    cone.rotate(rot_mat, center=[0, 0, 0])
    cone.translate(start_point + normalized_vector * 0.4)

    arrow = cylinder + cone
    return arrow


def rotate_part(
    pcd: o3d.geometry.PointCloud, mask: np.ndarray, axis_vector: np.ndarray, origin: np.ndarray, angle_rad: float
) -> o3d.geometry.PointCloud:
    """
    Generate rotated point cloud of mask based on provided angle around axis.

    :param pcd: point cloud object representing points of image
    :param mask: mask np.array of dimensions (height, width) representing the part to be rotated in the image
    :param axis_vector: np.array of dimensions (3, ) representing the vector of the axis of rotation
    :param origin: np.array of dimensions (3, ) representing the origin point of the axis of rotation
    :param angle_rad: angle in radians to rotate mask part
    :return: point cloud object after rotation of masked part
    """
    # Get the coordinates of the point cloud as a numpy array
    points_np = np.asarray(pcd.points)

    # Convert point cloud colors to numpy array for easier manipulation
    colors_np = np.asarray(pcd.colors)

    # Create skew-symmetric matrix from end
    K = np.array(
        [
            [0, -axis_vector[2], axis_vector[1]],
            [axis_vector[2], 0, -axis_vector[0]],
            [-axis_vector[1], axis_vector[0], 0],
        ]
    )

    # Compute rotation matrix using Rodrigues' formula
    R = np.eye(3) + np.sin(angle_rad) * K + (1 - np.cos(angle_rad)) * np.dot(K, K)

    # Iterate over the mask and rotate the points corresponding to the object pixels
    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            if mask[i, j] > 0:  # This condition checks if the pixel belongs to the object
                point_index = i * mask.shape[1] + j

                # Translate the point such that the rotation origin is at the world origin
                translated_point = points_np[point_index] - origin

                # Rotate the translated point
                rotated_point = np.dot(R, translated_point)

                # Translate the point back
                points_np[point_index] = rotated_point + origin

                colors_np[point_index] = POINT_COLOR

    # Update the point cloud's coordinates
    pcd.points = o3d.utility.Vector3dVector(points_np)

    # Update point cloud colors
    pcd.colors = o3d.utility.Vector3dVector(colors_np)

    return pcd


def translate_part(pcd, mask, axis_vector, distance):
    """
    Generate translated point cloud of mask based on provided angle around axis.

    :param pcd: point cloud object representing points of image
    :param mask: mask np.array of dimensions (height, width) representing the part to be translated in the image
    :param axis_vector: np.array of dimensions (3, ) representing the vector of the axis of translation
    :param distance: distance within coordinate system to translate mask part
    :return: point cloud object after translation of masked part
    """
    normalized_vector = axis_vector / np.linalg.norm(axis_vector)
    translation_vector = normalized_vector * distance

    # Convert point cloud colors to numpy array for easier manipulation
    colors_np = np.asarray(pcd.colors)

    # Get the coordinates of the point cloud as a numpy array
    points_np = np.asarray(pcd.points)

    # Iterate over the mask and assign the color to the points corresponding to the object pixels
    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            if mask[i, j] > 0:  # This condition checks if the pixel belongs to the object
                point_index = i * mask.shape[1] + j
                colors_np[point_index] = POINT_COLOR
                points_np[point_index] += translation_vector

    # Update point cloud colors
    pcd.colors = o3d.utility.Vector3dVector(colors_np)

    # Update the point cloud's coordinates
    pcd.points = o3d.utility.Vector3dVector(points_np)

    return pcd


def batch_trim(images_path: str, save_path: str, identical: bool = False) -> None:
    """
    Trim white spaces from all images in the given path and save new images to folder.

    :param images_path: local path to folder containing all images. Images must have the extension ".png", ".jpg", or
    ".jpeg".
    :param save_path: local path to folder in which to save trimmed images
    :param identical: if True, will apply same crop to all images, else each image will have its whitespace trimmed
    independently. Note that in the latter case, each image may have a slightly different size.
    """

    def get_trim(im):
        """Trim whitespace from an image and return the cropped image."""
        bg = Image.new(im.mode, im.size, im.getpixel((0, 0)))
        diff = ImageChops.difference(im, bg)
        diff = ImageChops.add(diff, diff, 2.0, -100)
        bbox = diff.getbbox()
        return bbox

    if identical:  #
        images = []
        optimal_box = None

        # load all images
        for image_file in sorted(os.listdir(images_path)):
            if image_file.endswith(IMAGE_EXTENSIONS):
                image_path = os.path.join(images_path, image_file)
                images.append(Image.open(image_path))

        # find optimal box size
        for im in images:
            bbox = get_trim(im)
            if bbox is None:
                bbox = (0, 0, im.size[0], im.size[1])  # bound entire image

            if optimal_box is None:
                optimal_box = bbox
            else:
                optimal_box = (
                    min(optimal_box[0], bbox[0]),
                    min(optimal_box[1], bbox[1]),
                    max(optimal_box[2], bbox[2]),
                    max(optimal_box[3], bbox[3]),
                )

        # apply cropping, if optimal box was found
        for idx, im in enumerate(images):
            im.crop(optimal_box)
            im.save(os.path.join(save_path, f"{idx}.png"))
            im.close()

    else:  # trim each image separately
        for image_file in os.listdir(images_path):
            if image_file.endswith(IMAGE_EXTENSIONS):
                image_path = os.path.join(images_path, image_file)
                with Image.open(image_path) as im:
                    bbox = get_trim(im)
                    trimmed = im.crop(bbox) if bbox else im
                    trimmed.save(os.path.join(save_path, image_file))


def create_gif(image_folder_path: str, num_samples: int, gif_filename: str = "output.gif") -> None:
    """
    Create gif out of folder of images and save to file.

    :param image_folder_path: path to folder containing images (non-recursive). Assumes images are named as {i}.png for
    each of i from 0 to num_samples.
    :param num_samples: number of sampled images to compile into gif.
    :param gif_filename: filename for gif, defaults to "output.gif"
    """
    # Generate a list of image filenames (assuming the images are saved as 0.png, 1.png, etc.)
    image_files = [f"{image_folder_path}/{i}.png" for i in range(num_samples)]

    # Read the images using imageio
    images = [imageio.imread(image_file) for image_file in image_files]
    assert all(
        images[0].shape == im.shape for im in images
    ), f"Found some images with a different shape: {[im.shape for im in images]}"

    # Save images as a gif
    gif_output_path = f"{image_folder_path}/{gif_filename}"
    imageio.mimsave(gif_output_path, images, duration=0.1)

    return


def main(
    cfg: CfgNode,
    rgb_image: str,
    depth_image: str,
    intrinsics: list[float],
    num_samples: int,
    crop: bool,
    score_threshold: float,
) -> None:
    """
    Main inference method.

    :param cfg: configuration object
    :param rgb_image: local path to RGB image
    :param depth_image: local path to depth image
    :param intrinsics: camera intrinsics matrix as a list of 9 values
    :param num_samples: number of sample visualization states to generate
    :param crop: if True, images will be cropped to remove whitespace before visualization
    :param score_threshold: float between 0 and 1 representing threshold at which to filter instances based on score
    """
    logger = logging.getLogger("detectron2")

    # setup data
    logger.info("Loading image.")
    inp = format_input(rgb_image)

    # setup model
    logger.info("Loading model.")
    model = build_model(cfg)
    weights = torch.load(cfg.MODEL.WEIGHTS, map_location=torch.device("cpu"))
    if "model" not in weights:
        weights = {"model": weights}
    load_model(model, weights)

    # run model on data
    logger.info("Running model.")
    prediction = predict(model, inp)[0]  # index 0 since there is only one image
    pred_instances = prediction["instances"]

    # log results
    image_id = os.path.splitext(os.path.basename(rgb_image))[0]
    pred_dict = {"image_id": image_id}
    instances = pred_instances.to(torch.device("cpu"))
    pred_dict["instances"] = prediction_to_json(instances, image_id)
    torch.save(pred_dict, os.path.join(cfg.OUTPUT_DIR, f"{image_id}_prediction.pth"))

    # select best prediction to visualize
    score_ranking = np.argsort([-1 * pred_instances[i].scores.item() for i in range(len(pred_instances))])
    score_ranking = [idx for idx in score_ranking if pred_instances[int(idx)].scores.item() > score_threshold]
    if len(score_ranking) == 0:
        logging.warning("The model did not predict any moving parts above the score threshold.")
        return

    for idx in score_ranking:  # iterate through all best predictions, by score threshold
        pred = pred_instances[int(idx)]  # take highest predicted one
        logger.info("Rendering prediction for instance %d", int(idx))
        output_dir = os.path.join(cfg.OUTPUT_DIR, str(idx))
        os.makedirs(output_dir, exist_ok=True)

        # extract predicted values for visualization
        mask = np.squeeze(pred.pred_masks.cpu().numpy())  # dim: [height, width]
        origin = pred.morigin.cpu().numpy().flatten()  # dim: [3, ]
        axis_vector = pred.maxis.cpu().numpy().flatten()  # dim: [3, ]
        pred_type = TYPE_CLASSIFICATION.get(pred.mtype.item())
        range_min = 0 - pred.mstate.cpu().numpy()
        range_max = pred.mstatemax.cpu().numpy() - pred.mstate.cpu().numpy()

        # process visualization
        color = o3d.io.read_image(rgb_image)
        depth = o3d.io.read_image(depth_image)
        rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(color, depth, convert_rgb_to_intensity=False)
        color_np = np.asarray(color)
        height, width = color_np.shape[:2]

        # generate intrinsics
        intrinsic_matrix = np.reshape(intrinsics, (3, 3), order="F")
        intrinsic_obj = o3d.camera.PinholeCameraIntrinsic(
            width,
            height,
            intrinsic_matrix[0, 0],
            intrinsic_matrix[1, 1],
            intrinsic_matrix[0, 2],
            intrinsic_matrix[1, 2],
        )

        # Convert the RGBD image to a point cloud
        pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, intrinsic_obj)

        # Create a LineSet to visualize the direction vector
        axis_arrow = draw_line(origin, axis_vector + origin)
        axis_arrow.paint_uniform_color(ARROW_COLOR)

        # if USE_GT:
        #     anno_path = f"/localhome/atw7/projects/opdmulti/data/data_demo_dev/59-4860.json"
        #     part_id = 32

        #     # get annotation for the frame
        #     import json

        #     with open(anno_path, "r") as f:
        #         anno = json.load(f)

        #     articulations = anno["articulation"]
        #     for articulation in articulations:
        #         if articulation["partId"] == part_id:
        #             range_min = articulation["rangeMin"] - articulation["state"]
        #             range_max = articulation["rangeMax"] - articulation["state"]
        #             break

        if pred_type == "rotation":
            generate_rotation_visualization(
                pcd,
                axis_arrow,
                mask,
                axis_vector,
                origin,
                range_min,
                range_max,
                num_samples,
                output_dir,
            )
        elif pred_type == "translation":
            generate_translation_visualization(
                pcd,
                axis_arrow,
                mask,
                axis_vector,
                range_min,
                range_max,
                num_samples,
                output_dir,
            )
        else:
            raise ValueError(f"Invalid motion prediction type: {pred_type}")

        if pred_type:
            if crop:  # crop images to remove shared extraneous whitespace
                output_dir_cropped = f"{output_dir}_cropped"
                if not os.path.isdir(output_dir_cropped):
                    os.makedirs(output_dir_cropped)
                batch_trim(output_dir, output_dir_cropped, identical=True)
                # create_gif(output_dir_cropped, num_samples)
            else:  # leave original dimensions of image as-is
                # create_gif(output_dir, num_samples)
                pass


if __name__ == "__main__":
    # parse arguments
    start_time = time.time()
    args = get_parser().parse_args()
    cfg = setup_cfg(args)

    # run main code
    engine.launch(
        main,
        args.num_processes,
        args=(
            cfg,
            args.rgb_image,
            args.depth_image,
            args.intrinsics,
            args.num_samples,
            args.crop,
            args.score_threshold,
        ),
    )
    end_time = time.time()
    print(f"Inference time: {end_time - start_time:.2f} seconds")