alessandro trinca tornidor commited on
Commit
1186076
·
1 Parent(s): 8facf64

[refactor] return the inference function to inject the model

Browse files
Files changed (1) hide show
  1. app.py +133 -131
app.py CHANGED
@@ -196,138 +196,138 @@ def get_model(args_to_parse):
196
  return _model, _clip_image_processor, _tokenizer, _transform
197
 
198
 
199
- args = parse_args(sys.argv[1:])
200
- model, clip_image_processor, tokenizer, transform = get_model(args)
201
-
202
-
203
- ## to be implemented
204
- def inference(input_str, input_image):
205
- ## filter out special chars
206
-
207
- input_str = nh3.clean(
208
- input_str,
209
- tags={
210
- "a",
211
- "abbr",
212
- "acronym",
213
- "b",
214
- "blockquote",
215
- "code",
216
- "em",
217
- "i",
218
- "li",
219
- "ol",
220
- "strong",
221
- "ul",
222
- },
223
- attributes={
224
- "a": {"href", "title"},
225
- "abbr": {"title"},
226
- "acronym": {"title"},
227
- },
228
- url_schemes={"http", "https", "mailto"},
229
- link_rel=None,
230
- )
231
-
232
- print("input_str: ", input_str, "input_image: ", input_image)
233
-
234
- ## input valid check
235
- if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
236
- output_str = "[Error] Invalid input: ", input_str
237
- # output_image = np.zeros((128, 128, 3))
238
- ## error happened
239
- output_image = cv2.imread("./resources/error_happened.png")[:, :, ::-1]
240
- return output_image, output_str
241
-
242
- # Model Inference
243
- conv = conversation_lib.conv_templates[args.conv_type].copy()
244
- conv.messages = []
245
-
246
- prompt = input_str
247
- prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
248
- if args.use_mm_start_end:
249
- replace_token = (
250
- DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
251
  )
252
- prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
253
-
254
- conv.append_message(conv.roles[0], prompt)
255
- conv.append_message(conv.roles[1], "")
256
- prompt = conv.get_prompt()
257
-
258
- image_np = cv2.imread(input_image)
259
- image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
260
- original_size_list = [image_np.shape[:2]]
261
-
262
- image_clip = (
263
- clip_image_processor.preprocess(image_np, return_tensors="pt")[
264
- "pixel_values"
265
- ][0]
266
- .unsqueeze(0)
267
- .cuda()
268
- )
269
- if args.precision == "bf16":
270
- image_clip = image_clip.bfloat16()
271
- elif args.precision == "fp16":
272
- image_clip = image_clip.half()
273
- else:
274
- image_clip = image_clip.float()
275
-
276
- image = transform.apply_image(image_np)
277
- resize_list = [image.shape[:2]]
278
-
279
- image = (
280
- preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
281
- .unsqueeze(0)
282
- .cuda()
283
- )
284
- if args.precision == "bf16":
285
- image = image.bfloat16()
286
- elif args.precision == "fp16":
287
- image = image.half()
288
- else:
289
- image = image.float()
290
-
291
- input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
292
- input_ids = input_ids.unsqueeze(0).cuda()
293
-
294
- output_ids, pred_masks = model.evaluate(
295
- image_clip,
296
- image,
297
- input_ids,
298
- resize_list,
299
- original_size_list,
300
- max_new_tokens=512,
301
- tokenizer=tokenizer,
302
- )
303
- output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
304
 
305
- text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
306
- text_output = text_output.replace("\n", "").replace(" ", " ")
307
- text_output = text_output.split("ASSISTANT: ")[-1]
308
-
309
- print("text_output: ", text_output)
310
- save_img = None
311
- for i, pred_mask in enumerate(pred_masks):
312
- if pred_mask.shape[0] == 0:
313
- continue
314
-
315
- pred_mask = pred_mask.detach().cpu().numpy()[0]
316
- pred_mask = pred_mask > 0
317
-
318
- save_img = image_np.copy()
319
- save_img[pred_mask] = (
320
- image_np * 0.5
321
- + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
322
- )[pred_mask]
323
-
324
- output_str = "ASSITANT: " + text_output # input_str
325
- if save_img is not None:
326
- output_image = save_img # input_image
327
- else:
328
- ## no seg output
329
- output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1]
330
- return output_image, output_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
 
