alessandro trinca tornidor commited on
Commit
c5fe4a2
·
1 Parent(s): f623930

[feat] restart from an example app to check what's working

Browse files
Files changed (2) hide show
  1. app.py +26 -377
  2. utils/session_logger.py +36 -0
app.py CHANGED
@@ -1,390 +1,39 @@
1
- import argparse
2
- import os
3
- import re
4
- import sys
5
- import logging
6
- from typing import Callable
7
-
8
- from fastapi import FastAPI, File, UploadFile, Request
9
- from fastapi.responses import HTMLResponse, RedirectResponse
10
- from fastapi.staticfiles import StaticFiles
11
- from fastapi.templating import Jinja2Templates
12
-
13
- import cv2
14
  import gradio as gr
15
- import nh3
16
- import numpy as np
17
- import torch
18
- import torch.nn.functional as F
19
- from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
20
-
21
- from model.LISA import LISAForCausalLM
22
- from model.llava import conversation as conversation_lib
23
- from model.llava.mm_utils import tokenizer_image_token
24
- from model.segment_anything.utils.transforms import ResizeLongestSide
25
- from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
26
- DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
27
 
 
28
 
29
- CUSTOM_GRADIO_PATH = "/gradio"
30
- app = FastAPI()
31
 
32
- FASTAPI_STATIC = os.getenv("FASTAPI_STATIC")
33
- os.makedirs(FASTAPI_STATIC, exist_ok=True)
34
- app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static")
35
- templates = Jinja2Templates(directory="templates")
36
 
37
 
38
- def get_cleaned_input(input_str):
39
- input_str = nh3.clean(
40
- input_str,
41
- tags={
42
- "a",
43
- "abbr",
44
- "acronym",
45
- "b",
46
- "blockquote",
47
- "code",
48
- "em",
49
- "i",
50
- "li",
51
- "ol",
52
- "strong",
53
- "ul",
54
- },
55
- attributes={
56
- "a": {"href", "title"},
57
- "abbr": {"title"},
58
- "acronym": {"title"},
59
- },
60
- url_schemes={"http", "https", "mailto"},
61
- link_rel=None,
62
- )
63
- return input_str
64
 
65
 
66
- @app.get("/", response_class=HTMLResponse)
67
- async def home(request: Request):
68
- logging.info(f"Request raw: {request}.")
69
- clean_request = get_cleaned_input(str(request))
70
- logging.info(f"clean_request: {request}.")
71
- return templates.TemplateResponse(
72
- "home.html", {"clean_request": clean_request}
73
- )
74
 
75
 
