alessandro trinca tornidor commited on
Commit
37a5f04
·
1 Parent(s): e789db0

[refactor] set image precision with an external function

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -100,6 +100,16 @@ def parse_args(args_to_parse):
100
  return parser.parse_args(args_to_parse)
101
 
102
 
 
 
 
 
 
 
 
 
 
 
103
  def preprocess(
104
  x,
105
  pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
@@ -267,12 +277,7 @@ def get_inference_model_by_args(args_to_parse):
267
  .cuda()
268
  )
269
  logging.info(f"image_clip type: {type(image_clip)}.")
270
- if args_to_parse.precision == "bf16":
271
- image_clip = image_clip.bfloat16()
272
- elif args_to_parse.precision == "fp16":
273
- image_clip = image_clip.half()
274
- else:
275
- image_clip = image_clip.float()
276
 
277
  image = transform.apply_image(image_np)
278
  resize_list = [image.shape[:2]]
@@ -283,12 +288,7 @@ def get_inference_model_by_args(args_to_parse):
283
  .cuda()
284
  )
285
  logging.info(f"image_clip type: {type(image_clip)}.")
286
- if args_to_parse.precision == "bf16":
287
- image = image.bfloat16()
288
- elif args_to_parse.precision == "fp16":
289
- image = image.half()
290
- else:
291
- image = image.float()
292
 
293
  input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
294
  input_ids = input_ids.unsqueeze(0).cuda()
@@ -330,6 +330,7 @@ def get_inference_model_by_args(args_to_parse):
330
  ## no seg output
331
  output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1]
332
  return output_image, output_str
 
333
  return inference
334
 
335
 
 
100
  return parser.parse_args(args_to_parse)
101
 
102
 
103
+ def set_image_precision_by_args(input_image, precision):
104
+ if precision == "bf16":
105
+ input_image = input_image.bfloat16()
106
+ elif precision == "fp16":
107
+ input_image = input_image.half()
108
+ else:
109
+ input_image = input_image.float()
110
+ return input_image
111
+
112
+
113
  def preprocess(
114
  x,
115
  pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
 
277
  .cuda()
278
  )
279
  logging.info(f"image_clip type: {type(image_clip)}.")
280
+ image_clip = set_image_precision_by_args(image_clip, args_to_parse.precision)
 
 
 
 
 
281
 
282
  image = transform.apply_image(image_np)
283
  resize_list = [image.shape[:2]]
 
288
  .cuda()
289
  )
290
  logging.info(f"image_clip type: {type(image_clip)}.")
291
+ image = set_image_precision_by_args(image, args_to_parse.precision)
 
 
 
 
 
292
 
293
  input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
294
  input_ids = input_ids.unsqueeze(0).cuda()
 
330
  ## no seg output
331
  output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1]
332
  return output_image, output_str
333
+
334
  return inference
335
 
336