|
import argparse |
|
import math |
|
import os |
|
import shlex |
|
import subprocess |
|
import uuid |
|
from distutils.util import strtobool |
|
|
|
import requests |
|
|
|
|
|
def parse_args(): |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--command", type=str, default="", |
|
help="the command to run") |
|
parser.add_argument("--num-seeds", type=int, default=3, |
|
help="the number of random seeds") |
|
parser.add_argument("--start-seed", type=int, default=1, |
|
help="the number of the starting seed") |
|
parser.add_argument("--workers", type=int, default=0, |
|
help="the number of workers to run benchmark experimenets") |
|
parser.add_argument("--auto-tag", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
|
help="if toggled, the runs will be tagged with git tags, commit, and pull request number if possible") |
|
parser.add_argument("--slurm-template-path", type=str, default=None, |
|
help="the path to the slurm template file (see docs for more details)") |
|
parser.add_argument("--slurm-gpus-per-task", type=int, default=1, |
|
help="the number of gpus per task to use for slurm jobs") |
|
parser.add_argument("--slurm-total-cpus", type=int, default=50, |
|
help="the number of gpus per task to use for slurm jobs") |
|
parser.add_argument("--slurm-ntasks", type=int, default=1, |
|
help="the number of tasks to use for slurm jobs") |
|
parser.add_argument("--slurm-nodes", type=int, default=None, |
|
help="the number of nodes to use for slurm jobs") |
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
|
|
def run_experiment(command: str): |
|
command_list = shlex.split(command) |
|
print(f"running {command}") |
|
|
|
|
|
fd = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
output, errors = fd.communicate() |
|
|
|
return_code = fd.returncode |
|
assert return_code == 0, f"Command failed with error: {errors.decode('utf-8')}" |
|
|
|
|
|
return output.decode("utf-8").strip() |
|
|
|
|
|
def autotag() -> str: |
|
wandb_tag = "" |
|
print("autotag feature is enabled") |
|
git_tag = "" |
|
try: |
|
git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip() |
|
print(f"identified git tag: {git_tag}") |
|
except subprocess.CalledProcessError as e: |
|
print(e) |
|
if len(git_tag) == 0: |
|
try: |
|
count = int(subprocess.check_output(["git", "rev-list", "--count", "HEAD"]).decode("ascii").strip()) |
|
hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip() |
|
git_tag = f"no-tag-{count}-g{hash}" |
|
print(f"identified git tag: {git_tag}") |
|
except subprocess.CalledProcessError as e: |
|
print(e) |
|
wandb_tag = f"{git_tag}" |
|
|
|
git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip() |
|
try: |
|
|
|
prs = requests.get(f"https://api.github.com/search/issues?q=repo:huggingface/trl+is:pr+{git_commit}") |
|
if prs.status_code == 200: |
|
prs = prs.json() |
|
if len(prs["items"]) > 0: |
|
pr = prs["items"][0] |
|
pr_number = pr["number"] |
|
wandb_tag += f",pr-{pr_number}" |
|
print(f"identified github pull request: {pr_number}") |
|
except Exception as e: |
|
print(e) |
|
|
|
return wandb_tag |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
if args.auto_tag: |
|
existing_wandb_tag = os.environ.get("WANDB_TAGS", "") |
|
wandb_tag = autotag() |
|
if len(wandb_tag) > 0: |
|
if len(existing_wandb_tag) > 0: |
|
os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag]) |
|
else: |
|
os.environ["WANDB_TAGS"] = wandb_tag |
|
print("WANDB_TAGS: ", os.environ.get("WANDB_TAGS", "")) |
|
commands = [] |
|
for seed in range(0, args.num_seeds): |
|
commands += [" ".join([args.command, "--seed", str(args.start_seed + seed)])] |
|
|
|
print("======= commands to run:") |
|
for command in commands: |
|
print(command) |
|
|
|
if args.workers > 0 and args.slurm_template_path is None: |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
executor = ThreadPoolExecutor(max_workers=args.workers, thread_name_prefix="cleanrl-benchmark-worker-") |
|
for command in commands: |
|
executor.submit(run_experiment, command) |
|
executor.shutdown(wait=True) |
|
else: |
|
print("not running the experiments because --workers is set to 0; just printing the commands to run") |
|
|
|
|
|
if args.slurm_template_path is not None: |
|
if not os.path.exists("slurm"): |
|
os.makedirs("slurm") |
|
if not os.path.exists("slurm/logs"): |
|
os.makedirs("slurm/logs") |
|
print("======= slurm commands to run:") |
|
with open(args.slurm_template_path) as f: |
|
slurm_template = f.read() |
|
slurm_template = slurm_template.replace("{{array}}", f"0-{len(commands) - 1}%{args.workers}") |
|
slurm_template = slurm_template.replace( |
|
"{{seeds}}", f"({' '.join([str(args.start_seed + int(seed)) for seed in range(args.num_seeds)])})" |
|
) |
|
slurm_template = slurm_template.replace("{{len_seeds}}", f"{args.num_seeds}") |
|
slurm_template = slurm_template.replace("{{command}}", args.command) |
|
slurm_template = slurm_template.replace("{{gpus_per_task}}", f"{args.slurm_gpus_per_task}") |
|
total_gpus = args.slurm_gpus_per_task * args.slurm_ntasks |
|
slurm_cpus_per_gpu = math.ceil(args.slurm_total_cpus / total_gpus) |
|
slurm_template = slurm_template.replace("{{cpus_per_gpu}}", f"{slurm_cpus_per_gpu}") |
|
slurm_template = slurm_template.replace("{{ntasks}}", f"{args.slurm_ntasks}") |
|
if args.slurm_nodes is not None: |
|
slurm_template = slurm_template.replace("{{nodes}}", f"#SBATCH --nodes={args.slurm_nodes}") |
|
else: |
|
slurm_template = slurm_template.replace("{{nodes}}", "") |
|
filename = str(uuid.uuid4()) |
|
open(os.path.join("slurm", f"{filename}.slurm"), "w").write(slurm_template) |
|
slurm_path = os.path.join("slurm", f"{filename}.slurm") |
|
print(f"saving command in {slurm_path}") |
|
if args.workers > 0: |
|
job_id = run_experiment(f"sbatch --parsable {slurm_path}") |
|
print(f"Job ID: {job_id}") |
|
|