zR
commited on
Commit
•
d2415e6
1
Parent(s):
8e2af17
remove load video
Browse files- modeling_cogvlm.py +0 -57
modeling_cogvlm.py
CHANGED
@@ -404,7 +404,6 @@ class CogVLMVideoModel(CogVLMPreTrainedModel):
|
|
404 |
images_features = self.encode_images(images)
|
405 |
images_features = rearrange(images_features, 'b n d -> (b n) d')
|
406 |
images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
407 |
-
|
408 |
inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
|
409 |
else: # single-modality
|
410 |
if token_type_ids is None:
|
@@ -580,62 +579,6 @@ def _history_to_prompt(signal_type, history, query):
|
|
580 |
prompt += 'Question: {} {}'.format(query, answer_format)
|
581 |
return prompt
|
582 |
|
583 |
-
def load_video(video_path):
|
584 |
-
mp4_stream = None
|
585 |
-
decord.bridge.set_bridge('torch')
|
586 |
-
with open(video_path, 'rb') as f:
|
587 |
-
mp4_stream = f.read()
|
588 |
-
clip_end_sec = 60 # clip video to <= 1 minute
|
589 |
-
clip_start_sec = 0
|
590 |
-
num_frames = 24
|
591 |
-
# decord.bridge.set_bridge('torch')
|
592 |
-
if mp4_stream is not None:
|
593 |
-
decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
|
594 |
-
else:
|
595 |
-
decord_vr = VideoReader(video_path, ctx=cpu(0))
|
596 |
-
duration = len(decord_vr) # duration in terms of frames
|
597 |
-
start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
|
598 |
-
end_frame = min(duration, int(clip_end_sec*decord_vr.get_avg_fps())) if \
|
599 |
-
clip_end_sec is not None else duration
|
600 |
-
frame_id_list = np.linspace(start_frame, end_frame-1, num_frames, dtype=int)
|
601 |
-
# frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
|
602 |
-
video_data = decord_vr.get_batch(frame_id_list)
|
603 |
-
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
|
604 |
-
# video_outputs = transform(video_data)
|
605 |
-
return video_data
|
606 |
-
|
607 |
-
def load_video_1fps(video_path):
|
608 |
-
mp4_stream = None
|
609 |
-
decord.bridge.set_bridge('torch')
|
610 |
-
with open(video_path, 'rb') as f:
|
611 |
-
mp4_stream = f.read()
|
612 |
-
|
613 |
-
num_frames = 24
|
614 |
-
# decord.bridge.set_bridge('torch')
|
615 |
-
if mp4_stream is not None:
|
616 |
-
decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
|
617 |
-
else:
|
618 |
-
decord_vr = VideoReader(video_path, ctx=cpu(0))
|
619 |
-
|
620 |
-
total_frames = len(decord_vr)
|
621 |
-
timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
|
622 |
-
timestamps = [i[0] for i in timestamps]
|
623 |
-
|
624 |
-
max_second = round(max(timestamps)) + 1
|
625 |
-
frame_id_list = []
|
626 |
-
for second in range(max_second):
|
627 |
-
closest_num = min(timestamps, key=lambda x: abs(x - second))
|
628 |
-
index = timestamps.index(closest_num)
|
629 |
-
frame_id_list.append(index)
|
630 |
-
if len(frame_id_list) > num_frames:
|
631 |
-
break
|
632 |
-
|
633 |
-
video_data = decord_vr.get_batch(frame_id_list)
|
634 |
-
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
|
635 |
-
# video_outputs = transform(video_data)
|
636 |
-
return video_data
|
637 |
-
|
638 |
-
|
639 |
|
640 |
class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
641 |
_auto_class = "AutoModelForCausalLM"
|
|
|
404 |
images_features = self.encode_images(images)
|
405 |
images_features = rearrange(images_features, 'b n d -> (b n) d')
|
406 |
images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
|
|
407 |
inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
|
408 |
else: # single-modality
|
409 |
if token_type_ids is None:
|
|
|
579 |
prompt += 'Question: {} {}'.format(query, answer_format)
|
580 |
return prompt
|
581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
582 |
|
583 |
class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
584 |
_auto_class = "AutoModelForCausalLM"
|