import os import gc import random import numpy as np import torch def set_seed(seed: int): """ Sets the seed of the entire notebook so results are the same every time we run. This is for REPRODUCIBILITY. """ np.random.seed(seed) random_state = np.random.RandomState(seed) random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False os.environ['PYTHONHASHSEED'] = str(seed) return random_state def flatten_list(lis): """Given a list, possibly nested to any level, return it flattened.""" new_lis = [] for item in lis: if type(item) == type([]): new_lis.extend(flatten_list(item)) else: new_lis.append(item) return new_lis def clear_torch_cache(): if torch.cuda.is_available: torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect()