File size: 5,576 Bytes
07423df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os
# Set this before importing any other modules to be on the safe side
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import argparse
import logging
import sys
import time
import psutil
def check_for_done(process_queue):
"""Checks for finished process ids
Args:
process_queue: list of process ids
Returns:
(True, process_idx) if there is any finished process
(False, False) if there is not finished processes
"""
for i, pid in enumerate(process_queue):
zombie = False
try:
p = psutil.Process(pid)
zombie = p.status() == "zombie"
except psutil.NoSuchProcess:
pass
if not psutil.pid_exists(pid) or zombie:
return True, i
return False, False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"-C", "--config", help="config filename", default=argparse.SUPPRESS
)
parser.add_argument("-Y", "--yaml", help="yaml filename", default=argparse.SUPPRESS)
parser.add_argument(
"-Q",
"--process-queue",
help="process queue to wait for",
default=argparse.SUPPRESS,
)
parser_args, _ = parser.parse_known_args(sys.argv)
process_queue = []
if "process_queue" in parser_args and parser_args.process_queue != "":
process_queue = [int(x) for x in parser_args.process_queue.split(",")]
while True:
if len(process_queue) == 0:
break
done, num = check_for_done(process_queue)
if done:
process_queue.pop(num)
else:
time.sleep(30)
# delayed imports from llm_studio, only after we want to start training
import subprocess
import torch
from llm_studio.src.utils.config_utils import load_config_py, load_config_yaml
from llm_studio.src.utils.exceptions import (
LLMAugmentationsException,
LLMDataException,
LLMMetricException,
LLMModelException,
LLMTrainingException,
)
from llm_studio.src.utils.gpu_utils import is_oom_error
from llm_studio.src.utils.logging_utils import initialize_logging, write_flag
from llm_studio.src.utils.utils import kill_ddp_processes
from train import run
if "config" in parser_args:
cfg = load_config_py(parser_args.config)
elif "yaml" in parser_args:
cfg = load_config_yaml(parser_args.yaml)
flag_path = os.path.join(cfg.output_directory, "flags{}.json")
# Check if DDP
if "WORLD_SIZE" in os.environ:
local_rank = int(os.environ["LOCAL_RANK"])
if local_rank == 0:
write_flag(flag_path.format(""), "status", "running")
else:
write_flag(flag_path.format(""), "status", "running")
local_rank = 0
initialize_logging(cfg)
try:
run(cfg=cfg)
except Exception as exception:
write_flag(flag_path.format(local_rank), "status", "failed")
if is_oom_error(exception):
logging.error(
"GPU Out-of-Memory (OOM) error occurred. "
"Please, reduce the batch size, or input data size, "
"or model size. Or try gradient checkpointing.",
exc_info=True,
)
write_flag(flag_path.format(local_rank), "info", "OOM error")
logging.info(
"<pre>"
+ subprocess.check_output(["nvidia-smi"]).decode("utf-8")
+ "</pre>"
)
if torch.cuda.is_available():
logging.info(
"<pre>" + torch.cuda.memory_summary().replace("-", "=") + "</pre>"
)
elif isinstance(exception, LLMDataException):
logging.error(
"Data error occurred during H2O LLM Studio run:", exc_info=True
)
write_flag(flag_path.format(local_rank), "info", "Data error")
elif isinstance(exception, LLMTrainingException):
logging.error(
"Training error occurred during H2O LLM Studio run:", exc_info=True
)
write_flag(flag_path.format(local_rank), "info", "Training error")
elif isinstance(exception, LLMMetricException):
logging.error(
"Validation metric failed. Please make sure selected validation "
"metric is suitable for your current problem setup.",
exc_info=True,
)
write_flag(flag_path.format(local_rank), "info", "Metric error")
elif isinstance(exception, LLMAugmentationsException):
logging.error(
"Custom augmentations error occurred during " "H2O LLM Studio run:",
exc_info=True,
)
write_flag(flag_path.format(local_rank), "info", "Augmentations error")
elif isinstance(exception, LLMModelException):
logging.error(
"Model error occurred during H2O LLM Studio run:",
exc_info=True,
)
write_flag(flag_path.format(local_rank), "info", "Model error")
else:
logging.error(
"Exception occurred during H2O LLM Studio run:", exc_info=True
)
write_flag(flag_path.format(local_rank), "info", "See logs")
kill_ddp_processes()
|