|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from argparse import Namespace |
|
from typing import NamedTuple |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class AdaptorInput(NamedTuple): |
|
images: torch.Tensor |
|
summary: torch.Tensor |
|
features: torch.Tensor |
|
|
|
|
|
class RadioOutput(NamedTuple): |
|
summary: torch.Tensor |
|
features: torch.Tensor |
|
|
|
def to(self, *args, **kwargs): |
|
return RadioOutput( |
|
self.summary.to(*args, **kwargs) if self.summary is not None else None, |
|
self.features.to(*args, **kwargs) if self.features is not None else None, |
|
) |
|
|
|
|
|
class AdaptorBase(nn.Module): |
|
def forward(self, input: AdaptorInput) -> RadioOutput: |
|
raise NotImplementedError("Subclasses must implement this!") |
|
|