executor / jupyter /jupyter_server.py
lvwerra's picture
lvwerra HF staff
Update jupyter/jupyter_server.py
fea1872 verified
raw
history blame
6.15 kB
from flask import Flask, request, jsonify
import os
import uuid
import time
import docker
import requests
import atexit
import socket
import argparse
import logging
from pydantic import BaseModel, Field, ValidationError
current_dir = os.path.dirname(os.path.abspath(__file__))
app = Flask(__name__)
app.logger.setLevel(logging.INFO)
# CLI function to parse arguments
def parse_args():
parser = argparse.ArgumentParser(description="Jupyter server.")
parser.add_argument('--n_instances', type=int, help="Number of Jupyter instances.")
parser.add_argument('--n_cpus', type=int, default=2, help="Number of CPUs per Jupyter instance.")
parser.add_argument('--mem', type=str, default="2g", help="Amount of memory per Jupyter instance.")
parser.add_argument('--execution_timeout', type=int, default=10, help="Timeout period for a code execution.")
parser.add_argument('--port', type=int, default=5001, help="Port of main server")
return parser.parse_args()
def get_unused_port(start=50000, end=65535, exclusion=[]):
for port in range(start, end + 1):
if port in exclusion:
continue
try:
sock = socket.socket()
sock.bind(("", port))
sock.listen(1)
sock.close()
return port
except OSError:
continue
raise IOError("No free ports available in range {}-{}".format(start, end))
def create_kernel_containers(n_instances, n_cpus=2, mem="2g", execution_timeout=10):
docker_client = docker.from_env()
app.logger.info("Buidling docker image...")
image, logs = docker_client.images.build(path=current_dir, tag='jupyter-kernel:latest')
app.logger.info("Building docker image complete.")
containers = []
port_exclusion = []
for i in range(n_instances):
free_port = get_unused_port(exclusion=port_exclusion)
port_exclusion.append(free_port) # it takes a while to startup so we don't use the same port twice
app.logger.info(f"Starting container {i} on port {free_port}...")
container = docker_client.containers.run(
"jupyter-kernel:latest",
detach=True,
mem_limit=mem,
cpuset_cpus=f"{i*n_cpus}-{(i+1)*n_cpus-1}", # Limit to CPU cores 0 and 1
remove=True,
ports={'5000/tcp': free_port},
environment={"EXECUTION_TIMEOUT": execution_timeout},
)
containers.append({"container": container, "port": free_port})
start_time = time.time()
containers_ready = []
while len(containers_ready) < n_instances:
app.logger.info("Pinging Jupyter containers to check readiness.")
if time.time() - start_time > 60:
raise TimeoutError("Container took too long to startup.")
for i in range(n_instances):
if i in containers_ready:
continue
url = f"http://localhost:{containers[i]['port']}/health"
try:
# TODO: dedicated health endpoint
response = requests.get(url)
if response.status_code == 200:
containers_ready.append(i)
except Exception as e:
# Catch any other errors that might occur
pass
time.sleep(0.5)
app.logger.info("Containers ready!")
return containers
def shutdown_cleanup():
app.logger.info("Shutting down. Stopping and removing all containers...")
for instance in app.containers:
try:
instance['container'].stop()
instance['container'].remove()
except Exception as e:
app.logger.info(f"Error stopping/removing container: {str(e)}")
app.logger.info("All containers stopped and removed.")
class ServerRequest(BaseModel):
code: str = Field(..., example="print('Hello World!')")
instance_id: int = Field(0, example=0)
restart: bool = Field(False, example=False)
@app.route('/execute', methods=['POST'])
def execute_code():
try:
input = ServerRequest(**request.json)
except ValidationError as e:
return jsonify(e.errors()), 400
port = app.containers[input.instance_id]["port"]
app.logger.info(f"Received request for instance {input.instance_id} (port={port}).")
try:
if input.restart:
response = requests.post(f'http://localhost:{port}/restart', json={})
if response.status_code==200:
app.logger.info(f"Kernel for instance {input.instance_id} restarted.")
else:
app.logger.info(f"Error when restarting kernel of instance {input.instance_id}: {response.json()}.")
response = requests.post(f'http://localhost:{port}/execute', json={'code': input.code})
result = response.json()
return result
except Exception as e:
app.logger.info(f"Error in execute_code: {str(e)}")
return jsonify({
'result': 'error',
'output': str(e)
}), 500
def init_app(app, args=None):
if args is None:
# When run through Gunicorn, use environment variables
args = argparse.Namespace(
n_instances=int(os.getenv('N_INSTANCES', 1)),
n_cpus=int(os.getenv('N_CPUS', 1)),
mem=os.getenv('MEM', '1g'),
execution_timeout=int(os.getenv('EXECUTION_TIMEOUT', 60))
)
app.containers = create_kernel_containers(
args.n_instances,
n_cpus=args.n_cpus,
mem=args.mem,
execution_timeout=args.execution_timeout
)
return app, args
atexit.register(shutdown_cleanup)
if __name__ == '__main__':
args = parse_args()
app, args = init_app(app, args=args)
# don't use debug=True --> it will run main twice and thus start double the containers
app.run(debug=False, host='0.0.0.0', port=args.port)
else:
app, args = init_app(app)
# TODO:
# how to mount data at runtime into the container? idea: mount a (read only)
# folder into the container at startup and copy the data in there. before starting
# the kernel we could cp the necessary data into the pwd.