|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Union, Tuple |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
norm_t = Union[Tuple[float, float, float], torch.Tensor] |
|
|
|
class InputConditioner(nn.Module): |
|
def __init__(self, |
|
input_scale: float, |
|
norm_mean: norm_t, |
|
norm_std: norm_t, |
|
dtype: torch.dtype = None, |
|
): |
|
super().__init__() |
|
|
|
self.dtype = dtype |
|
|
|
self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale) |
|
self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale) |
|
|
|
def forward(self, x: torch.Tensor): |
|
y = (x - self.norm_mean) / self.norm_std |
|
if self.dtype is not None: |
|
y = y.to(self.dtype) |
|
return y |
|
|
|
|
|
def get_default_conditioner(): |
|
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD |
|
|
|
return InputConditioner( |
|
input_scale=1.0, |
|
norm_mean=OPENAI_CLIP_MEAN, |
|
norm_std=OPENAI_CLIP_STD, |
|
) |
|
|
|
|
|
def _to_tensor(v: norm_t): |
|
return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1) |
|
|