File size: 5,576 Bytes
07423df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os

# Set this before importing any other modules to be on the safe side
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import argparse
import logging
import sys
import time

import psutil


def check_for_done(process_queue):
    """Checks for finished process ids

    Args:
        process_queue: list of process ids
    Returns:
        (True, process_idx) if there is any finished process
        (False, False) if there is not finished processes
    """

    for i, pid in enumerate(process_queue):
        zombie = False
        try:
            p = psutil.Process(pid)
            zombie = p.status() == "zombie"
        except psutil.NoSuchProcess:
            pass
        if not psutil.pid_exists(pid) or zombie:
            return True, i

    return False, False


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="")
    parser.add_argument(
        "-C", "--config", help="config filename", default=argparse.SUPPRESS
    )
    parser.add_argument("-Y", "--yaml", help="yaml filename", default=argparse.SUPPRESS)
    parser.add_argument(
        "-Q",
        "--process-queue",
        help="process queue to wait for",
        default=argparse.SUPPRESS,
    )
    parser_args, _ = parser.parse_known_args(sys.argv)

    process_queue = []
    if "process_queue" in parser_args and parser_args.process_queue != "":
        process_queue = [int(x) for x in parser_args.process_queue.split(",")]

    while True:
        if len(process_queue) == 0:
            break
        done, num = check_for_done(process_queue)
        if done:
            process_queue.pop(num)
        else:
            time.sleep(30)

    # delayed imports from llm_studio, only after we want to start training
    import subprocess

    import torch

    from llm_studio.src.utils.config_utils import load_config_py, load_config_yaml
    from llm_studio.src.utils.exceptions import (
        LLMAugmentationsException,
        LLMDataException,
        LLMMetricException,
        LLMModelException,
        LLMTrainingException,
    )
    from llm_studio.src.utils.gpu_utils import is_oom_error
    from llm_studio.src.utils.logging_utils import initialize_logging, write_flag
    from llm_studio.src.utils.utils import kill_ddp_processes
    from train import run

    if "config" in parser_args:
        cfg = load_config_py(parser_args.config)
    elif "yaml" in parser_args:
        cfg = load_config_yaml(parser_args.yaml)

    flag_path = os.path.join(cfg.output_directory, "flags{}.json")

    # Check if DDP
    if "WORLD_SIZE" in os.environ:
        local_rank = int(os.environ["LOCAL_RANK"])
        if local_rank == 0:
            write_flag(flag_path.format(""), "status", "running")
    else:
        write_flag(flag_path.format(""), "status", "running")
        local_rank = 0

    initialize_logging(cfg)

    try:
        run(cfg=cfg)
    except Exception as exception:
        write_flag(flag_path.format(local_rank), "status", "failed")
        if is_oom_error(exception):
            logging.error(
                "GPU Out-of-Memory (OOM) error occurred. "
                "Please, reduce the batch size, or input data size, "
                "or model size. Or try gradient checkpointing.",
                exc_info=True,
            )
            write_flag(flag_path.format(local_rank), "info", "OOM error")

            logging.info(
                "<pre>"
                + subprocess.check_output(["nvidia-smi"]).decode("utf-8")
                + "</pre>"
            )

            if torch.cuda.is_available():
                logging.info(
                    "<pre>" + torch.cuda.memory_summary().replace("-", "=") + "</pre>"
                )

        elif isinstance(exception, LLMDataException):
            logging.error(
                "Data error occurred during H2O LLM Studio run:", exc_info=True
            )
            write_flag(flag_path.format(local_rank), "info", "Data error")
        elif isinstance(exception, LLMTrainingException):
            logging.error(
                "Training error occurred during H2O LLM Studio run:", exc_info=True
            )
            write_flag(flag_path.format(local_rank), "info", "Training error")
        elif isinstance(exception, LLMMetricException):
            logging.error(
                "Validation metric failed. Please make sure selected validation "
                "metric is suitable for your current problem setup.",
                exc_info=True,
            )
            write_flag(flag_path.format(local_rank), "info", "Metric error")
        elif isinstance(exception, LLMAugmentationsException):
            logging.error(
                "Custom augmentations error occurred during " "H2O LLM Studio run:",
                exc_info=True,
            )
            write_flag(flag_path.format(local_rank), "info", "Augmentations error")
        elif isinstance(exception, LLMModelException):
            logging.error(
                "Model error occurred during H2O LLM Studio run:",
                exc_info=True,
            )
            write_flag(flag_path.format(local_rank), "info", "Model error")
        else:
            logging.error(
                "Exception occurred during H2O LLM Studio run:", exc_info=True
            )
            write_flag(flag_path.format(local_rank), "info", "See logs")
        kill_ddp_processes()