alessandro trinca tornidor commited on
Commit
a5e4002
·
1 Parent(s): 00f8875

refactor: remove unuseful app_helpers.get_inference_model_by_args(args, inference_decorator=SPACES_GPU), fix logs, initialize gpu within infer_lisa_gradio()

Browse files
.idea/vcs.xml CHANGED
@@ -2,5 +2,6 @@
2
  <project version="4">
3
  <component name="VcsDirectoryMappings">
4
  <mapping directory="" vcs="Git" />
 
5
  </component>
6
  </project>
 
2
  <project version="4">
3
  <component name="VcsDirectoryMappings">
4
  <mapping directory="" vcs="Git" />
5
+ <mapping directory="$PROJECT_DIR$/sam-quantized" vcs="Git" />
6
  </component>
7
  </project>
app.py CHANGED
@@ -5,6 +5,7 @@ import uuid
5
  from typing import Callable, NoReturn
6
 
7
  import gradio as gr
 
8
  import uvicorn
9
  from fastapi import FastAPI, HTTPException, Request, status
10
  from fastapi.exceptions import RequestValidationError
@@ -13,8 +14,6 @@ from fastapi.staticfiles import StaticFiles
13
  from fastapi.templating import Jinja2Templates
14
  from lisa_on_cuda.utils import app_helpers, frontend_builder, create_folders_and_variables_if_not_exists
15
  from pydantic import ValidationError
16
- from spaces import GPU as SPACES_GPU
17
-
18
  from samgis_core.utilities.fastapi_logger import setup_logging
19
  from samgis_lisa_on_zero import PROJECT_ROOT_FOLDER, WORKDIR
20
  from samgis_lisa_on_zero.utilities.type_hints import ApiRequestBody, StringPromptApiRequestBody
@@ -31,6 +30,11 @@ FASTAPI_TITLE = "samgis-lisa-on-zero"
31
  app = FastAPI(title=FASTAPI_TITLE, version="1.0")
32
 
33
 
 
 
 
 
 
34
  def get_gradio_interface_geojson(
35
  fn_inference: Callable
36
  ):
@@ -143,13 +147,16 @@ def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> JSONResponse
143
  time_start_run = time.time()
144
  body_request = get_parsed_bbox_points_with_string_prompt(request_input)
145
  app_logger.info(f"lisa body_request:{body_request}.")
146
- app_logger.info(f"lisa module:{lisa}.")
147
  try:
148
- source_name = get_source_name(request_input.source_type)
149
- app_logger.info(f"source_name = {source_name}.")
 
 
 
 
150
  output = lisa.lisa_predict(
151
  bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
152
- source=body_request["source"], source_name=source_name, inference_function_name_key=LISA_INFERENCE_FN
153
  )
154
  duration_run = time.time() - time_start_run
155
  app_logger.info(f"duration_run:{duration_run}.")
@@ -157,9 +164,10 @@ def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> JSONResponse
157
  "duration_run": duration_run,
158
  "output": output
159
  }
160
- app_logger.info(f"json.dumps(body):{json.dumps(body)}.")
161
- # return JSONResponse(status_code=200, content={"body": json.dumps(body)})
162
- return json.dumps(body)
 
163
  except Exception as inference_exception:
164
  handle_exception_response(inference_exception)
165
  except ValidationError as va1:
@@ -187,7 +195,7 @@ def infer_samgis(request_input: ApiRequestBody) -> JSONResponse:
187
  body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input)
188
  app_logger.info(f"body_request:{body_request}.")
189
  try:
190
- source_name = get_source_name(request_input.source_type)
191
  app_logger.info(f"source_name = {source_name}.")
