import json import os from pathlib import Path from typing import Callable, NoReturn from asgi_correlation_id import CorrelationIdMiddleware import gradio as gr from starlette.responses import JSONResponse import structlog import uvicorn from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request, status from fastapi.exceptions import RequestValidationError from fastapi.responses import FileResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from pydantic import ValidationError from samgis_core.utilities import create_folders_if_not_exists from samgis_core.utilities import frontend_builder from samgis_core.utilities.session_logger import setup_logging from samgis_web.utilities.constants import GRADIO_EXAMPLES_TEXT_LIST, GRADIO_MARKDOWN, GRADIO_EXAMPLE_BODY_STRING_PROMPT from samgis_web.utilities.type_hints import StringPromptApiRequestBody load_dotenv() project_root_folder = Path(globals().get("__file__", "./_")).absolute().parent workdir = Path(os.getenv("WORKDIR", project_root_folder)) model_folder = Path(project_root_folder / "machine_learning_models") log_level = os.getenv("LOG_LEVEL", "INFO") setup_logging(log_level=log_level) app_logger = structlog.stdlib.get_logger() app_logger.info(f"PROJECT_ROOT_FOLDER:{project_root_folder}, WORKDIR:{workdir}.") folders_map = os.getenv("FOLDERS_MAP", "{}") markdown_text = os.getenv("MARKDOWN_TEXT", "") examples_text_list = os.getenv("EXAMPLES_TEXT_LIST", "").split("\n") example_body = json.loads(os.getenv("EXAMPLE_BODY", "{}")) mount_gradio_app = bool(os.getenv("MOUNT_GRADIO_APP", "")) static_dist_folder = workdir / "static" / "dist" input_css_path = os.getenv("INPUT_CSS_PATH", "src/input.css") vite_gradio_url = os.getenv("VITE_GRADIO_URL", "/gradio") vite_index_url = os.getenv("VITE_INDEX_URL", "/") vite_samgis_url = os.getenv("VITE_SAMGIS_URL", "/samgis") vite_lisa_url = os.getenv("VITE_LISA_URL", "/lisa") fastapi_title = "samgis-lisa-on-cuda" app = FastAPI(title=fastapi_title, version="1.0") @app.middleware("http") async def request_middleware(request, call_next): from samgis_web.web.middlewares import logging_middleware return await logging_middleware(request, call_next) def get_example_complete(example_text): example_dict = dict(**GRADIO_EXAMPLE_BODY_STRING_PROMPT) example_dict["string_prompt"] = example_text return json.dumps(example_dict) def get_gradio_interface_geojson(fn_inference: Callable): with gr.Blocks() as gradio_app: gr.Markdown(GRADIO_MARKDOWN) with gr.Row(): with gr.Column(): text_input = gr.Textbox(lines=1, placeholder=None, label="Payload input") btn = gr.Button(value="Submit") with gr.Column(): text_output = gr.Textbox(lines=1, placeholder=None, label="Geojson Output") gr.Examples( examples=[ get_example_complete(example) for example in GRADIO_EXAMPLES_TEXT_LIST ], inputs=[text_input], ) btn.click( fn_inference, inputs=[text_input], outputs=[text_output] ) return gradio_app def handle_exception_response(exception: Exception) -> NoReturn: import subprocess project_root_folder_content = subprocess.run( f"ls -l {project_root_folder}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE ) app_logger.error(f"project_root folder 'ls -l' command output: {project_root_folder_content.stdout}.") workdir_folder_content = subprocess.run( f"ls -l {workdir}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE ) app_logger.error(f"workdir folder 'ls -l' command stdout: {workdir_folder_content.stdout}.") app_logger.error(f"workdir folder 'ls -l' command stderr: {workdir_folder_content.stderr}.") app_logger.error(f"inference error:{exception}.") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error on inference" ) @app.get("/health") async def health() -> JSONResponse: from samgis_web.__version__ import __version__ as version_web from samgis_core.__version__ import __version__ as version_core from lisa_on_cuda.__version__ import __version__ as version_lisa_on_cuda from samgis_lisa.__version__ import __version__ as version_samgis_lisa app_logger.info(f"still alive, version_web:{version_web}, version_core:{version_core}.") app_logger.info(f"still alive, version_lisa_on_cuda:{version_lisa_on_cuda}, version_samgis_lisa:{version_samgis_lisa}.") return JSONResponse(status_code=200, content={"msg": "still alive..."}) @app.get("/health_models") async def health_models() -> JSONResponse: from samgis_lisa.prediction_api import lisa from samgis_lisa.utilities.constants import LISA_INFERENCE_FN from samgis_web.__version__ import __version__ as version_web from samgis_core.__version__ import __version__ as version_core from lisa_on_cuda.__version__ import __version__ as version_lisa_on_cuda from samgis_lisa.__version__ import __version__ as version_samgis_lisa from samgis_lisa.prediction_api.global_models import models_dict app_logger.info(f"still alive, version_web:{version_web}, version_core:{version_core}.") app_logger.info(f"still alive, version_lisa_on_cuda:{version_lisa_on_cuda}, version_samgis_lisa:{version_samgis_lisa}.") app_logger.info(f"try to load inference function for '{LISA_INFERENCE_FN}' model...") if models_dict[LISA_INFERENCE_FN]["inference"] is None: app_logger.info(f"model not found, loading inference function for '{LISA_INFERENCE_FN}' model. This could take some minutes...") lisa.load_model_and_inference_fn(LISA_INFERENCE_FN, inference_decorator=None, device_map="auto", device="cuda") try: model_name = models_dict[LISA_INFERENCE_FN]["inference"] app_logger.info(f"inference function for '{LISA_INFERENCE_FN}' model => '{model_name.__name__}' found and loaded...") except KeyError as ke: app_logger.error(f"model not found, error:{ke}.") raise HTTPException(status_code=500, detail="Internal Server Error") return JSONResponse(status_code=200, content={"msg": f"still alive, inference function for '{LISA_INFERENCE_FN}' model loaded..."}) def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> str: from samgis_lisa.io_package.wrappers_helpers import get_parsed_bbox_points_with_string_prompt from samgis_lisa.prediction_api import lisa from samgis_lisa.utilities.constants import LISA_INFERENCE_FN app_logger.info("starting lisa inference request...") try: import time time_start_run = time.time() body_request = get_parsed_bbox_points_with_string_prompt(request_input) app_logger.info(f"lisa body_request:{body_request}.") try: source = body_request["source"] source_name = body_request["source_name"] app_logger.debug(f"body_request:type(source):{type(source)}, source:{source}.") app_logger.debug(f"body_request:type(source_name):{type(source_name)}, source_name:{source_name}.") app_logger.debug(f"lisa module:{lisa}.") output = lisa.lisa_predict( bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"], source=source, source_name=source_name, inference_function_name_key=LISA_INFERENCE_FN ) duration_run = time.time() - time_start_run app_logger.info(f"duration_run:{duration_run}.") body = { "duration_run": duration_run, "output": output } dumped = json.dumps(body) app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.") app_logger.debug(f"complete json.dumps(body):{dumped}.") return dumped except Exception as inference_exception: app_logger.error(f"inference_exception:{inference_exception}.") app_logger.error(f"inference_exception, request_input:{request_input}.") raise HTTPException(status_code=500, detail="Internal Server Error") except ValidationError as va1: app_logger.error(f"validation error: {str(va1)}.") app_logger.error(f"ValidationError, request_input:{request_input}.") raise RequestValidationError("Unprocessable Entity") @app.post("/infer_lisa") def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse: dumped = infer_lisa_gradio(request_input=request_input) app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.") app_logger.debug(f"complete json.dumps(body):{dumped}.") return JSONResponse(status_code=200, content={"body": dumped}) @app.exception_handler(RequestValidationError) def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: from samgis_web.web import exception_handlers return exception_handlers.request_validation_exception_handler(request, exc) @app.exception_handler(HTTPException) def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: from samgis_web.web import exception_handlers return exception_handlers.http_exception_handler(request, exc) create_folders_if_not_exists.folders_creation(folders_map) write_tmp_on_disk = os.getenv("WRITE_TMP_ON_DISK", "") app_logger.info(f"write_tmp_on_disk:{write_tmp_on_disk}.") if bool(write_tmp_on_disk): try: assert Path(write_tmp_on_disk).is_dir() app.mount("/vis_output", StaticFiles(directory=write_tmp_on_disk), name="vis_output") templates = Jinja2Templates(directory=str(project_root_folder / "static")) @app.get("/vis_output", response_class=HTMLResponse) def list_files(request: Request): files = os.listdir(write_tmp_on_disk) files_paths = sorted([f"{request.url._url}/{f}" for f in files]) print(files_paths) return templates.TemplateResponse( "list_files.html", {"request": request, "files": files_paths} ) except (AssertionError, RuntimeError) as rerr: app_logger.error(f"{rerr} while loading the folder write_tmp_on_disk:{write_tmp_on_disk}...") raise rerr frontend_builder.build_frontend( project_root_folder=workdir, input_css_path=input_css_path, output_dist_folder=static_dist_folder ) app_logger.info("build_frontend ok!") templates = Jinja2Templates(directory="templates") app.mount("/static", StaticFiles(directory=static_dist_folder, html=True), name="static") # important: the index() function and the app.mount MUST be at the end # samgis.html app.mount(vite_samgis_url, StaticFiles(directory=static_dist_folder, html=True), name="samgis") @app.get(vite_samgis_url) async def samgis() -> FileResponse: return FileResponse(path=str(static_dist_folder / "samgis.html"), media_type="text/html") # lisa.html app.mount(vite_lisa_url, StaticFiles(directory=static_dist_folder, html=True), name="lisa") @app.get(vite_lisa_url) async def lisa() -> FileResponse: return FileResponse(path=str(static_dist_folder / "lisa.html"), media_type="text/html") # index.html (lisa.html copy) app.mount(vite_index_url, StaticFiles(directory=static_dist_folder, html=True), name="index") @app.get(vite_index_url) async def index() -> FileResponse: return FileResponse(path=str(static_dist_folder / "index.html"), media_type="text/html") app_logger.info("creating gradio interface...") gr_interface = get_gradio_interface_geojson(infer_lisa_gradio) app_logger.info(f"gradio interface created, mounting gradio app on url {vite_gradio_url} within FastAPI...") app = gr.mount_gradio_app(app, gr_interface, path=vite_gradio_url) app_logger.info("mounted gradio app within fastapi") # add the CorrelationIdMiddleware AFTER the @app.middleware("http") decorated function to avoid missing request id app.add_middleware(CorrelationIdMiddleware) if __name__ == '__main__': try: uvicorn.run(host="0.0.0.0", port=7860, app=app) except Exception as ex: app_logger.error(f"fastapi/gradio application {fastapi_title}, exception:{ex}.") print(f"fastapi/gradio application {fastapi_title}, exception:{ex}.") raise ex