|
from flask import Flask, request, jsonify |
|
import os |
|
import uuid |
|
import time |
|
import docker |
|
import requests |
|
import atexit |
|
import socket |
|
import argparse |
|
import logging |
|
from pydantic import BaseModel, Field, ValidationError |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
app = Flask(__name__) |
|
app.logger.setLevel(logging.INFO) |
|
|
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Jupyter server.") |
|
parser.add_argument('--n_instances', type=int, help="Number of Jupyter instances.") |
|
parser.add_argument('--n_cpus', type=int, default=2, help="Number of CPUs per Jupyter instance.") |
|
parser.add_argument('--mem', type=str, default="2g", help="Amount of memory per Jupyter instance.") |
|
parser.add_argument('--execution_timeout', type=int, default=10, help="Timeout period for a code execution.") |
|
parser.add_argument('--port', type=int, default=5001, help="Port of main server") |
|
return parser.parse_args() |
|
|
|
|
|
def get_unused_port(start=50000, end=65535, exclusion=[]): |
|
for port in range(start, end + 1): |
|
if port in exclusion: |
|
continue |
|
try: |
|
sock = socket.socket() |
|
sock.bind(("", port)) |
|
sock.listen(1) |
|
sock.close() |
|
return port |
|
except OSError: |
|
continue |
|
raise IOError("No free ports available in range {}-{}".format(start, end)) |
|
|
|
|
|
def create_kernel_containers(n_instances, n_cpus=2, mem="2g", execution_timeout=10): |
|
|
|
docker_client = docker.from_env() |
|
app.logger.info("Buidling docker image...") |
|
image, logs = docker_client.images.build(path=current_dir, tag='jupyter-kernel:latest') |
|
app.logger.info("Building docker image complete.") |
|
|
|
containers = [] |
|
port_exclusion = [] |
|
for i in range(n_instances): |
|
|
|
free_port = get_unused_port(exclusion=port_exclusion) |
|
port_exclusion.append(free_port) |
|
app.logger.info(f"Starting container {i} on port {free_port}...") |
|
container = docker_client.containers.run( |
|
"jupyter-kernel:latest", |
|
detach=True, |
|
mem_limit=mem, |
|
cpuset_cpus=f"{i*n_cpus}-{(i+1)*n_cpus-1}", |
|
remove=True, |
|
ports={'5000/tcp': free_port}, |
|
environment={"EXECUTION_TIMEOUT": execution_timeout}, |
|
) |
|
|
|
containers.append({"container": container, "port": free_port}) |
|
|
|
start_time = time.time() |
|
|
|
containers_ready = [] |
|
|
|
while len(containers_ready) < n_instances: |
|
app.logger.info("Pinging Jupyter containers to check readiness.") |
|
if time.time() - start_time > 60: |
|
raise TimeoutError("Container took too long to startup.") |
|
for i in range(n_instances): |
|
if i in containers_ready: |
|
continue |
|
url = f"http://localhost:{containers[i]['port']}/health" |
|
try: |
|
|
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
containers_ready.append(i) |
|
except Exception as e: |
|
|
|
pass |
|
time.sleep(0.5) |
|
app.logger.info("Containers ready!") |
|
return containers |
|
|
|
def shutdown_cleanup(): |
|
app.logger.info("Shutting down. Stopping and removing all containers...") |
|
for instance in app.containers: |
|
try: |
|
instance['container'].stop() |
|
instance['container'].remove() |
|
except Exception as e: |
|
app.logger.info(f"Error stopping/removing container: {str(e)}") |
|
app.logger.info("All containers stopped and removed.") |
|
|
|
|
|
class ServerRequest(BaseModel): |
|
code: str = Field(..., example="print('Hello World!')") |
|
instance_id: int = Field(0, example=0) |
|
restart: bool = Field(False, example=False) |
|
|
|
|
|
@app.route('/execute', methods=['POST']) |
|
def execute_code(): |
|
try: |
|
input = ServerRequest(**request.json) |
|
except ValidationError as e: |
|
return jsonify(e.errors()), 400 |
|
|
|
|
|
port = app.containers[input.instance_id]["port"] |
|
|
|
app.logger.info(f"Received request for instance {input.instance_id} (port={port}).") |
|
|
|
try: |
|
if input.restart: |
|
response = requests.post(f'http://localhost:{port}/restart', json={}) |
|
if response.status_code==200: |
|
app.logger.info(f"Kernel for instance {input.instance_id} restarted.") |
|
else: |
|
app.logger.info(f"Error when restarting kernel of instance {input.instance_id}: {response.json()}.") |
|
|
|
response = requests.post(f'http://localhost:{port}/execute', json={'code': input.code}) |
|
result = response.json() |
|
return result |
|
|
|
except Exception as e: |
|
app.logger.info(f"Error in execute_code: {str(e)}") |
|
return jsonify({ |
|
'result': 'error', |
|
'output': str(e) |
|
}), 500 |
|
|
|
|
|
def init_app(app, args=None): |
|
if args is None: |
|
|
|
args = argparse.Namespace( |
|
n_instances=int(os.getenv('N_INSTANCES', 1)), |
|
n_cpus=int(os.getenv('N_CPUS', 1)), |
|
mem=os.getenv('MEM', '1g'), |
|
execution_timeout=int(os.getenv('EXECUTION_TIMEOUT', 60)) |
|
) |
|
|
|
app.containers = create_kernel_containers( |
|
args.n_instances, |
|
n_cpus=args.n_cpus, |
|
mem=args.mem, |
|
execution_timeout=args.execution_timeout |
|
) |
|
return app, args |
|
|
|
atexit.register(shutdown_cleanup) |
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
app, args = init_app(app, args=args) |
|
|
|
app.run(debug=False, host='0.0.0.0', port=args.port) |
|
else: |
|
app, args = init_app(app) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|