alessandro trinca tornidor commited on
Commit
6d22b1d
·
1 Parent(s): 1186076

[fxi] get_inference_model_by_args: use args_to_parse argument

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -239,12 +239,12 @@ def get_inference_model_by_args(args_to_parse):
239
  return output_image, output_str
240
 
241
  # Model Inference
242
- conv = conversation_lib.conv_templates[args.conv_type].copy()
243
  conv.messages = []
244
 
245
  prompt = input_str
246
  prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
247
- if args.use_mm_start_end:
248
  replace_token = (
249
  DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
250
  )
@@ -265,9 +265,9 @@ def get_inference_model_by_args(args_to_parse):
265
  .unsqueeze(0)
266
  .cuda()
267
  )
268
- if args.precision == "bf16":
269
  image_clip = image_clip.bfloat16()
270
- elif args.precision == "fp16":
271
  image_clip = image_clip.half()
272
  else:
273
  image_clip = image_clip.float()
@@ -280,9 +280,9 @@ def get_inference_model_by_args(args_to_parse):
280
  .unsqueeze(0)
281
  .cuda()
282
  )
283
- if args.precision == "bf16":
284
  image = image.bfloat16()
285
- elif args.precision == "fp16":
286
  image = image.half()
287
  else:
288
  image = image.float()
 
239
  return output_image, output_str
240
 
241
  # Model Inference
242
+ conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
243
  conv.messages = []
244
 
245
  prompt = input_str
246
  prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
247
+ if args_to_parse.use_mm_start_end:
248
  replace_token = (
249
  DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
250
  )
 
265
  .unsqueeze(0)
266
  .cuda()
267
  )
268
+ if args_to_parse.precision == "bf16":
269
  image_clip = image_clip.bfloat16()
270
+ elif args_to_parse.precision == "fp16":
271
  image_clip = image_clip.half()
272
  else:
273
  image_clip = image_clip.float()
 
280
  .unsqueeze(0)
281
  .cuda()
282
  )
283
+ if args_to_parse.precision == "bf16":
284
  image = image.bfloat16()
285
+ elif args_to_parse.precision == "fp16":
286
  image = image.half()
287
  else:
288
  image = image.float()