File size: 3,125 Bytes
aaf9c6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import flash
from flash.core.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier
import torch
from flash.video.classification.input_transform import VideoClassificationInputTransform
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample,
    UniformCropVideo,
)
from dataclasses import dataclass
from typing import Callable

import torch
from torch import Tensor

from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import (
    _KORNIA_AVAILABLE,
    _PYTORCHVIDEO_AVAILABLE,
    requires,
)
from torchvision.transforms import Compose, CenterCrop
from torchvision.transforms import RandomCrop
from torch import nn
import kornia.augmentation as K
from torchvision import transforms as T
torch.set_float32_matmul_precision('high')

def normalize(x: Tensor) -> Tensor:
    return x / 255.0

class TransformDataModule(InputTransform):
    image_size: int = 256
    temporal_sub_sample: int = 16  # This is the only change in our custom transform
    mean: Tensor = torch.tensor([0.45, 0.45, 0.45])
    std: Tensor = torch.tensor([0.225, 0.225, 0.225])
    data_format: str = "BCTHW"
    same_on_frame: bool = False

    def per_sample_transform(self) -> Callable:
        per_sample_transform = [CenterCrop(self.image_size)]

        return Compose(
            [
                ApplyToKeys(
                    DataKeys.INPUT,
                    Compose(
                        [UniformTemporalSubsample(self.temporal_sub_sample), normalize]
                        + per_sample_transform
                    ),
                ),
                ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
            ]
        )

    def train_per_sample_transform(self) -> Callable:
        per_sample_transform = [RandomCrop(self.image_size, pad_if_needed=True)]

        return Compose(
            [
                ApplyToKeys(
                    DataKeys.INPUT,
                    Compose(
                        [UniformTemporalSubsample(self.temporal_sub_sample), normalize]
                        + per_sample_transform
                    ),
                ),
                ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
            ]
        )

    def per_batch_transform_on_device(self) -> Callable:
        return ApplyToKeys(
            DataKeys.INPUT,
            K.VideoSequential(
                K.Normalize(self.mean, self.std),
                data_format=self.data_format,
                same_on_frame=self.same_on_frame,
            ),
        )



model = VideoClassifier.load_from_checkpoint("video_classfication/checkpoints/epoch=99-step=1000.ckpt")


datamodule_p = VideoClassificationData.from_folders(
 predict_folder="videos",
 batch_size=1,
 transform=TransformDataModule()
)
trainer = flash.Trainer(
 max_epochs=5,
)
def classfication():
    predictions = trainer.predict(model, datamodule=datamodule_p, output="labels")
    return predictions[0][0]