|
""" |
|
Useful class for Experiment tracking, and ensuring code is |
|
saved alongside files. |
|
""" |
|
import datetime |
|
import os |
|
import shlex |
|
import shutil |
|
import subprocess |
|
import typing |
|
from pathlib import Path |
|
|
|
import randomname |
|
|
|
|
|
class Experiment: |
|
"""This class contains utilities for managing experiments. |
|
It is a context manager, that when you enter it, changes |
|
your directory to a specified experiment folder (which |
|
optionally can have an automatically generated experiment |
|
name, or a specified one), and changes the CUDA device used |
|
to the specified device (or devices). |
|
|
|
Parameters |
|
---------- |
|
exp_directory : str |
|
Folder where all experiments are saved, by default "runs/". |
|
exp_name : str, optional |
|
Name of the experiment, by default uses the current time, date, and |
|
hostname to save. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
exp_directory: str = "runs/", |
|
exp_name: str = None, |
|
): |
|
if exp_name is None: |
|
exp_name = self.generate_exp_name() |
|
exp_dir = Path(exp_directory) / exp_name |
|
exp_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
self.exp_dir = exp_dir |
|
self.exp_name = exp_name |
|
self.git_tracked_files = ( |
|
subprocess.check_output( |
|
shlex.split("git ls-tree --full-tree --name-only -r HEAD") |
|
) |
|
.decode("utf-8") |
|
.splitlines() |
|
) |
|
self.parent_directory = Path(".").absolute() |
|
|
|
def __enter__(self): |
|
self.prev_dir = os.getcwd() |
|
os.chdir(self.exp_dir) |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
os.chdir(self.prev_dir) |
|
|
|
@staticmethod |
|
def generate_exp_name(): |
|
"""Generates a random experiment name based on the date |
|
and a randomly generated adjective-noun tuple. |
|
|
|
Returns |
|
------- |
|
str |
|
Randomly generated experiment name. |
|
""" |
|
date = datetime.datetime.now().strftime("%y%m%d") |
|
name = f"{date}-{randomname.get_name()}" |
|
return name |
|
|
|
def snapshot(self, filter_fn: typing.Callable = lambda f: True): |
|
"""Captures a full snapshot of all the files tracked by git at the time |
|
the experiment is run. It also captures the diff against the committed |
|
code as a separate file. |
|
|
|
Parameters |
|
---------- |
|
filter_fn : typing.Callable, optional |
|
Function that can be used to exclude some files |
|
from the snapshot, by default accepts all files |
|
""" |
|
for f in self.git_tracked_files: |
|
if filter_fn(f): |
|
Path(f).parent.mkdir(parents=True, exist_ok=True) |
|
shutil.copyfile(self.parent_directory / f, f) |
|
|