333
  def server_runner(
@@ -361,8 +361,10 @@ def server_runner(
361
 
362
 
363
  if __name__ == '__main__':
 
 
364
  server_runner(
365
- inference,
366
  debug=True,
367
  server_name="0.0.0.0"
368
  )
 
196
  return _model, _clip_image_processor, _tokenizer, _transform
197
 
198
 
199
+ def get_inference_model_by_args(args_to_parse):
200
+ model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
201
+
202
+ ## to be implemented
203
+ def inference(input_str, input_image):
204
+ ## filter out special chars
205
+
206
+ input_str = nh3.clean(
207
+ input_str,
208
+ tags={
209
+ "a",
210
+ "abbr",
211
+ "acronym",
212
+ "b",
213
+ "blockquote",
214
+ "code",
215
+ "em",
216
+ "i",
217
+ "li",
218
+ "ol",
219
+ "strong",
220
+ "ul",
221
+ },
222
+ attributes={
223
+ "a": {"href", "title"},
224
+ "abbr": {"title"},
225
+ "acronym": {"title"},
226
+ },
227
+ url_schemes={"http", "https", "mailto"},
228
+ link_rel=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ print("input_str: ", input_str, "input_image: ", input_image)
232
+
233
+ ## input valid check
234
+ if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
235
+ output_str = "[Error] Invalid input: ", input_str
236
+ # output_image = np.zeros((128, 128, 3))
237
+ ## error happened
238
+ output_image = cv2.imread("./resources/error_happened.png")[:, :, ::-1]
239
+ return output_image, output_str
240
+
241
+ # Model Inference
242
+ conv = conversation_lib.conv_templates[args.conv_type].copy()
243
+ conv.messages = []
244
+
245
+ prompt = input_str
246
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
247
+ if args.use_mm_start_end:
248
+ replace_token = (
249
+ DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
250
+ )
251
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
252
+
253
+ conv.append_message(conv.roles[0], prompt)
254
+ conv.append_message(conv.roles[1], "")
255
+ prompt = conv.get_prompt()
256
+
257
+ image_np = cv2.imread(input_image)
258
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
259
+ original_size_list = [image_np.shape[:2]]
260
+
261
+ image_clip = (
262
+ clip_image_processor.preprocess(image_np, return_tensors="pt")[
263
+ "pixel_values"
264
+ ][0]
265
+ .unsqueeze(0)
266
+ .cuda()
267
+ )
268
+ if args.precision == "bf16":
269
+ image_clip = image_clip.bfloat16()
270
+ elif args.precision == "fp16":
271
+ image_clip = image_clip.half()
272
+ else:
273
+ image_clip = image_clip.float()
274
+
275
+ image = transform.apply_image(image_np)
276
+ resize_list = [image.shape[:2]]
277
+
278
+ image = (
279
+ preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
280
+ .unsqueeze(0)
281
+ .cuda()
282
+ )
283
+ if args.precision == "bf16":
284
+ image = image.bfloat16()
285
+ elif args.precision == "fp16":
286
+ image = image.half()
287
+ else:
288
+ image = image.float()
289
+
290
+ input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
291
+ input_ids = input_ids.unsqueeze(0).cuda()
292
+
293
+ output_ids, pred_masks = model.evaluate(
294
+ image_clip,
295
+ image,
296
+ input_ids,
297
+ resize_list,
298
+ original_size_list,
299
+ max_new_tokens=512,
300
+ tokenizer=tokenizer,
301
+ )
302
+ output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
303
+
304
+ text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
305
+ text_output = text_output.replace("\n", "").replace(" ", " ")
306
+ text_output = text_output.split("ASSISTANT: ")[-1]
307
+
308
+ print("text_output: ", text_output)
309
+ save_img = None
310
+ for i, pred_mask in enumerate(pred_masks):
311
+ if pred_mask.shape[0] == 0:
312
+ continue
313
+
314
+ pred_mask = pred_mask.detach().cpu().numpy()[0]
315
+ pred_mask = pred_mask > 0
316
+
317
+ save_img = image_np.copy()
318
+ save_img[pred_mask] = (
319
+ image_np * 0.5
320
+ + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
321
+ )[pred_mask]
322
+
323
+ output_str = "ASSITANT: " + text_output # input_str
324
+ if save_img is not None:
325
+ output_image = save_img # input_image
326
+ else:
327
+ ## no seg output
328
+ output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1]
329
+ return output_image, output_str
330
+ return inference
331
 
332
 
333
  def server_runner(
 
361
 
362
 
363
  if __name__ == '__main__':
364
+ args = parse_args(sys.argv[1:])
365
+ inference_fn = get_inference_model_by_args(args)
366
  server_runner(
367
+ inference_fn,
368
  debug=True,
369
  server_name="0.0.0.0"
370
  )