76
- # Gradio
77
- examples = [
78
- [
79
- "Where can the driver see the car speed in this image? Please output segmentation mask.",
80
- "./resources/imgs/example1.jpg",
81
  ],
82
- [
83
- "Can you segment the food that tastes spicy and hot?",
84
- "./resources/imgs/example2.jpg",
85
  ],
86
- [
87
- "Assuming you are an autonomous driving robot, what part of the diagram would you manipulate to control the direction of travel? Please output segmentation mask and explain why.",
88
- "./resources/imgs/example1.jpg",
89
- ],
90
- [
91
- "What can make the woman stand higher? Please output segmentation mask and explain why.",
92
- "./resources/imgs/example3.jpg",
93
- ],
94
- ]
95
- output_labels = ["Segmentation Output"]
96
-
97
- title = "LISA: Reasoning Segmentation via Large Language Model"
98
-
99
- description = """
100
- <font size=4>
101
- This is the online demo of LISA. \n
102
- If multiple users are using it at the same time, they will enter a queue, which may delay some time. \n
103
- **Note**: **Different prompts can lead to significantly varied results**. \n
104
- **Note**: Please try to **standardize** your input text prompts to **avoid ambiguity**, and also pay attention to whether the **punctuations** of the input are correct. \n
105
- **Note**: Current model is **LISA-13B-llama2-v0-explanatory**, and 4-bit quantization may impair text-generation quality. \n
106
- **Usage**: <br>
107
- &ensp;(1) To let LISA **segment something**, input prompt like: "Can you segment xxx in this image?", "What is xxx in this image? Please output segmentation mask."; <br>
108
- &ensp;(2) To let LISA **output an explanation**, input prompt like: "What is xxx in this image? Please output segmentation mask and explain why."; <br>
109
- &ensp;(3) To obtain **solely language output**, you can input like what you should do in current multi-modal LLM (e.g., LLaVA). <br>
110
- Hope you can enjoy our work!
111
- </font>
112
- """
113
-
114
- article = """
115
- <p style='text-align: center'>
116
- <a href='https://arxiv.org/abs/2308.00692' target='_blank'>
117
- Preprint Paper
118
- </a>
119
- \n
120
- <p style='text-align: center'>
121
- <a href='https://github.com/dvlab-research/LISA' target='_blank'> Github Repo </a></p>
122
- """
123
-
124
-
125
- def parse_args(args_to_parse):
126
- parser = argparse.ArgumentParser(description="LISA chat")
127
- parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1")
128
- parser.add_argument("--vis_save_path", default="./vis_output", type=str)
129
- parser.add_argument(
130
- "--precision",
131
- default="fp16",
132
- type=str,
133
- choices=["fp32", "bf16", "fp16"],
134
- help="precision for inference",
135
- )
136
- parser.add_argument("--image_size", default=1024, type=int, help="image size")
137
- parser.add_argument("--model_max_length", default=512, type=int)
138
- parser.add_argument("--lora_r", default=8, type=int)
139
- parser.add_argument(
140
- "--vision-tower", default="openai/clip-vit-large-patch14", type=str
141
- )
142
- parser.add_argument("--local-rank", default=0, type=int, help="node rank")
143
- parser.add_argument("--load_in_8bit", action="store_true", default=False)
144
- parser.add_argument("--load_in_4bit", action="store_true", default=False)
145
- parser.add_argument("--use_mm_start_end", action="store_true", default=True)
146
- parser.add_argument(
147
- "--conv_type",
148
- default="llava_v1",
149
- type=str,
150
- choices=["llava_v1", "llava_llama_2"],
151
- )
152
- return parser.parse_args(args_to_parse)
153
-
154
-
155
- def set_image_precision_by_args(input_image, precision):
156
- if precision == "bf16":
157
- input_image = input_image.bfloat16()
158
- elif precision == "fp16":
159
- input_image = input_image.half()
160
- else:
161
- input_image = input_image.float()
162
- return input_image
163
-
164
-
165
- def preprocess(
166
- x,
167
- pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
168
- pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
169
- img_size=1024,
170
- ) -> torch.Tensor:
171
- """Normalize pixel values and pad to a square input."""
172
- # Normalize colors
173
- x = (x - pixel_mean) / pixel_std
174
- # Pad
175
- h, w = x.shape[-2:]
176
- padh = img_size - h
177
- padw = img_size - w
178
- x = F.pad(x, (0, padw, 0, padh))
179
- return x
180
-
181
-
182
- def get_model(args_to_parse):
183
- os.makedirs(args_to_parse.vis_save_path, exist_ok=True)
184
-
185
- # global tokenizer, tokenizer
186
- # Create model
187
- _tokenizer = AutoTokenizer.from_pretrained(
188
- args_to_parse.version,
189
- cache_dir=None,
190
- model_max_length=args_to_parse.model_max_length,
191
- padding_side="right",
192
- use_fast=False,
193
- )
194
- _tokenizer.pad_token = _tokenizer.unk_token
195
- args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
196
- torch_dtype = torch.float32
197
- if args_to_parse.precision == "bf16":
198
- torch_dtype = torch.bfloat16
199
- elif args_to_parse.precision == "fp16":
200
- torch_dtype = torch.half
201
- kwargs = {"torch_dtype": torch_dtype}
202
- if args_to_parse.load_in_4bit:
203
- kwargs.update(
204
- {
205
- "torch_dtype": torch.half,
206
- "load_in_4bit": True,
207
- "quantization_config": BitsAndBytesConfig(
208
- load_in_4bit=True,
209
- bnb_4bit_compute_dtype=torch.float16,
210
- bnb_4bit_use_double_quant=True,
211
- bnb_4bit_quant_type="nf4",
212
- llm_int8_skip_modules=["visual_model"],
213
- ),
214
- }
215
- )
216
- elif args_to_parse.load_in_8bit:
217
- kwargs.update(
218
- {
219
- "torch_dtype": torch.half,
220
- "quantization_config": BitsAndBytesConfig(
221
- llm_int8_skip_modules=["visual_model"],
222
- load_in_8bit=True,
223
- ),
224
- }
225
- )
226
- _model = LISAForCausalLM.from_pretrained(
227
- args_to_parse.version, low_cpu_mem_usage=True, vision_tower=args_to_parse.vision_tower, seg_token_idx=args_to_parse.seg_token_idx, **kwargs
228
- )
229
- _model.config.eos_token_id = _tokenizer.eos_token_id
230
- _model.config.bos_token_id = _tokenizer.bos_token_id
231
- _model.config.pad_token_id = _tokenizer.pad_token_id
232
- _model.get_model().initialize_vision_modules(_model.get_model().config)
233
- vision_tower = _model.get_model().get_vision_tower()
234
- vision_tower.to(dtype=torch_dtype)
235
- if args_to_parse.precision == "bf16":
236
- _model = _model.bfloat16().cuda()
237
- elif (
238
- args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit)
239
- ):
240
- vision_tower = _model.get_model().get_vision_tower()
241
- _model.model.vision_tower = None
242
- import deepspeed
243
-
244
- model_engine = deepspeed.init_inference(
245
- model=_model,
246
- dtype=torch.half,
247
- replace_with_kernel_inject=True,
248
- replace_method="auto",
249
- )
250
- _model = model_engine.module
251
- _model.model.vision_tower = vision_tower.half().cuda()
252
- elif args_to_parse.precision == "fp32":
253
- _model = _model.float().cuda()
254
- vision_tower = _model.get_model().get_vision_tower()
255
- vision_tower.to(device=args_to_parse.local_rank)
256
- _clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower)
257
- _transform = ResizeLongestSide(args_to_parse.image_size)
258
- _model.eval()
259
- return _model, _clip_image_processor, _tokenizer, _transform
260
-
261
-
262
- def get_inference_model_by_args(args_to_parse):
263
- model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
264
-
265
- ## to be implemented
266
- def inference(input_str, input_image):
267
- ## filter out special chars
268
-
269
- input_str = get_cleaned_input(input_str)
270
- logging.info(f"input_str type: {type(input_str)}, input_image type: {type(input_image)}.")
271
- logging.info(f"input_str: {input_str}.")
272
-
273
- ## input valid check
274
- if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
275
- output_str = "[Error] Invalid input: ", input_str
276
- # output_image = np.zeros((128, 128, 3))
277
- ## error happened
278
- output_image = cv2.imread("./resources/error_happened.png")[:, :, ::-1]
279
- return output_image, output_str
280
-
281
- # Model Inference
282
- conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
283
- conv.messages = []
284
-
285
- prompt = input_str
286
- prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
287
- if args_to_parse.use_mm_start_end:
288
- replace_token = (
289
- DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
290
- )
291
- prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
292
-
293
- conv.append_message(conv.roles[0], prompt)
294
- conv.append_message(conv.roles[1], "")
295
- prompt = conv.get_prompt()
296
-
297
- image_np = cv2.imread(input_image)
298
- image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
299
- original_size_list = [image_np.shape[:2]]
300
-
301
- image_clip = (
302
- clip_image_processor.preprocess(image_np, return_tensors="pt")[
303
- "pixel_values"
304
- ][0]
305
- .unsqueeze(0)
306
- .cuda()
307
- )
308
- logging.info(f"image_clip type: {type(image_clip)}.")
309
- image_clip = set_image_precision_by_args(image_clip, args_to_parse.precision)
310
-
311
- image = transform.apply_image(image_np)
312
- resize_list = [image.shape[:2]]
313
-
314
- image = (
315
- preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
316
- .unsqueeze(0)
317
- .cuda()
318
- )
319
- logging.info(f"image_clip type: {type(image_clip)}.")
320
- image = set_image_precision_by_args(image, args_to_parse.precision)
321
-
322
- input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
323
- input_ids = input_ids.unsqueeze(0).cuda()
324
-
325
- output_ids, pred_masks = model.evaluate(
326
- image_clip,
327
- image,
328
- input_ids,
329
- resize_list,
330
- original_size_list,
331
- max_new_tokens=512,
332
- tokenizer=tokenizer,
333
- )
334
- output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
335
-
336
- text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
337
- text_output = text_output.replace("\n", "").replace(" ", " ")
338
- text_output = text_output.split("ASSISTANT: ")[-1]
339
-
340
- logging.info(f"text_output type: {type(text_output)}, text_output: {text_output}.")
341
- save_img = None
342
- for i, pred_mask in enumerate(pred_masks):
343
- if pred_mask.shape[0] == 0:
344
- continue
345
-
346
- pred_mask = pred_mask.detach().cpu().numpy()[0]
347
- pred_mask = pred_mask > 0
348
-
349
- save_img = image_np.copy()
350
- save_img[pred_mask] = (
351
- image_np * 0.5
352
- + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
353
- )[pred_mask]
354
-
355
- output_str = "ASSITANT: " + text_output # input_str
356
- if save_img is not None:
357
- output_image = save_img # input_image
358
- else:
359
- ## no seg output
360
- output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1]
361
- return output_image, output_str
362
-
363
- return inference
364
-
365
-
366
- def get_gradio_interface(
367
- fn_inference: Callable
368
- ):
369
- return gr.Interface(
370
- fn_inference,
371
- inputs=[
372
- gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
373
- gr.Image(type="filepath", label="Input Image")
374
- ],
375
- outputs=[
376
- gr.Image(type="pil", label="Segmentation Output"),
377
- gr.Textbox(lines=1, placeholder=None, label="Text Output"),
378
- ],
379
- title=title,
380
- description=description,
381
- article=article,
382
- examples=examples,
383
- allow_flagging="auto",
384
- )
385
-
386
-
387
- args = parse_args(sys.argv[1:])
388
- inference_fn = get_inference_model_by_args(args)
389
- io = get_gradio_interface(inference_fn)
390
  app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ from utils import session_logger
