H2OTest / train_wave.py
elineve's picture
Upload 301 files
07423df
raw
history blame
5.58 kB
import os
# Set this before importing any other modules to be on the safe side
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import argparse
import logging
import sys
import time
import psutil
def check_for_done(process_queue):
"""Checks for finished process ids
Args:
process_queue: list of process ids
Returns:
(True, process_idx) if there is any finished process
(False, False) if there is not finished processes
"""
for i, pid in enumerate(process_queue):
zombie = False
try:
p = psutil.Process(pid)
zombie = p.status() == "zombie"
except psutil.NoSuchProcess:
pass
if not psutil.pid_exists(pid) or zombie:
return True, i
return False, False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"-C", "--config", help="config filename", default=argparse.SUPPRESS
)
parser.add_argument("-Y", "--yaml", help="yaml filename", default=argparse.SUPPRESS)
parser.add_argument(
"-Q",
"--process-queue",
help="process queue to wait for",
default=argparse.SUPPRESS,
)
parser_args, _ = parser.parse_known_args(sys.argv)
process_queue = []
if "process_queue" in parser_args and parser_args.process_queue != "":
process_queue = [int(x) for x in parser_args.process_queue.split(",")]
while True:
if len(process_queue) == 0:
break
done, num = check_for_done(process_queue)
if done:
process_queue.pop(num)
else:
time.sleep(30)
# delayed imports from llm_studio, only after we want to start training
import subprocess
import torch
from llm_studio.src.utils.config_utils import load_config_py, load_config_yaml
from llm_studio.src.utils.exceptions import (
LLMAugmentationsException,
LLMDataException,
LLMMetricException,
LLMModelException,
LLMTrainingException,
)
from llm_studio.src.utils.gpu_utils import is_oom_error
from llm_studio.src.utils.logging_utils import initialize_logging, write_flag
from llm_studio.src.utils.utils import kill_ddp_processes
from train import run
if "config" in parser_args:
cfg = load_config_py(parser_args.config)
elif "yaml" in parser_args:
cfg = load_config_yaml(parser_args.yaml)
flag_path = os.path.join(cfg.output_directory, "flags{}.json")
# Check if DDP
if "WORLD_SIZE" in os.environ:
local_rank = int(os.environ["LOCAL_RANK"])
if local_rank == 0:
write_flag(flag_path.format(""), "status", "running")
else:
write_flag(flag_path.format(""), "status", "running")
local_rank = 0
initialize_logging(cfg)
try:
run(cfg=cfg)
except Exception as exception:
write_flag(flag_path.format(local_rank), "status", "failed")
if is_oom_error(exception):
logging.error(
"GPU Out-of-Memory (OOM) error occurred. "
"Please, reduce the batch size, or input data size, "
"or model size. Or try gradient checkpointing.",
exc_info=True,
)
write_flag(flag_path.format(local_rank), "info", "OOM error")
logging.info(
"<pre>"
+ subprocess.check_output(["nvidia-smi"]).decode("utf-8")
+ "</pre>"
)
if torch.cuda.is_available():
logging.info(
"<pre>" + torch.cuda.memory_summary().replace("-", "=") + "</pre>"
)
elif isinstance(exception, LLMDataException):
logging.error(
"Data error occurred during H2O LLM Studio run:", exc_info=True
)
write_flag(flag_path.format(local_rank), "info", "Data error")
elif isinstance(exception, LLMTrainingException):
logging.error(
"Training error occurred during H2O LLM Studio run:", exc_info=True
)
write_flag(flag_path.format(local_rank), "info", "Training error")
elif isinstance(exception, LLMMetricException):
logging.error(
"Validation metric failed. Please make sure selected validation "
"metric is suitable for your current problem setup.",
exc_info=True,
)
write_flag(flag_path.format(local_rank), "info", "Metric error")
elif isinstance(exception, LLMAugmentationsException):
logging.error(
"Custom augmentations error occurred during " "H2O LLM Studio run:",
exc_info=True,
)
write_flag(flag_path.format(local_rank), "info", "Augmentations error")
elif isinstance(exception, LLMModelException):
logging.error(
"Model error occurred during H2O LLM Studio run:",
exc_info=True,
)
write_flag(flag_path.format(local_rank), "info", "Model error")
else:
logging.error(
"Exception occurred during H2O LLM Studio run:", exc_info=True
)
write_flag(flag_path.format(local_rank), "info", "See logs")
kill_ddp_processes()