tinyllava commited on
Commit
829cb99
1 Parent(s): 3b3297c

Update generate_model.py

Browse files
Files changed (1) hide show
  1. generate_model.py +2 -2
generate_model.py CHANGED
@@ -611,12 +611,12 @@ def generate(
611
  # print('loading image...')
612
  image = load_image(image)
613
  # print('load image over')
614
- image_tensor = process_images(image, image_processor, config).to(model.device, dtype=torch.float16)
615
 
616
  input_ids = (
617
  tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
618
  .unsqueeze(0)
619
- .to(model.device, dtype=torch.float16)
620
  )
621
  # Generate
622
  stime = time.time()
 
611
  # print('loading image...')
612
  image = load_image(image)
613
  # print('load image over')
614
+ image_tensor = process_images(image, image_processor, config).to(model.device)
615
 
616
  input_ids = (
617
  tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
618
  .unsqueeze(0)
619
+ .to(model.device)
620
  )
621
  # Generate
622
  stime = time.time()