h2ogpt-chatbot / utils.py
pseudotensor's picture
Add application file and dependencies
efe0924
raw
history blame
971 Bytes
import os
import gc
import random
import numpy as np
import torch
def set_seed(seed: int):
"""
Sets the seed of the entire notebook so results are the same every time we run.
This is for REPRODUCIBILITY.
"""
np.random.seed(seed)
random_state = np.random.RandomState(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(seed)
return random_state
def flatten_list(lis):
"""Given a list, possibly nested to any level, return it flattened."""
new_lis = []
for item in lis:
if type(item) == type([]):
new_lis.extend(flatten_list(item))
else:
new_lis.append(item)
return new_lis
def clear_torch_cache():
if torch.cuda.is_available:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()