pvanand's picture
Update main.py
398aecd verified
raw
history blame
5.89 kB
import asyncio
import uvicorn
import os
import shutil
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from jupyter_client.manager import KernelManager # Updated import
from typing import Dict
from datetime import datetime, timedelta
import psutil
from typing import List
app = FastAPI()
# Define root directory for all session folders
root_dir = os.path.abspath(os.path.dirname(__file__))
# Middleware for CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Model for input data
class CodeExecutionRequest(BaseModel):
session_token: str
code: str
# Store kernel managers and last access times
kernel_managers: Dict[str, KernelManager] = {}
last_access_times: Dict[str, datetime] = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 600 # 10 minutes
# Function to create a new kernel and session directory
async def create_kernel(session_token: str):
# Create session directory
session_dir = os.path.join(root_dir, "output", session_token)
if not os.path.exists(session_dir):
os.makedirs(session_dir)
os.chmod(session_dir, 0o777)
km = KernelManager()
km.start_kernel()
kernel_managers[session_token] = km
last_access_times[session_token] = datetime.now()
# Function to kill a kernel
async def kill_kernel(session_token: str):
km = kernel_managers.pop(session_token, None)
if km:
km.shutdown_kernel(now=True)
last_access_times.pop(session_token, None)
session_dir = os.path.join(root_dir, "output", session_token)
if os.path.exists(session_dir):
shutil.rmtree(session_dir)
# Add file upload endpoint
@app.post("/upload/{session_token}")
async def upload_files(session_token: str, files: List[UploadFile] = File(...)):
session_dir = os.path.join(root_dir, "output", session_token)
# Create session directory if it doesn't exist
if not os.path.exists(session_dir):
os.makedirs(session_dir)
uploaded_files = []
for file in files:
# Save the uploaded file
file_path = os.path.join(session_dir, file.filename)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
uploaded_files.append(file.filename)
print("uploaded_files",uploaded_files)
return {"filenames": uploaded_files, "status": "success"}
# Function to execute code in a kernel
async def execute_code(session_token: str, code: str):
setup_code = """
%matplotlib inline
import os
os.chdir('{session_dir}')
""".format(session_dir=os.path.join(root_dir, "output", session_token).replace('\\', '\\\\'))
session_dir = os.path.join(root_dir, "output", session_token)
if not os.path.exists(session_dir):
os.makedirs(session_dir)
if session_token not in kernel_managers:
await create_kernel(session_token)
km = kernel_managers[session_token]
kc = km.client()
try:
# Set working directory for the Python process
os.chdir(session_dir)
print("Current working directory:", os.getcwd())
# Execute setup code to set kernel's working directory
kc.execute_interactive(setup_code, store_history=False)
# Execute the provided code
kc.execute(code, store_history=False)
output = []
timeout = datetime.now() + timedelta(seconds=TIMEOUT_DURATION)
while True:
if datetime.now() > timeout:
raise TimeoutError("Code execution timed out.")
msg = kc.get_iopub_msg()
if msg['msg_type'] == 'status' and msg['content']['execution_state'] == 'idle':
break
elif msg['msg_type'] == 'error':
error_output = {
"ename": msg['content']['ename'],
"evalue": msg['content']['evalue'],
"traceback": msg['content']['traceback']
}
output.append({"error": error_output})
if 'data' in msg['content']:
output.append({"data": msg['content']['data']})
elif 'text' in msg['content']:
output.append({"text": msg['content']['text']})
last_access_times[session_token] = datetime.now()
print("Execution SUCCESS")
print("#################")
print("CODE:", code)
print("OUTPUT:", output)
return {'status': 'success', 'value': output}
except Exception as e:
last_access_times[session_token] = datetime.now()
return {'status': 'error', 'value': str(e)}
# Background task to check for idle kernels
async def check_idle_kernels():
while True:
now = datetime.now()
for session_token, last_access in list(last_access_times.items()):
if now - last_access > timedelta(seconds=TIMEOUT_DURATION):
await kill_kernel(session_token)
await asyncio.sleep(60) # Check every minute
@app.on_event("startup")
async def startup_event():
asyncio.create_task(check_idle_kernels())
@app.post("/execute")
async def execute(request: CodeExecutionRequest):
result = await execute_code(request.session_token, request.code)
return result
@app.get("/info")
async def get_info():
# Get the number of active kernels
active_kernels = len(kernel_managers)
# Get system resource usage
cpu_usage = psutil.cpu_percent(interval=1)
ram_usage = psutil.virtual_memory().percent
# Return the information
return {
"active_kernels": active_kernels,
"cpu_usage_percent": cpu_usage,
"ram_usage_percent": ram_usage
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)