Spaces:
Running
Running
import multiprocessing | |
import json | |
import os | |
import uvicorn | |
from fastapi import FastAPI, Request, HTTPException, Response | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from utils import extract_and_cache_document, service, cache_file_popup_url, cache_root, cache_file, code_interpreter_ws, update_pop_url, change_checkbox_state | |
from starlette.middleware.sessions import SessionMiddleware | |
# os.environ["TRANSFORMERS_CACHE"] = ".cache/huggingface/" | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
# allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=['*'], | |
allow_headers=['*'], | |
) | |
app.mount('/static', StaticFiles(directory=code_interpreter_ws), name='static') | |
async def access_token_auth(request: Request, call_next): | |
# print(f"Request URL path: {request.url}") | |
access_token: str = request.headers.get("Authorization") or request.query_params.get("access_token") or request.session.get("access_token") | |
is_valid = False | |
if access_token: | |
account_info = json.loads(service.get(access_token, "info.json", False)) | |
if account_info and account_info["enabled"]: | |
is_valid = True | |
if not is_valid: | |
return Response(status_code=401, content="the token is not valid") | |
request.session.setdefault("access_token", access_token) | |
return await call_next(request) | |
async def healthz(request: Request): | |
return JSONResponse({"healthz": True}) | |
async def add_token(request: Request): | |
access_token: str = request.headers.get("Authorization") or request.query_params.get("access_token") or request.session.get("access_token") | |
account_info = json.loads(service.get(access_token, "info.json", False)) | |
if account_info and account_info["enabled"] and account_info["role"] == 'admin': | |
return Response(status_code=401, content="the token is not valid") | |
data = await request.json() | |
service.upsert(access_token, "info.json", json.dumps(data, ensure_ascii=False), False) | |
return JSONResponse({"success": True}) | |
async def cache_data(request: Request, file_name: str): | |
access_token: str = request.headers.get("Authorization") or request.query_params.get("access_token") or request.session.get("access_token") | |
account_info = json.loads(service.get(access_token, "info.json", False)) | |
if account_info and account_info["enabled"] and account_info["role"] == 'admin': | |
return Response(status_code=401, content="the token is not valid") | |
data = service.get(access_token, file_name, False) | |
content = json.loads(data) if data else "" | |
return JSONResponse(content) | |
async def web_listening(request: Request): | |
data = await request.json() | |
msg_type = data['task'] | |
access_token = request.session.get("access_token") | |
if msg_type == 'change_checkbox': | |
rsp = change_checkbox_state(data['ckid'], cache_file, access_token) | |
elif msg_type == 'cache': | |
cache_obj = multiprocessing.Process( target=extract_and_cache_document, args=(data, cache_root, access_token)) | |
cache_obj.start() | |
# rsp = cache_data(data, cache_file) | |
rsp = 'caching' | |
elif msg_type == 'pop_url': | |
# What a misleading name! pop_url actually means add_url. pop is referring to the pop_up ui. | |
rsp = update_pop_url(data, cache_file_popup_url, access_token) | |
else: | |
raise NotImplementedError | |
return JSONResponse(content=rsp) | |
import gradio as gr | |
from assistant_server import demo as assistant_app | |
from workstation_server import demo as workstation_app | |
app = gr.mount_gradio_app(app, assistant_app, path="/assistant") | |
app = gr.mount_gradio_app(app, workstation_app, path="/workstation") | |
app.add_middleware(SessionMiddleware, secret_key=os.getenv("SECRET_KEY"), max_age=25200) | |
if __name__ == '__main__': | |
uvicorn.run(app='database_server:app', host='0.0.0.0', port=7860, reload=False, workers=1) | |