import logging import os import pickle import random import zipfile from typing import Any import numpy as np import psutil import torch logger = logging.getLogger(__name__) def set_seed(seed: int = 1234) -> None: """Sets the random seed. Args: seed: seed value """ random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True def set_environment(cfg): """Sets and checks environment settings""" if "GPT" in cfg.prediction.metric and os.getenv("OPENAI_API_KEY", "") == "": logger.warning("No OpenAI API Key set. Setting metric to BLEU. ") cfg.prediction.metric = "BLEU" return cfg def kill_child_processes(parent_pid: int) -> bool: """Killing a process and all its child processes Args: parent_pid: process id of parent Returns: True or False in case of success or failure """ logger.debug(f"Killing process id: {parent_pid}") try: parent = psutil.Process(parent_pid) if parent.status() == "zombie": return False children = parent.children(recursive=True) for child in children: child.kill() parent.kill() return True except psutil.NoSuchProcess: logger.warning(f"Cannot kill process id: {parent_pid}. No such process.") return False def kill_ddp_processes() -> None: """ Killing all DDP processes from a single process. Firstly kills all children of a single DDP process (dataloader workers) Then kills all other DDP processes Then kills main parent DDP process """ pid = os.getpid() parent_pid = os.getppid() current_process = psutil.Process(pid) children = current_process.children(recursive=True) for child in children: child.kill() parent_process = psutil.Process(parent_pid) children = parent_process.children(recursive=True)[::-1] for child in children: if child.pid == pid: continue child.kill() parent_process.kill() current_process.kill() def add_file_to_zip(zf: zipfile.ZipFile, path: str) -> None: """Adds a file to the existing zip. Does nothing if file does not exist. Args: zf: zipfile object to add to path: path to the file to add """ try: zf.write(path, os.path.basename(path)) except Exception: logger.warning(f"File {path} could not be added to zip.") def save_pickle(path: str, obj: Any, protocol: int = 4) -> None: """Saves object as pickle file Args: path: path of file to save obj: object to save protocol: protocol to use when saving pickle """ with open(path, "wb") as pickle_file: pickle.dump(obj, pickle_file, protocol=protocol) class DisableLogger: def __init__(self, level: int = logging.INFO): self.level = level def __enter__(self): logging.disable(self.level) def __exit__(self, exit_type, exit_value, exit_traceback): logging.disable(logging.NOTSET) class PatchedAttribute: """ Patches an attribute of an object for the duration of this context manager. Similar to unittest.mock.patch, but works also for properties that are not present in the original class >>> class MyObj: ... attr = 'original' >>> my_obj = MyObj() >>> with PatchedAttribute(my_obj, 'attr', 'patched'): ... print(my_obj.attr) patched >>> print(my_obj.attr) original >>> with PatchedAttribute(my_obj, 'new_attr', 'new_patched'): ... print(my_obj.new_attr) new_patched >>> assert not hasattr(my_obj, 'new_attr') """ def __init__(self, obj, attribute, new_value): self.obj = obj self.attribute = attribute self.new_value = new_value self.original_exists = hasattr(obj, attribute) if self.original_exists: self.original_value = getattr(obj, attribute) def __enter__(self): setattr(self.obj, self.attribute, self.new_value) def __exit__(self, exc_type, exc_val, exc_tb): if self.original_exists: setattr(self.obj, self.attribute, self.original_value) else: delattr(self.obj, self.attribute)