alessandro trinca tornidor commited on
Commit
a84a5a1
·
1 Parent(s): b21c563

[refactor] start reverting to original app.py content/1

Browse files
Files changed (1) hide show
  1. app.py +88 -10
app.py CHANGED
@@ -1,10 +1,29 @@
 
 
1
  import gradio as gr
2
  import json
3
  import logging
4
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
6
  from utils import session_logger
7
-
 
8
 
9
  session_logger.change_logging(logging.DEBUG)
10
 
@@ -12,6 +31,11 @@ session_logger.change_logging(logging.DEBUG)
12
  CUSTOM_GRADIO_PATH = "/"
13
  app = FastAPI(title="lisa_app", version="1.0")
14
 
 
 
 
 
 
15
 
16
  @app.get("/health")
17
  @session_logger.set_uuid_logging
@@ -25,20 +49,74 @@ def health() -> str:
25
 
26
 
27
  @session_logger.set_uuid_logging
28
- def request_formatter(text: str) -> str:
29
- logging.info("start request formatting...")
30
- formatted_text = f"transformed {text}."
31
- logging.info(f"formatted request as {formatted_text}.")
32
- return formatted_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
 
34
 
35
- io = gr.Interface(
36
- request_formatter,
 
 
 
 
 
 
 
37
  inputs=[
38
- gr.Textbox(lines=1, placeholder=None, label="Text input"),
 
39
  ],
40
  outputs=[
 
41
  gr.Textbox(lines=1, placeholder=None, label="Text Output"),
42
  ],
 
 
 
 
 
43
  )
 
 
 
 
 
44
  app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)
 
1
+ import argparse
2
+ import cv2
3
  import gradio as gr
4
  import json
5
  import logging
6
+ import nh3
7
+ import numpy as np
8
+ import os
9
+ import re
10
+ import sys
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from fastapi import FastAPI, File, UploadFile, Request
14
+ from fastapi.responses import HTMLResponse, RedirectResponse
15
+ from fastapi.staticfiles import StaticFiles
16
+ from fastapi.templating import Jinja2Templates
17
+ from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
18
+ from typing import Callable
19
 
20
+ from model.LISA import LISAForCausalLM
21
+ from model.llava import conversation as conversation_lib
22
+ from model.llava.mm_utils import tokenizer_image_token
23
+ from model.segment_anything.utils.transforms import ResizeLongestSide
24
  from utils import session_logger
25
+ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
26
+ DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
27
 
28
  session_logger.change_logging(logging.DEBUG)
29
 
 
31
  CUSTOM_GRADIO_PATH = "/"
32
  app = FastAPI(title="lisa_app", version="1.0")
33
 
34
+ FASTAPI_STATIC = os.getenv("FASTAPI_STATIC")
35
+ os.makedirs(FASTAPI_STATIC, exist_ok=True)
36
+ app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static")
37
+ templates = Jinja2Templates(directory="templates")
38
+
39
 
40
  @app.get("/health")
41
  @session_logger.set_uuid_logging
 
49
 
50
 
51
  @session_logger.set_uuid_logging
52
+ def get_cleaned_input(input_str):
53
+ logging.info(f"start cleaning of input_str: {input_str}.")
54
+ input_str = nh3.clean(
55
+ input_str,
56
+ tags={
57
+ "a",
58
+ "abbr",
59
+ "acronym",
60
+ "b",
61
+ "blockquote",
62
+ "code",
63
+ "em",
64
+ "i",
65
+ "li",
66
+ "ol",
67
+ "strong",
68
+ "ul",
69
+ },
70
+ attributes={
71
+ "a": {"href", "title"},
72
+ "abbr": {"title"},
73
+ "acronym": {"title"},
74
+ },
75
+ url_schemes={"http", "https", "mailto"},
76
+ link_rel=None,
77
+ )
78
+ logging.info(f"cleaned input_str: {input_str}.")
79
+ return input_str
80
+
81
+
82
+ @session_logger.set_uuid_logging
83
+ def get_inference_model_by_args(args_to_parse):
84
+ logging.info(f"args_to_parse:{args_to_parse}.")
85
+
86
+ @session_logger.set_uuid_logging
87
+ def inference(input_str, input_image):
88
+ ## filter out special chars
89
 
90
+ input_str = get_cleaned_input(input_str)
91
+ logging.info(f"input_str type: {type(input_str)}, input_image type: {type(input_image)}.")
92
+ logging.info(f"input_str: {input_str}.")
93
 
94
+ return output_image, output_str
95
+
96
+ return inference
97
+
98
+
99
+ @session_logger.set_uuid_logging
100
+ def get_gradio_interface(fn_inference: Callable):
101
+ return gr.Interface(
102
+ fn_inference,
103
  inputs=[
104
+ gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
105
+ gr.Image(type="filepath", label="Input Image")
106
  ],
107
  outputs=[
108
+ gr.Image(type="pil", label="Segmentation Output"),
109
  gr.Textbox(lines=1, placeholder=None, label="Text Output"),
110
  ],
111
+ title=title,
112
+ description=description,
113
+ article=article,
114
+ examples=examples,
115
+ allow_flagging="auto",
116
  )
117
+
118
+
119
+ args = parse_args(sys.argv[1:])
120
+ inference_fn = get_inference_model_by_args(args)
121
+ io = get_gradio_interface(inference_fn)
122
  app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)