zhanghaoji commited on
Commit
cc32390
·
1 Parent(s): 5d42cba

fix tensor

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. flash_vstream/serve/demo.py +1 -2
app.py CHANGED
@@ -53,11 +53,11 @@ def generate(video, textbox_in, first_run, state, state_, images_tensor):
53
 
54
  if os.path.exists(video):
55
  video_tensor = handler._get_rawvideo_dec(video, image_processor, max_frames=MAX_IMAGE_LENGTH)
56
- for img in video_tensor:
57
- images_tensor.append(image_processor(img, return_tensors='pt')['pixel_values'][0].to(handler.model.device, dtype=torch.float16))
58
 
59
  if os.path.exists(video):
60
- text_en_in = DEFAULT_IMAGE_TOKEN * len(video_tensor) + '\n' + text_en_in
61
 
62
  text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
63
  state_.messages[-1] = (state_.roles[1], text_en_out)
 
53
 
54
  if os.path.exists(video):
55
  video_tensor = handler._get_rawvideo_dec(video, image_processor, max_frames=MAX_IMAGE_LENGTH)
56
+ images_tensor = image_processor(video_tensor, return_tensors='pt')['pixel_values'].to(handler.model.device, dtype=torch.float16)
57
+ print("video_tensor", video_tensor.shape)
58
 
59
  if os.path.exists(video):
60
+ text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in
61
 
62
  text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
63
  state_.messages[-1] = (state_.roles[1], text_en_out)
flash_vstream/serve/demo.py CHANGED
@@ -75,14 +75,13 @@ class Chat:
75
  return patch_images
76
 
77
  @torch.inference_mode()
78
- def generate(self, images_tensor: list, prompt: str, first_run: bool, state):
79
  tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
80
 
81
  state = self.get_prompt(prompt, state)
82
  prompt = state.get_prompt()
83
  print(prompt)
84
 
85
- images_tensor = torch.stack(images_tensor, dim=0)
86
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
87
 
88
  temperature = 0.2
 
75
  return patch_images
76
 
77
  @torch.inference_mode()
78
+ def generate(self, images_tensor, prompt, first_run, state):
79
  tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
80
 
81
  state = self.get_prompt(prompt, state)
82
  prompt = state.get_prompt()
83
  print(prompt)
84
 
 
85
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
86
 
87
  temperature = 0.2