5
 
 
 
6
 
7
+ CUSTOM_GRADIO_PATH = "/"
8
+ app = FastAPI(title="lisa_app", version="1.0")
 
 
9
 
10
 
11
+ @app.get("/health")
12
+ @session_logger.set_uuid_logging
13
+ def health() -> str:
14
+ try:
15
+ logging.info("health check")
16
+ return json.dumps({"msg": "ok"})
17
+ except Exception as e:
18
+ logging.error(f"exception:{e}.")
19
+ return json.dumps({"msg": "request failed"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
+ @session_logger.set_uuid_logging
23
+ def request_formatter(text: str) -> str:
24
+ logging.info("start request formatting...")
25
+ formatted_text = f"transformed {text}."
26
+ logging.info(f"formatted request as {formatted_text}.")
27
+ return formatted_text
 
 
28
 
29
 
30
+ io = gr.Interface(
31
+ request_formatter,
32
+ inputs=[
33
+ gr.Textbox(lines=1, placeholder=None, label="Text input"),
 
34
  ],
35
+ outputs=[
36
+ gr.Textbox(lines=1, placeholder=None, label="Text Output"),
 
37
  ],
38
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)
utils/session_logger.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextvars
2
+ import logging
3
+ from functools import wraps
4
+ from typing import Callable
5
+
6
+ logging_uuid = contextvars.ContextVar("uuid")
7
+ formatter = '%(asctime)s | %(uuid)s [%(pathname)s:%(module)s %(lineno)d] %(levelname)s | %(message)s'
8
+
9
+
10
+ loggingType = logging.CRITICAL | logging.ERROR | logging.WARNING | logging.INFO | logging.DEBUG
11
+
12
+
13
+ def change_logging(level_log: loggingType = logging.INFO) -> None:
14
+ old_factory = logging.getLogRecordFactory()
15
+
16
+ def record_factory(*args, **kwargs):
17
+ record = old_factory(*args, **kwargs)
18
+ record.uuid = logging_uuid.get("uuid")
19
+ if isinstance(record.msg, str):
20
+ record.msg = record.msg.replace("\\", "\\\\").replace("\n", "\\n")
21
+ return record
22
+
23
+ logging.setLogRecordFactory(record_factory)
24
+ logging.basicConfig(level=level_log, format=formatter, force=True)
25
+
26
+
27
+ def set_uuid_logging(func: Callable) -> Callable:
28
+ @wraps(func)
29
+ def wrapper(*args, **kwargs):
30
+ import uuid
31
+
32
+ current_uuid = f"{uuid.uuid4()}"
33
+ logging_uuid.set(current_uuid)
34
+ return func(*args, **kwargs)
35
+
36
+ return wrapper