192
  output = predictors.samexporter_predict(
193
  bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
@@ -296,7 +304,7 @@ async def lisa() -> FileResponse:
296
  return FileResponse(path=static_dist_folder / "lisa.html", media_type="text/html")
297
 
298
 
299
- # # index.html (lisa.html copy)
300
  app.mount(VITE_INDEX_URL, StaticFiles(directory=static_dist_folder, html=True), name="index")
301
 
302
 
@@ -305,12 +313,7 @@ async def index() -> FileResponse:
305
  return FileResponse(path=static_dist_folder / "index.html", media_type="text/html")
306
 
307
 
308
- args = app_helpers.parse_args([])
309
- app_helpers.app_logger.info(f"prepared default arguments:{args}.")
310
- inference_fn = app_helpers.get_inference_model_by_args(args, inference_decorator=SPACES_GPU)
311
-
312
  app_helpers.app_logger.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...")
313
- # io_package = app_helpers.get_gradio_interface(inference_fn)
314
  io = get_gradio_interface_geojson(infer_lisa_gradio)
315
  app_helpers.app_logger.info(
316
  f"created gradio interface, mounting gradio app on url {VITE_GRADIO_URL} within FastAPI...")
 
5
  from typing import Callable, NoReturn
6
 
7
  import gradio as gr
8
+ import spaces
9
  import uvicorn
10
  from fastapi import FastAPI, HTTPException, Request, status
11
  from fastapi.exceptions import RequestValidationError
 
14
  from fastapi.templating import Jinja2Templates
15
  from lisa_on_cuda.utils import app_helpers, frontend_builder, create_folders_and_variables_if_not_exists
16
  from pydantic import ValidationError
 
 
17
  from samgis_core.utilities.fastapi_logger import setup_logging
18
  from samgis_lisa_on_zero import PROJECT_ROOT_FOLDER, WORKDIR
19
  from samgis_lisa_on_zero.utilities.type_hints import ApiRequestBody, StringPromptApiRequestBody
 
30
  app = FastAPI(title=FASTAPI_TITLE, version="1.0")
31
 
32
 
33
+ @spaces.GPU
34
+ def gpu_initialization() -> None:
35
+ app_logger.info("GPU initialization...")
36
+
37
+
38
  def get_gradio_interface_geojson(
39
  fn_inference: Callable
40
  ):
 
147
  time_start_run = time.time()
148
  body_request = get_parsed_bbox_points_with_string_prompt(request_input)
149
  app_logger.info(f"lisa body_request:{body_request}.")
 
150
  try:
151
+ source = body_request["source"]
152
+ source_name = body_request["source_name"]
153
+ app_logger.debug(f"body_request:type(source):{type(source)}, source:{source}.")
154
+ app_logger.debug(f"body_request:type(source_name):{type(source_name)}, source_name:{source_name}.")
155
+ app_logger.debug(f"lisa module:{lisa}.")
156
+ gpu_initialization()
157
  output = lisa.lisa_predict(
158
  bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
159
+ source=source, source_name=source_name, inference_function_name_key=LISA_INFERENCE_FN
160
  )
161
  duration_run = time.time() - time_start_run
162
  app_logger.info(f"duration_run:{duration_run}.")
 
164
  "duration_run": duration_run,
165
  "output": output
166
  }
167
+ dumped = json.dumps(body)
168
+ app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.")
169
+ app_logger.debug(f"complete json.dumps(body):{dumped}.")
170
+ return dumped
171
  except Exception as inference_exception:
172
  handle_exception_response(inference_exception)
173
  except ValidationError as va1:
 
195
  body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input)
196
  app_logger.info(f"body_request:{body_request}.")
197
  try:
198
+ source_name = body_request["source_name"]
199
  app_logger.info(f"source_name = {source_name}.")
200
  output = predictors.samexporter_predict(
201
  bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
 
304
  return FileResponse(path=static_dist_folder / "lisa.html", media_type="text/html")
305
 
306
 
307
+ # index.html (lisa.html copy)
308
  app.mount(VITE_INDEX_URL, StaticFiles(directory=static_dist_folder, html=True), name="index")
309
 
310
 
 
313
  return FileResponse(path=static_dist_folder / "index.html", media_type="text/html")
314
 
315
 
 
 
 
 
316
  app_helpers.app_logger.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...")
 
317
  io = get_gradio_interface_geojson(infer_lisa_gradio)
318
  app_helpers.app_logger.info(
319
  f"created gradio interface, mounting gradio app on url {VITE_GRADIO_URL} within FastAPI...")
samgis_lisa_on_zero/io_package/wrappers_helpers.py CHANGED
@@ -83,7 +83,8 @@ def get_parsed_bbox_points_with_string_prompt(request_input: StringPromptApiRequ
83
  "bbox": [ne_latlng, sw_latlng],
84
  "prompt": cleaned_prompt,
85
  "zoom": new_zoom,
86
- "source": get_url_tile(request_input.source_type)
 
87
  }
88
 
89
 
@@ -119,7 +120,8 @@ def get_parsed_bbox_points_with_dictlist_prompt(request_input: ApiRequestBody) -
119
  "bbox": [ne_latlng, sw_latlng],
120
  "prompt": new_prompt_list,
121
  "zoom": new_zoom,
122
- "source": get_url_tile(request_input.source_type)
 
123
  }
124
 
125
 
 
83
  "bbox": [ne_latlng, sw_latlng],
84
  "prompt": cleaned_prompt,
85
  "zoom": new_zoom,
86
+ "source": get_url_tile(request_input.source_type),
87
+ "source_name": get_source_name(request_input.source_type)
88
  }
89
 
90
 
 
120
  "bbox": [ne_latlng, sw_latlng],
121
  "prompt": new_prompt_list,
122
  "zoom": new_zoom,
123
+ "source": get_url_tile(request_input.source_type),
124
+ "source_name": get_source_name(request_input.source_type)
125
  }
