# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from argparse import Namespace from typing import NamedTuple, Optional import torch from torch import nn import torch.nn.functional as F class AdaptorInput(NamedTuple): images: torch.Tensor summary: torch.Tensor features: torch.Tensor feature_fmt: str patch_size: int class RadioOutput(NamedTuple): summary: torch.Tensor features: torch.Tensor def to(self, *args, **kwargs): return RadioOutput( self.summary.to(*args, **kwargs) if self.summary is not None else None, self.features.to(*args, **kwargs) if self.features is not None else None, ) class AdaptorBase(nn.Module): def forward(self, input: AdaptorInput) -> RadioOutput: raise NotImplementedError("Subclasses must implement this!")