File size: 6,504 Bytes
fa4458a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import argparse
import math
import os
import shlex
import subprocess
import uuid
from distutils.util import strtobool

import requests


def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument("--command", type=str, default="",
        help="the command to run")
    parser.add_argument("--num-seeds", type=int, default=3,
        help="the number of random seeds")
    parser.add_argument("--start-seed", type=int, default=1,
        help="the number of the starting seed")
    parser.add_argument("--workers", type=int, default=0,
        help="the number of workers to run benchmark experimenets")
    parser.add_argument("--auto-tag", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="if toggled, the runs will be tagged with git tags, commit, and pull request number if possible")
    parser.add_argument("--slurm-template-path", type=str, default=None,
        help="the path to the slurm template file (see docs for more details)")
    parser.add_argument("--slurm-gpus-per-task", type=int, default=1,
        help="the number of gpus per task to use for slurm jobs")
    parser.add_argument("--slurm-total-cpus", type=int, default=50,
        help="the number of gpus per task to use for slurm jobs")
    parser.add_argument("--slurm-ntasks", type=int, default=1,
        help="the number of tasks to use for slurm jobs")
    parser.add_argument("--slurm-nodes", type=int, default=None,
        help="the number of nodes to use for slurm jobs")
    args = parser.parse_args()
    # fmt: on
    return args


def run_experiment(command: str):
    command_list = shlex.split(command)
    print(f"running {command}")

    # Use subprocess.PIPE to capture the output
    fd = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    output, errors = fd.communicate()

    return_code = fd.returncode
    assert return_code == 0, f"Command failed with error: {errors.decode('utf-8')}"

    # Convert bytes to string and strip leading/trailing whitespaces
    return output.decode("utf-8").strip()


def autotag() -> str:
    wandb_tag = ""
    print("autotag feature is enabled")
    git_tag = ""
    try:
        git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip()
        print(f"identified git tag: {git_tag}")
    except subprocess.CalledProcessError as e:
        print(e)
    if len(git_tag) == 0:
        try:
            count = int(subprocess.check_output(["git", "rev-list", "--count", "HEAD"]).decode("ascii").strip())
            hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip()
            git_tag = f"no-tag-{count}-g{hash}"
            print(f"identified git tag: {git_tag}")
        except subprocess.CalledProcessError as e:
            print(e)
    wandb_tag = f"{git_tag}"

    git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip()
    try:
        # try finding the pull request number on github
        prs = requests.get(f"https://api.github.com/search/issues?q=repo:huggingface/trl+is:pr+{git_commit}")
        if prs.status_code == 200:
            prs = prs.json()
            if len(prs["items"]) > 0:
                pr = prs["items"][0]
                pr_number = pr["number"]
                wandb_tag += f",pr-{pr_number}"
        print(f"identified github pull request: {pr_number}")
    except Exception as e:
        print(e)

    return wandb_tag


if __name__ == "__main__":
    args = parse_args()
    if args.auto_tag:
        existing_wandb_tag = os.environ.get("WANDB_TAGS", "")
        wandb_tag = autotag()
        if len(wandb_tag) > 0:
            if len(existing_wandb_tag) > 0:
                os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag])
            else:
                os.environ["WANDB_TAGS"] = wandb_tag
    print("WANDB_TAGS: ", os.environ.get("WANDB_TAGS", ""))
    commands = []
    for seed in range(0, args.num_seeds):
        commands += [" ".join([args.command, "--seed", str(args.start_seed + seed)])]

    print("======= commands to run:")
    for command in commands:
        print(command)

    if args.workers > 0 and args.slurm_template_path is None:
        from concurrent.futures import ThreadPoolExecutor

        executor = ThreadPoolExecutor(max_workers=args.workers, thread_name_prefix="cleanrl-benchmark-worker-")
        for command in commands:
            executor.submit(run_experiment, command)
        executor.shutdown(wait=True)
    else:
        print("not running the experiments because --workers is set to 0; just printing the commands to run")

    # SLURM logic
    if args.slurm_template_path is not None:
        if not os.path.exists("slurm"):
            os.makedirs("slurm")
        if not os.path.exists("slurm/logs"):
            os.makedirs("slurm/logs")
        print("======= slurm commands to run:")
        with open(args.slurm_template_path) as f:
            slurm_template = f.read()
        slurm_template = slurm_template.replace("{{array}}", f"0-{len(commands) - 1}%{args.workers}")
        slurm_template = slurm_template.replace(
            "{{seeds}}", f"({' '.join([str(args.start_seed + int(seed)) for seed in range(args.num_seeds)])})"
        )
        slurm_template = slurm_template.replace("{{len_seeds}}", f"{args.num_seeds}")
        slurm_template = slurm_template.replace("{{command}}", args.command)
        slurm_template = slurm_template.replace("{{gpus_per_task}}", f"{args.slurm_gpus_per_task}")
        total_gpus = args.slurm_gpus_per_task * args.slurm_ntasks
        slurm_cpus_per_gpu = math.ceil(args.slurm_total_cpus / total_gpus)
        slurm_template = slurm_template.replace("{{cpus_per_gpu}}", f"{slurm_cpus_per_gpu}")
        slurm_template = slurm_template.replace("{{ntasks}}", f"{args.slurm_ntasks}")
        if args.slurm_nodes is not None:
            slurm_template = slurm_template.replace("{{nodes}}", f"#SBATCH --nodes={args.slurm_nodes}")
        else:
            slurm_template = slurm_template.replace("{{nodes}}", "")
        filename = str(uuid.uuid4())
        open(os.path.join("slurm", f"{filename}.slurm"), "w").write(slurm_template)
        slurm_path = os.path.join("slurm", f"{filename}.slurm")
        print(f"saving command in {slurm_path}")
        if args.workers > 0:
            job_id = run_experiment(f"sbatch --parsable {slurm_path}")
            print(f"Job ID: {job_id}")