126
 
127
 
samgis_lisa_on_zero/prediction_api/lisa.py CHANGED
@@ -16,7 +16,9 @@ def load_model_and_inference_fn(inference_function_name_key: str):
16
  from samgis_lisa_on_zero.prediction_api.global_models import models_dict
17
 
18
  if models_dict[inference_function_name_key]["inference"] is None:
19
- app_logger.info(f"missing inference function {inference_function_name_key}, instantiating it now using inference_decorator {SPACES_GPU}!")
 
 
20
  parsed_args = app_helpers.parse_args([])
21
  inference_fn = app_helpers.get_inference_model_by_args(
22
  parsed_args,
@@ -57,10 +59,17 @@ def lisa_predict(
57
  from samgis_lisa_on_zero import app_logger
58
  from samgis_lisa_on_zero.prediction_api.global_models import models_dict
59
 
 
 
 
60
  app_logger.info("start lisa inference...")
 
 
 
61
  load_model_and_inference_fn(inference_function_name_key)
62
- app_logger.debug(f"using a {inference_function_name_key} instance model...")
63
  inference_fn = models_dict[inference_function_name_key]["inference"]
 
64
 
65
  pt0, pt1 = bbox
66
  app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
@@ -80,15 +89,14 @@ def lisa_predict(
80
  app_logger.info("keep all temp data in memory...")
81
 
82
  app_logger.info(f"lisa_zero, source_name:{source_name}, source_name type:{type(source_name)}.")
83
- app_logger.info(f"lisa_zero, prompt tpye:{type(prompt)}.")
84
  app_logger.info(f"lisa_zero, prompt:{prompt}.")
85
  prompt_str = str(prompt)
86
- app_logger.info(f"lisa_zero, img tpye:{type(img)}.")
87
  embedding_key = f"{source_name}_z{zoom}_{prefix}"
88
  _, mask, output_string = inference_fn(input_str=prompt_str, input_image=img, embedding_key=embedding_key)
89
- app_logger.info(f"lisa_zero, output_string tpye:{type(output_string)}.")
90
- app_logger.info(f"lisa_zero, output_string:{output_string}.")
91
- app_logger.info(f"lisa_zero, mask_output tpye:{type(mask)}.")
92
  app_logger.info(f"created output_string '{output_string}', preparing conversion to geojson...")
93
  return {
94
  "output_string": output_string,
 
16
  from samgis_lisa_on_zero.prediction_api.global_models import models_dict
17
 
18
  if models_dict[inference_function_name_key]["inference"] is None:
19
+ msg = f"missing inference function {inference_function_name_key}, "
20
+ msg += f"instantiating it now using inference_decorator {SPACES_GPU}!"
21
+ app_logger.info(msg)
22
  parsed_args = app_helpers.parse_args([])
23
  inference_fn = app_helpers.get_inference_model_by_args(
24
  parsed_args,
 
59
  from samgis_lisa_on_zero import app_logger
60
  from samgis_lisa_on_zero.prediction_api.global_models import models_dict
61
 
62
+ if source_name is None:
63
+ source_name = str(source)
64
+
65
  app_logger.info("start lisa inference...")
66
+ app_logger.debug(f"type(source):{type(source)}, source:{source},")
67
+ app_logger.debug(f"type(source_name):{type(source_name)}, source_name:{source_name}.")
68
+
69
  load_model_and_inference_fn(inference_function_name_key)
70
+ app_logger.debug(f"using a '{inference_function_name_key}' instance model...")
71
  inference_fn = models_dict[inference_function_name_key]["inference"]
72
+ app_logger.info(f"loaded inference function '{inference_fn.__name__}'.")
73
 
74
  pt0, pt1 = bbox
75
  app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
 
89
  app_logger.info("keep all temp data in memory...")
90
 
91
  app_logger.info(f"lisa_zero, source_name:{source_name}, source_name type:{type(source_name)}.")
92
+ app_logger.info(f"lisa_zero, prompt type:{type(prompt)}.")
93
  app_logger.info(f"lisa_zero, prompt:{prompt}.")
94
  prompt_str = str(prompt)
95
+ app_logger.info(f"lisa_zero, img type:{type(img)}.")
96
  embedding_key = f"{source_name}_z{zoom}_{prefix}"
97
  _, mask, output_string = inference_fn(input_str=prompt_str, input_image=img, embedding_key=embedding_key)
98
+ app_logger.info(f"lisa_zero, output_string type:{type(output_string)}.")
99
+ app_logger.info(f"lisa_zero, mask_output type:{type(mask)}.")
 
100
  app_logger.info(f"created output_string '{output_string}', preparing conversion to geojson...")
101
  return {
102
  "output_string": output_string,