from flask import Flask, request, jsonify import os import uuid import time import docker import requests import atexit import socket import argparse import logging from pydantic import BaseModel, Field, ValidationError current_dir = os.path.dirname(os.path.abspath(__file__)) app = Flask(__name__) app.logger.setLevel(logging.INFO) # CLI function to parse arguments def parse_args(): parser = argparse.ArgumentParser(description="Jupyter server.") parser.add_argument('--n_instances', type=int, help="Number of Jupyter instances.") parser.add_argument('--n_cpus', type=int, default=2, help="Number of CPUs per Jupyter instance.") parser.add_argument('--mem', type=str, default="2g", help="Amount of memory per Jupyter instance.") parser.add_argument('--execution_timeout', type=int, default=10, help="Timeout period for a code execution.") parser.add_argument('--port', type=int, default=5001, help="Port of main server") return parser.parse_args() def get_unused_port(start=50000, end=65535, exclusion=[]): for port in range(start, end + 1): if port in exclusion: continue try: sock = socket.socket() sock.bind(("", port)) sock.listen(1) sock.close() return port except OSError: continue raise IOError("No free ports available in range {}-{}".format(start, end)) def create_kernel_containers(n_instances, n_cpus=2, mem="2g", execution_timeout=10): docker_client = docker.from_env() app.logger.info("Buidling docker image...") image, logs = docker_client.images.build(path=current_dir, tag='jupyter-kernel:latest') app.logger.info("Building docker image complete.") containers = [] port_exclusion = [] for i in range(n_instances): free_port = get_unused_port(exclusion=port_exclusion) port_exclusion.append(free_port) # it takes a while to startup so we don't use the same port twice app.logger.info(f"Starting container {i} on port {free_port}...") container = docker_client.containers.run( "jupyter-kernel:latest", detach=True, mem_limit=mem, cpuset_cpus=f"{i*n_cpus}-{(i+1)*n_cpus-1}", # Limit to CPU cores 0 and 1 remove=True, ports={'5000/tcp': free_port}, environment={"EXECUTION_TIMEOUT": execution_timeout}, ) containers.append({"container": container, "port": free_port}) start_time = time.time() containers_ready = [] while len(containers_ready) < n_instances: app.logger.info("Pinging Jupyter containers to check readiness.") if time.time() - start_time > 60: raise TimeoutError("Container took too long to startup.") for i in range(n_instances): if i in containers_ready: continue url = f"http://localhost:{containers[i]['port']}/health" try: # TODO: dedicated health endpoint response = requests.get(url) if response.status_code == 200: containers_ready.append(i) except Exception as e: # Catch any other errors that might occur pass time.sleep(0.5) app.logger.info("Containers ready!") return containers def shutdown_cleanup(): app.logger.info("Shutting down. Stopping and removing all containers...") for instance in app.containers: try: instance['container'].stop() instance['container'].remove() except Exception as e: app.logger.info(f"Error stopping/removing container: {str(e)}") app.logger.info("All containers stopped and removed.") class ServerRequest(BaseModel): code: str = Field(..., example="print('Hello World!')") instance_id: int = Field(0, example=0) restart: bool = Field(False, example=False) @app.route('/execute', methods=['POST']) def execute_code(): try: input = ServerRequest(**request.json) except ValidationError as e: return jsonify(e.errors()), 400 port = app.containers[input.instance_id]["port"] app.logger.info(f"Received request for instance {input.instance_id} (port={port}).") try: if input.restart: response = requests.post(f'http://localhost:{port}/restart', json={}) if response.status_code==200: app.logger.info(f"Kernel for instance {input.instance_id} restarted.") else: app.logger.info(f"Error when restarting kernel of instance {input.instance_id}: {response.json()}.") response = requests.post(f'http://localhost:{port}/execute', json={'code': input.code}) result = response.json() return result except Exception as e: app.logger.info(f"Error in execute_code: {str(e)}") return jsonify({ 'result': 'error', 'output': str(e) }), 500 def init_app(app, args=None): if args is None: # When run through Gunicorn, use environment variables args = argparse.Namespace( n_instances=int(os.getenv('N_INSTANCES', 1)), n_cpus=int(os.getenv('N_CPUS', 1)), mem=os.getenv('MEM', '1g'), execution_timeout=int(os.getenv('EXECUTION_TIMEOUT', 60)) ) app.containers = create_kernel_containers( args.n_instances, n_cpus=args.n_cpus, mem=args.mem, execution_timeout=args.execution_timeout ) return app, args atexit.register(shutdown_cleanup) if __name__ == '__main__': args = parse_args() app, args = init_app(app, args=args) # don't use debug=True --> it will run main twice and thus start double the containers app.run(debug=False, host='0.0.0.0', port=args.port) else: app, args = init_app(app) # TODO: # how to mount data at runtime into the container? idea: mount a (read only) # folder into the container at startup and copy the data in there. before starting # the kernel we could cp the necessary data into the pwd.