File size: 4,623 Bytes
ce78b5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350d5de
ce78b5d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from typing import List

import cv2
import torch
import numpy as np
from tqdm import tqdm
import supervision as sv
import torch.nn.functional as F
from transformers import AutoModel
from sklearn.decomposition import PCA
from torchvision import transforms as T
from sklearn.preprocessing import MinMaxScaler


def load_video_frames(video_path: str) -> List[np.ndarray]:
    frames = []
    for frame in tqdm(sv.get_video_frames_generator(source_path=video_path), unit=" frames"):
        frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    return frames

def preprocess(image: np.ndarray, n_patches: int, device: str, patch_size: int = 14) -> torch.Tensor:
    IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

    transform = T.Compose([
        T.Resize((n_patches * patch_size, n_patches * patch_size)),
        T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
    ])

    img = torch.from_numpy(image).type(torch.float).permute(2, 0, 1) / 255
    img_tensor = transform(img).unsqueeze(0).to(device)

    return img_tensor


def process_video(
    model: AutoModel,
    video: str | List[np.ndarray],
    is_larger: bool = True,
    batch_size: int = 4,
    threshold: float = 0.5,
    n_patches: int = 40,
    interpolate: bool = False,
    device: str = "cpu"
) -> List[np.ndarray]:
    # NP = N_PATCHES
    # P = PATCH_SIZE
    if isinstance(video, str):
        frames = load_video_frames(video)
    else:
        frames = video
    patch_size = model.config.patch_size

    original_height = frames[0].shape[0] # C, H, W
    original_width = frames[0].shape[1] # C, H, W

    final_frames = []
    pca = PCA(n_components=3)
    scaler = MinMaxScaler(clip=True)

    for i in range(len(frames)//batch_size):
        batch = frames[i*batch_size:batch_size*(i+1)]
        pixel_values = [
            preprocess(f, n_patches, device, patch_size).squeeze(0) for f in batch
        ]
        pixel_values = torch.stack(pixel_values) # B, C, NP * P, NP * P

        with torch.no_grad():
            out = model(pixel_values=pixel_values)

        features = out.last_hidden_state[:, 1:] # B, P * P, HIDDEN_DIM
        features = features.cpu().numpy()
        features = features.reshape(batch_size * n_patches * n_patches, -1) # B * P * P, HIDDEN_DIM

        pca_features = pca.fit_transform(features)
        pca_features = scaler.fit_transform(pca_features)

        if is_larger:
            pca_features_bg = pca_features[:, 0] > threshold
        else:
            pca_features_bg = pca_features[:, 0] < threshold


        pca_features_fg = ~pca_features_bg

        pca_features_fg_seg = pca.fit_transform(features[pca_features_fg])

        pca_features_fg_seg = scaler.fit_transform(pca_features_fg_seg)

        pca_features_rgb = np.zeros((batch_size * n_patches * n_patches, 3))
        pca_features_rgb[pca_features_bg] = 0
        pca_features_rgb[pca_features_fg] = pca_features_fg_seg
        pca_features_rgb = pca_features_rgb.reshape(batch_size, n_patches, n_patches, 3)

        if interpolate:
            # transformed into torch tensor
            pca_features_rgb = torch.from_numpy(pca_features_rgb) # B, P, P, 3
            # reshaped to B, C, P, P
            pca_features_rgb = pca_features_rgb.permute(0, 3, 1, 2)
            # interpolate to B, C, H, W
            # reshaped to B, H, W, C
            # unbind to a list of len B with np.ndarray of shape H, W, C
            pca_features_rgb = F.interpolate(
                pca_features_rgb,
                size=(original_height, original_width),
                mode='bilinear',
                align_corners=False
            ).permute(0, 2, 3, 1).unbind(0)
            # Fixing range to np.uint8
        else:
            pca_features_rgb = [f for f in pca_features_rgb]
        # Adding to final_frames list
        final_frames.extend(pca_features_rgb)

    return final_frames


def create_video_from_frames_rgb(
    frame_list: List[np.ndarray], 
    output_filename: str = "animation.mp4", 
    fps: int = 15
) -> str:
    # Get the shape of the frames to determine video dimensions
    frame_height, frame_width, _ = frame_list[0].shape

    # Define the codec and create a VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # You can change the codec as needed
    out = cv2.VideoWriter(output_filename, fourcc, fps, (frame_width, frame_height))

    for frame in frame_list:
        # Write the frame to the video file
        out.write(np.uint8(frame*255))

    # Release the VideoWriter object
    out.release()

    return output_filename