Spaces:
Paused
Paused
Update models/pipeline_controlvideo.py
Browse files- models/pipeline_controlvideo.py +37 -37
models/pipeline_controlvideo.py
CHANGED
@@ -670,43 +670,43 @@ class ControlVideoPipeline(DiffusionPipeline):
|
|
670 |
return key_frame_indices, inter_frame_list
|
671 |
"""
|
672 |
def get_slide_window_indices(self, video_length, window_size):
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
@torch.no_grad()
|
711 |
def __call__(
|
712 |
self,
|
|
|
670 |
return key_frame_indices, inter_frame_list
|
671 |
"""
|
672 |
def get_slide_window_indices(self, video_length, window_size):
|
673 |
+
assert window_size >= 3
|
674 |
+
|
675 |
+
# Define the chunk size for processing
|
676 |
+
chunk_size = 4
|
677 |
+
|
678 |
+
# Calculate the number of chunks
|
679 |
+
num_chunks = (video_length - 1) // chunk_size + 1
|
680 |
+
|
681 |
+
# Initialize the lists to store the results
|
682 |
+
key_frame_indices = []
|
683 |
+
inter_frame_list = []
|
684 |
+
|
685 |
+
for chunk_index in range(num_chunks):
|
686 |
+
# Calculate the start and end indices for the current chunk
|
687 |
+
start_index = chunk_index * chunk_size
|
688 |
+
end_index = min((chunk_index + 1) * chunk_size, video_length)
|
689 |
+
|
690 |
+
# Generate key frame indices for the current chunk
|
691 |
+
chunk_key_frame_indices = np.arange(start_index, end_index, window_size - 1).tolist()
|
692 |
+
|
693 |
+
# Append the last index if it's not already included
|
694 |
+
if chunk_key_frame_indices[-1] != (end_index - 1):
|
695 |
+
chunk_key_frame_indices.append(end_index - 1)
|
696 |
+
|
697 |
+
# Append the key frame indices of the current chunk to the overall list
|
698 |
+
key_frame_indices.extend(chunk_key_frame_indices)
|
699 |
+
|
700 |
+
# Generate slices for the current chunk
|
701 |
+
chunk_slices = np.split(np.arange(start_index, end_index), chunk_key_frame_indices)
|
702 |
+
|
703 |
+
# Process each slice in the current chunk
|
704 |
+
for s in chunk_slices:
|
705 |
+
if len(s) < 2:
|
706 |
+
continue
|
707 |
+
inter_frame_list.append(s[1:].tolist())
|
708 |
+
|
709 |
+
return key_frame_indices, inter_frame_list
|
710 |
@torch.no_grad()
|
711 |
def __call__(
|
712 |
self,
|