|
|
|
"""Contains some utility operators.""" |
|
|
|
import math |
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn.functional as F |
|
|
|
__all__ = [ |
|
'all_gather', |
|
'upsample', |
|
'downsample', |
|
] |
|
|
|
|
|
def all_gather(tensor): |
|
"""Gathers tensor from all devices and executes averaging.""" |
|
if not dist.is_initialized(): |
|
return tensor |
|
|
|
world_size = dist.get_world_size() |
|
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)] |
|
dist.all_gather(tensor_list, tensor, async_op=False) |
|
return torch.stack(tensor_list, dim=0).mean(dim=0) |
|
|
|
|
|
def upsample(img_nerf, size, filter=None): |
|
up = size // img_nerf.size(-1) |
|
if up <= 1: |
|
return img_nerf |
|
|
|
if filter is not None: |
|
from third_party.stylegan2_official_ops import upfirdn2d |
|
for _ in range(int(math.log2(up))): |
|
img_nerf = upfirdn2d.downsample2d(img_nerf, filter, up=2) |
|
else: |
|
img_nerf = F.interpolate(img_nerf, (size, size), |
|
mode='bilinear', |
|
align_corners=False) |
|
return img_nerf |
|
|
|
|
|
def downsample(img0, size, filter=None): |
|
down = img0.size(-1) // size |
|
if down <= 1: |
|
return img0 |
|
|
|
if filter is not None: |
|
from third_party.stylegan2_official_ops import upfirdn2d |
|
for _ in range(int(math.log2(down))): |
|
img0 = upfirdn2d.downsample2d(img0, filter, down=2) |
|
else: |
|
img0 = F.interpolate(img0, (size, size), |
|
mode='bilinear', |
|
align_corners=False) |
|
return img0 |