alessandro trinca tornidor commited on
Commit
3bd20e4
·
1 Parent(s): 37a5f04

[refactor] try mount the gradio app within a fastapi to prepare logging the session id

Browse files
Files changed (1) hide show
  1. app.py +56 -41
app.py CHANGED
@@ -5,6 +5,11 @@ import sys
5
  import logging
6
  from typing import Callable
7
 
 
 
 
 
 
8
  import cv2
9
  import gradio as gr
10
  import nh3
@@ -21,6 +26,51 @@ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
21
  DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Gradio
25
  examples = [
26
  [
@@ -214,30 +264,7 @@ def get_inference_model_by_args(args_to_parse):
214
  def inference(input_str, input_image):
215
  ## filter out special chars
216
 
217
- input_str = nh3.clean(
218
- input_str,
219
- tags={
220
- "a",
221
- "abbr",
222
- "acronym",
223
- "b",
224
- "blockquote",
225
- "code",
226
- "em",
227
- "i",
228
- "li",
229
- "ol",
230
- "strong",
231
- "ul",
232
- },
233
- attributes={
234
- "a": {"href", "title"},
235
- "abbr": {"title"},
236
- "acronym": {"title"},
237
- },
238
- url_schemes={"http", "https", "mailto"},
239
- link_rel=None,
240
- )
241
  logging.info(f"input_str type: {type(input_str)}, input_image type: {type(input_image)}.")
242
  logging.info(f"input_str: {input_str}.")
243
 
@@ -334,12 +361,10 @@ def get_inference_model_by_args(args_to_parse):
334
  return inference
335
 
336
 
337
- def server_runner(
338
- fn_inference: Callable,
339
- debug: bool = False,
340
- server_name: str = "0.0.0.0"
341
  ):
342
- inference_app = gr.Interface(
343
  fn_inference,
344
  inputs=[
345
  gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
@@ -356,19 +381,9 @@ def server_runner(
356
  allow_flagging="auto",
357
  )
358
 
359
- inference_app.queue()
360
- inference_app.launch(
361
- share=False,
362
- debug=debug,
363
- server_name=server_name
364
- )
365
-
366
 
367
  if __name__ == '__main__':
368
  args = parse_args(sys.argv[1:])
369
  inference_fn = get_inference_model_by_args(args)
370
- server_runner(
371
- inference_fn,
372
- debug=True,
373
- server_name="0.0.0.0"
374
- )
 
5
  import logging
6
  from typing import Callable
7
 
8
+ from fastapi import FastAPI, File, UploadFile, Request
9
+ from fastapi.responses import HTMLResponse, RedirectResponse
10
+ from fastapi.staticfiles import StaticFiles
11
+ from fastapi.templating import Jinja2Templates
12
+
13
  import cv2
14
  import gradio as gr
15
  import nh3
 
26
  DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
27
 
28
 
29
+ CUSTOM_GRADIO_PATH = "/gradio"
30
+ app = FastAPI()
31
+ os.makedirs("static", exist_ok=True)
32
+ app.mount("/static", StaticFiles(directory="static"), name="static")
33
+ templates = Jinja2Templates(directory="templates")
34
+
35
+
36
+ def get_cleaned_input(input_str):
37
+ input_str = nh3.clean(
38
+ input_str,
39
+ tags={
40
+ "a",
41
+ "abbr",
42
+ "acronym",
43
+ "b",
44
+ "blockquote",
45
+ "code",
46
+ "em",
47
+ "i",
48
+ "li",
49
+ "ol",
50
+ "strong",
51
+ "ul",
52
+ },
53
+ attributes={
54
+ "a": {"href", "title"},
55
+ "abbr": {"title"},
56
+ "acronym": {"title"},
57
+ },
58
+ url_schemes={"http", "https", "mailto"},
59
+ link_rel=None,
60
+ )
61
+ return input_str
62
+
63
+
64
+ @app.get("/", response_class=HTMLResponse)
65
+ async def home(request: Request):
66
+ logging.info(f"Request raw: {request}.")
67
+ clean_request = get_cleaned_input(str(request))
68
+ logging.info(f"clean_request: {request}.")
69
+ return templates.TemplateResponse(
70
+ "home.html", {"clean_request": clean_request}
71
+ )
72
+
73
+
74
  # Gradio
75
  examples = [
76
  [
 
264
  def inference(input_str, input_image):
265
  ## filter out special chars
266
 
267
+ input_str = get_cleaned_input(input_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  logging.info(f"input_str type: {type(input_str)}, input_image type: {type(input_image)}.")
269
  logging.info(f"input_str: {input_str}.")
270
 
 
361
  return inference
362
 
363
 
364
+ def get_gradio_interface(
365
+ fn_inference: Callable
 
 
366
  ):
367
+ return gr.Interface(
368
  fn_inference,
369
  inputs=[
370
  gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
 
381
  allow_flagging="auto",
382
  )
383
 
 
 
 
 
 
 
 
384
 
385
  if __name__ == '__main__':
386
  args = parse_args(sys.argv[1:])
387
  inference_fn = get_inference_model_by_args(args)
388
+ io = get_gradio_interface(inference_fn)
389
+ app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)