elineve's picture
Upload 301 files
07423df
raw
history blame
4.42 kB
import logging
import os
import pickle
import random
import zipfile
from typing import Any
import numpy as np
import psutil
import torch
logger = logging.getLogger(__name__)
def set_seed(seed: int = 1234) -> None:
"""Sets the random seed.
Args:
seed: seed value
"""
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def set_environment(cfg):
"""Sets and checks environment settings"""
if "GPT" in cfg.prediction.metric and os.getenv("OPENAI_API_KEY", "") == "":
logger.warning("No OpenAI API Key set. Setting metric to BLEU. ")
cfg.prediction.metric = "BLEU"
return cfg
def kill_child_processes(parent_pid: int) -> bool:
"""Killing a process and all its child processes
Args:
parent_pid: process id of parent
Returns:
True or False in case of success or failure
"""
logger.debug(f"Killing process id: {parent_pid}")
try:
parent = psutil.Process(parent_pid)
if parent.status() == "zombie":
return False
children = parent.children(recursive=True)
for child in children:
child.kill()
parent.kill()
return True
except psutil.NoSuchProcess:
logger.warning(f"Cannot kill process id: {parent_pid}. No such process.")
return False
def kill_ddp_processes() -> None:
"""
Killing all DDP processes from a single process.
Firstly kills all children of a single DDP process (dataloader workers)
Then kills all other DDP processes
Then kills main parent DDP process
"""
pid = os.getpid()
parent_pid = os.getppid()
current_process = psutil.Process(pid)
children = current_process.children(recursive=True)
for child in children:
child.kill()
parent_process = psutil.Process(parent_pid)
children = parent_process.children(recursive=True)[::-1]
for child in children:
if child.pid == pid:
continue
child.kill()
parent_process.kill()
current_process.kill()
def add_file_to_zip(zf: zipfile.ZipFile, path: str) -> None:
"""Adds a file to the existing zip. Does nothing if file does not exist.
Args:
zf: zipfile object to add to
path: path to the file to add
"""
try:
zf.write(path, os.path.basename(path))
except Exception:
logger.warning(f"File {path} could not be added to zip.")
def save_pickle(path: str, obj: Any, protocol: int = 4) -> None:
"""Saves object as pickle file
Args:
path: path of file to save
obj: object to save
protocol: protocol to use when saving pickle
"""
with open(path, "wb") as pickle_file:
pickle.dump(obj, pickle_file, protocol=protocol)
class DisableLogger:
def __init__(self, level: int = logging.INFO):
self.level = level
def __enter__(self):
logging.disable(self.level)
def __exit__(self, exit_type, exit_value, exit_traceback):
logging.disable(logging.NOTSET)
class PatchedAttribute:
"""
Patches an attribute of an object for the duration of this context manager.
Similar to unittest.mock.patch,
but works also for properties that are not present in the original class
>>> class MyObj:
... attr = 'original'
>>> my_obj = MyObj()
>>> with PatchedAttribute(my_obj, 'attr', 'patched'):
... print(my_obj.attr)
patched
>>> print(my_obj.attr)
original
>>> with PatchedAttribute(my_obj, 'new_attr', 'new_patched'):
... print(my_obj.new_attr)
new_patched
>>> assert not hasattr(my_obj, 'new_attr')
"""
def __init__(self, obj, attribute, new_value):
self.obj = obj
self.attribute = attribute
self.new_value = new_value
self.original_exists = hasattr(obj, attribute)
if self.original_exists:
self.original_value = getattr(obj, attribute)
def __enter__(self):
setattr(self.obj, self.attribute, self.new_value)
def __exit__(self, exc_type, exc_val, exc_tb):
if self.original_exists:
setattr(self.obj, self.attribute, self.original_value)
else:
delattr(self.obj, self.attribute)