pg56714's picture
Upload 267 files
564565f verified
raw
history blame
717 Bytes
from enum import Enum
import yaml
from easydict import EasyDict as edict
import torch.nn as nn
import torch
def load_yaml(path):
with open(path, 'r') as f:
return edict(yaml.safe_load(f))
def move_to_device(obj, device):
if isinstance(obj, nn.Module):
return obj.to(device)
if torch.is_tensor(obj):
return obj.to(device)
if isinstance(obj, (tuple, list)):
return [move_to_device(el, device) for el in obj]
if isinstance(obj, dict):
return {name: move_to_device(val, device) for name, val in obj.items()}
raise ValueError(f'Unexpected type {type(obj)}')
class SmallMode(Enum):
DROP = "drop"
UPSCALE = "upscale"