AK391
files
d380b77
raw
history blame
689 Bytes
from enum import Enum
import yaml
from easydict import EasyDict as edict
import torch.nn as nn
import torch
def load_yaml(path):
with open(path, 'r') as f:
return edict(yaml.safe_load(f))
def move_to_device(obj, device):
if isinstance(obj, nn.Module):
return obj.to(device)
if torch.is_tensor(obj):
return obj.to(device)
if isinstance(obj, (tuple, list)):
return [move_to_device(el, device) for el in obj]
if isinstance(obj, dict):
return {name: move_to_device(val, device) for name, val in obj.items()}
raise ValueError(f'Unexpected type {type(obj)}')
class SmallMode(Enum):
DROP = "drop"
UPSCALE = "upscale"