import os import json import requests import time from threading import Thread, Lock import re import multiprocessing import subprocess ERROR_PATTERN = re.compile(r"ERROR:") def get_image_name(): current_dir = os.path.basename(os.getcwd()) if "cog" in current_dir: return current_dir else: return f"cog-{current_dir}" def process_log_line(line): line = line.decode("utf-8").strip() try: log_data = json.loads(line) return json.dumps(log_data, indent=2) except json.JSONDecodeError: return line def capture_output(pipe, print_lock, logs=None, error_detected=None): for line in iter(pipe.readline, b""): formatted_line = process_log_line(line) with print_lock: print(formatted_line) if logs is not None: logs.append(formatted_line) if error_detected is not None: if ERROR_PATTERN.search(formatted_line): error_detected[0] = True def wait_for_server_to_be_ready(url, timeout=300): """ Waits for the server to be ready. Args: - url: The health check URL to poll. - timeout: Maximum time (in seconds) to wait for the server to be ready. """ start_time = time.time() while True: try: response = requests.get(url) data = response.json() if data["status"] == "READY": return elif data["status"] == "SETUP_FAILED": raise RuntimeError( "Server initialization failed with status: SETUP_FAILED" ) except requests.RequestException: pass if time.time() - start_time > timeout: raise TimeoutError("Server did not become ready in the expected time.") time.sleep(5) # Poll every 5 seconds def run_training_subprocess(command): # Start the subprocess with pipes for stdout and stderr process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) # Create a lock for printing and a list to accumulate logs print_lock = multiprocessing.Lock() logs = multiprocessing.Manager().list() error_detected = multiprocessing.Manager().list([False]) # Start two separate processes to handle stdout and stderr stdout_processor = multiprocessing.Process( target=capture_output, args=(process.stdout, print_lock, logs, error_detected) ) stderr_processor = multiprocessing.Process( target=capture_output, args=(process.stderr, print_lock, logs, error_detected) ) # Start the log processors stdout_processor.start() stderr_processor.start() # Wait for the subprocess to finish process.wait() # Wait for the log processors to finish stdout_processor.join() stderr_processor.join() # Check if an error pattern was detected if error_detected[0]: raise Exception("Error detected in training logs! Check logs for details") return list(logs)