alessandro trinca tornidor
commited on
Commit
·
a5e4002
1
Parent(s):
00f8875
refactor: remove unuseful app_helpers.get_inference_model_by_args(args, inference_decorator=SPACES_GPU), fix logs, initialize gpu within infer_lisa_gradio()
Browse files- .idea/vcs.xml +1 -0
- app.py +19 -16
- samgis_lisa_on_zero/io_package/wrappers_helpers.py +4 -2
- samgis_lisa_on_zero/prediction_api/lisa.py +15 -7
.idea/vcs.xml
CHANGED
@@ -2,5 +2,6 @@
|
|
2 |
<project version="4">
|
3 |
<component name="VcsDirectoryMappings">
|
4 |
<mapping directory="" vcs="Git" />
|
|
|
5 |
</component>
|
6 |
</project>
|
|
|
2 |
<project version="4">
|
3 |
<component name="VcsDirectoryMappings">
|
4 |
<mapping directory="" vcs="Git" />
|
5 |
+
<mapping directory="$PROJECT_DIR$/sam-quantized" vcs="Git" />
|
6 |
</component>
|
7 |
</project>
|
app.py
CHANGED
@@ -5,6 +5,7 @@ import uuid
|
|
5 |
from typing import Callable, NoReturn
|
6 |
|
7 |
import gradio as gr
|
|
|
8 |
import uvicorn
|
9 |
from fastapi import FastAPI, HTTPException, Request, status
|
10 |
from fastapi.exceptions import RequestValidationError
|
@@ -13,8 +14,6 @@ from fastapi.staticfiles import StaticFiles
|
|
13 |
from fastapi.templating import Jinja2Templates
|
14 |
from lisa_on_cuda.utils import app_helpers, frontend_builder, create_folders_and_variables_if_not_exists
|
15 |
from pydantic import ValidationError
|
16 |
-
from spaces import GPU as SPACES_GPU
|
17 |
-
|
18 |
from samgis_core.utilities.fastapi_logger import setup_logging
|
19 |
from samgis_lisa_on_zero import PROJECT_ROOT_FOLDER, WORKDIR
|
20 |
from samgis_lisa_on_zero.utilities.type_hints import ApiRequestBody, StringPromptApiRequestBody
|
@@ -31,6 +30,11 @@ FASTAPI_TITLE = "samgis-lisa-on-zero"
|
|
31 |
app = FastAPI(title=FASTAPI_TITLE, version="1.0")
|
32 |
|
33 |
|
|
|
|
|
|
|
|
|
|
|
34 |
def get_gradio_interface_geojson(
|
35 |
fn_inference: Callable
|
36 |
):
|
@@ -143,13 +147,16 @@ def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> JSONResponse
|
|
143 |
time_start_run = time.time()
|
144 |
body_request = get_parsed_bbox_points_with_string_prompt(request_input)
|
145 |
app_logger.info(f"lisa body_request:{body_request}.")
|
146 |
-
app_logger.info(f"lisa module:{lisa}.")
|
147 |
try:
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
150 |
output = lisa.lisa_predict(
|
151 |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
152 |
-
source=
|
153 |
)
|
154 |
duration_run = time.time() - time_start_run
|
155 |
app_logger.info(f"duration_run:{duration_run}.")
|
@@ -157,9 +164,10 @@ def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> JSONResponse
|
|
157 |
"duration_run": duration_run,
|
158 |
"output": output
|
159 |
}
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
163 |
except Exception as inference_exception:
|
164 |
handle_exception_response(inference_exception)
|
165 |
except ValidationError as va1:
|
@@ -187,7 +195,7 @@ def infer_samgis(request_input: ApiRequestBody) -> JSONResponse:
|
|
187 |
body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input)
|
188 |
app_logger.info(f"body_request:{body_request}.")
|
189 |
try:
|
190 |
-
source_name =
|
191 |
app_logger.info(f"source_name = {source_name}.")
|
192 |
output = predictors.samexporter_predict(
|
193 |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
@@ -296,7 +304,7 @@ async def lisa() -> FileResponse:
|
|
296 |
return FileResponse(path=static_dist_folder / "lisa.html", media_type="text/html")
|
297 |
|
298 |
|
299 |
-
#
|
300 |
app.mount(VITE_INDEX_URL, StaticFiles(directory=static_dist_folder, html=True), name="index")
|
301 |
|
302 |
|
@@ -305,12 +313,7 @@ async def index() -> FileResponse:
|
|
305 |
return FileResponse(path=static_dist_folder / "index.html", media_type="text/html")
|
306 |
|
307 |
|
308 |
-
args = app_helpers.parse_args([])
|
309 |
-
app_helpers.app_logger.info(f"prepared default arguments:{args}.")
|
310 |
-
inference_fn = app_helpers.get_inference_model_by_args(args, inference_decorator=SPACES_GPU)
|
311 |
-
|
312 |
app_helpers.app_logger.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...")
|
313 |
-
# io_package = app_helpers.get_gradio_interface(inference_fn)
|
314 |
io = get_gradio_interface_geojson(infer_lisa_gradio)
|
315 |
app_helpers.app_logger.info(
|
316 |
f"created gradio interface, mounting gradio app on url {VITE_GRADIO_URL} within FastAPI...")
|
|
|
5 |
from typing import Callable, NoReturn
|
6 |
|
7 |
import gradio as gr
|
8 |
+
import spaces
|
9 |
import uvicorn
|
10 |
from fastapi import FastAPI, HTTPException, Request, status
|
11 |
from fastapi.exceptions import RequestValidationError
|
|
|
14 |
from fastapi.templating import Jinja2Templates
|
15 |
from lisa_on_cuda.utils import app_helpers, frontend_builder, create_folders_and_variables_if_not_exists
|
16 |
from pydantic import ValidationError
|
|
|
|
|
17 |
from samgis_core.utilities.fastapi_logger import setup_logging
|
18 |
from samgis_lisa_on_zero import PROJECT_ROOT_FOLDER, WORKDIR
|
19 |
from samgis_lisa_on_zero.utilities.type_hints import ApiRequestBody, StringPromptApiRequestBody
|
|
|
30 |
app = FastAPI(title=FASTAPI_TITLE, version="1.0")
|
31 |
|
32 |
|
33 |
+
@spaces.GPU
|
34 |
+
def gpu_initialization() -> None:
|
35 |
+
app_logger.info("GPU initialization...")
|
36 |
+
|
37 |
+
|
38 |
def get_gradio_interface_geojson(
|
39 |
fn_inference: Callable
|
40 |
):
|
|
|
147 |
time_start_run = time.time()
|
148 |
body_request = get_parsed_bbox_points_with_string_prompt(request_input)
|
149 |
app_logger.info(f"lisa body_request:{body_request}.")
|
|
|
150 |
try:
|
151 |
+
source = body_request["source"]
|
152 |
+
source_name = body_request["source_name"]
|
153 |
+
app_logger.debug(f"body_request:type(source):{type(source)}, source:{source}.")
|
154 |
+
app_logger.debug(f"body_request:type(source_name):{type(source_name)}, source_name:{source_name}.")
|
155 |
+
app_logger.debug(f"lisa module:{lisa}.")
|
156 |
+
gpu_initialization()
|
157 |
output = lisa.lisa_predict(
|
158 |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
159 |
+
source=source, source_name=source_name, inference_function_name_key=LISA_INFERENCE_FN
|
160 |
)
|
161 |
duration_run = time.time() - time_start_run
|
162 |
app_logger.info(f"duration_run:{duration_run}.")
|
|
|
164 |
"duration_run": duration_run,
|
165 |
"output": output
|
166 |
}
|
167 |
+
dumped = json.dumps(body)
|
168 |
+
app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.")
|
169 |
+
app_logger.debug(f"complete json.dumps(body):{dumped}.")
|
170 |
+
return dumped
|
171 |
except Exception as inference_exception:
|
172 |
handle_exception_response(inference_exception)
|
173 |
except ValidationError as va1:
|
|
|
195 |
body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input)
|
196 |
app_logger.info(f"body_request:{body_request}.")
|
197 |
try:
|
198 |
+
source_name = body_request["source_name"]
|
199 |
app_logger.info(f"source_name = {source_name}.")
|
200 |
output = predictors.samexporter_predict(
|
201 |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
|
|
304 |
return FileResponse(path=static_dist_folder / "lisa.html", media_type="text/html")
|
305 |
|
306 |
|
307 |
+
# index.html (lisa.html copy)
|
308 |
app.mount(VITE_INDEX_URL, StaticFiles(directory=static_dist_folder, html=True), name="index")
|
309 |
|
310 |
|
|
|
313 |
return FileResponse(path=static_dist_folder / "index.html", media_type="text/html")
|
314 |
|
315 |
|
|
|
|
|
|
|
|
|
316 |
app_helpers.app_logger.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...")
|
|
|
317 |
io = get_gradio_interface_geojson(infer_lisa_gradio)
|
318 |
app_helpers.app_logger.info(
|
319 |
f"created gradio interface, mounting gradio app on url {VITE_GRADIO_URL} within FastAPI...")
|
samgis_lisa_on_zero/io_package/wrappers_helpers.py
CHANGED
@@ -83,7 +83,8 @@ def get_parsed_bbox_points_with_string_prompt(request_input: StringPromptApiRequ
|
|
83 |
"bbox": [ne_latlng, sw_latlng],
|
84 |
"prompt": cleaned_prompt,
|
85 |
"zoom": new_zoom,
|
86 |
-
"source": get_url_tile(request_input.source_type)
|
|
|
87 |
}
|
88 |
|
89 |
|
@@ -119,7 +120,8 @@ def get_parsed_bbox_points_with_dictlist_prompt(request_input: ApiRequestBody) -
|
|
119 |
"bbox": [ne_latlng, sw_latlng],
|
120 |
"prompt": new_prompt_list,
|
121 |
"zoom": new_zoom,
|
122 |
-
"source": get_url_tile(request_input.source_type)
|
|
|
123 |
}
|
124 |
|
125 |
|
|
|
83 |
"bbox": [ne_latlng, sw_latlng],
|
84 |
"prompt": cleaned_prompt,
|
85 |
"zoom": new_zoom,
|
86 |
+
"source": get_url_tile(request_input.source_type),
|
87 |
+
"source_name": get_source_name(request_input.source_type)
|
88 |
}
|
89 |
|
90 |
|
|
|
120 |
"bbox": [ne_latlng, sw_latlng],
|
121 |
"prompt": new_prompt_list,
|
122 |
"zoom": new_zoom,
|
123 |
+
"source": get_url_tile(request_input.source_type),
|
124 |
+
"source_name": get_source_name(request_input.source_type)
|
125 |
}
|
126 |
|
127 |
|
samgis_lisa_on_zero/prediction_api/lisa.py
CHANGED
@@ -16,7 +16,9 @@ def load_model_and_inference_fn(inference_function_name_key: str):
|
|
16 |
from samgis_lisa_on_zero.prediction_api.global_models import models_dict
|
17 |
|
18 |
if models_dict[inference_function_name_key]["inference"] is None:
|
19 |
-
|
|
|
|
|
20 |
parsed_args = app_helpers.parse_args([])
|
21 |
inference_fn = app_helpers.get_inference_model_by_args(
|
22 |
parsed_args,
|
@@ -57,10 +59,17 @@ def lisa_predict(
|
|
57 |
from samgis_lisa_on_zero import app_logger
|
58 |
from samgis_lisa_on_zero.prediction_api.global_models import models_dict
|
59 |
|
|
|
|
|
|
|
60 |
app_logger.info("start lisa inference...")
|
|
|
|
|
|
|
61 |
load_model_and_inference_fn(inference_function_name_key)
|
62 |
-
app_logger.debug(f"using a {inference_function_name_key} instance model...")
|
63 |
inference_fn = models_dict[inference_function_name_key]["inference"]
|
|
|
64 |
|
65 |
pt0, pt1 = bbox
|
66 |
app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
|
@@ -80,15 +89,14 @@ def lisa_predict(
|
|
80 |
app_logger.info("keep all temp data in memory...")
|
81 |
|
82 |
app_logger.info(f"lisa_zero, source_name:{source_name}, source_name type:{type(source_name)}.")
|
83 |
-
app_logger.info(f"lisa_zero, prompt
|
84 |
app_logger.info(f"lisa_zero, prompt:{prompt}.")
|
85 |
prompt_str = str(prompt)
|
86 |
-
app_logger.info(f"lisa_zero, img
|
87 |
embedding_key = f"{source_name}_z{zoom}_{prefix}"
|
88 |
_, mask, output_string = inference_fn(input_str=prompt_str, input_image=img, embedding_key=embedding_key)
|
89 |
-
app_logger.info(f"lisa_zero, output_string
|
90 |
-
app_logger.info(f"lisa_zero,
|
91 |
-
app_logger.info(f"lisa_zero, mask_output tpye:{type(mask)}.")
|
92 |
app_logger.info(f"created output_string '{output_string}', preparing conversion to geojson...")
|
93 |
return {
|
94 |
"output_string": output_string,
|
|
|
16 |
from samgis_lisa_on_zero.prediction_api.global_models import models_dict
|
17 |
|
18 |
if models_dict[inference_function_name_key]["inference"] is None:
|
19 |
+
msg = f"missing inference function {inference_function_name_key}, "
|
20 |
+
msg += f"instantiating it now using inference_decorator {SPACES_GPU}!"
|
21 |
+
app_logger.info(msg)
|
22 |
parsed_args = app_helpers.parse_args([])
|
23 |
inference_fn = app_helpers.get_inference_model_by_args(
|
24 |
parsed_args,
|
|
|
59 |
from samgis_lisa_on_zero import app_logger
|
60 |
from samgis_lisa_on_zero.prediction_api.global_models import models_dict
|
61 |
|
62 |
+
if source_name is None:
|
63 |
+
source_name = str(source)
|
64 |
+
|
65 |
app_logger.info("start lisa inference...")
|
66 |
+
app_logger.debug(f"type(source):{type(source)}, source:{source},")
|
67 |
+
app_logger.debug(f"type(source_name):{type(source_name)}, source_name:{source_name}.")
|
68 |
+
|
69 |
load_model_and_inference_fn(inference_function_name_key)
|
70 |
+
app_logger.debug(f"using a '{inference_function_name_key}' instance model...")
|
71 |
inference_fn = models_dict[inference_function_name_key]["inference"]
|
72 |
+
app_logger.info(f"loaded inference function '{inference_fn.__name__}'.")
|
73 |
|
74 |
pt0, pt1 = bbox
|
75 |
app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
|
|
|
89 |
app_logger.info("keep all temp data in memory...")
|
90 |
|
91 |
app_logger.info(f"lisa_zero, source_name:{source_name}, source_name type:{type(source_name)}.")
|
92 |
+
app_logger.info(f"lisa_zero, prompt type:{type(prompt)}.")
|
93 |
app_logger.info(f"lisa_zero, prompt:{prompt}.")
|
94 |
prompt_str = str(prompt)
|
95 |
+
app_logger.info(f"lisa_zero, img type:{type(img)}.")
|
96 |
embedding_key = f"{source_name}_z{zoom}_{prefix}"
|
97 |
_, mask, output_string = inference_fn(input_str=prompt_str, input_image=img, embedding_key=embedding_key)
|
98 |
+
app_logger.info(f"lisa_zero, output_string type:{type(output_string)}.")
|
99 |
+
app_logger.info(f"lisa_zero, mask_output type:{type(mask)}.")
|
|
|
100 |
app_logger.info(f"created output_string '{output_string}', preparing conversion to geojson...")
|
101 |
return {
|
102 |
"output_string": output_string,
|