|
import flash |
|
from flash.core.data.utils import download_data |
|
from flash.video import VideoClassificationData, VideoClassifier |
|
import torch |
|
from flash.video.classification.input_transform import VideoClassificationInputTransform |
|
from pytorchvideo.transforms import ( |
|
ApplyTransformToKey, |
|
ShortSideScale, |
|
UniformTemporalSubsample, |
|
UniformCropVideo, |
|
) |
|
from dataclasses import dataclass |
|
from typing import Callable |
|
|
|
import torch |
|
from torch import Tensor |
|
|
|
from flash.core.data.io.input import DataKeys |
|
from flash.core.data.io.input_transform import InputTransform |
|
from flash.core.data.transforms import ApplyToKeys |
|
from flash.core.utilities.imports import ( |
|
_KORNIA_AVAILABLE, |
|
_PYTORCHVIDEO_AVAILABLE, |
|
requires, |
|
) |
|
from torchvision.transforms import Compose, CenterCrop |
|
from torchvision.transforms import RandomCrop |
|
from torch import nn |
|
import kornia.augmentation as K |
|
from torchvision import transforms as T |
|
torch.set_float32_matmul_precision('high') |
|
|
|
def normalize(x: Tensor) -> Tensor: |
|
return x / 255.0 |
|
|
|
class TransformDataModule(InputTransform): |
|
image_size: int = 256 |
|
temporal_sub_sample: int = 16 |
|
mean: Tensor = torch.tensor([0.45, 0.45, 0.45]) |
|
std: Tensor = torch.tensor([0.225, 0.225, 0.225]) |
|
data_format: str = "BCTHW" |
|
same_on_frame: bool = False |
|
|
|
def per_sample_transform(self) -> Callable: |
|
per_sample_transform = [CenterCrop(self.image_size)] |
|
|
|
return Compose( |
|
[ |
|
ApplyToKeys( |
|
DataKeys.INPUT, |
|
Compose( |
|
[UniformTemporalSubsample(self.temporal_sub_sample), normalize] |
|
+ per_sample_transform |
|
), |
|
), |
|
ApplyToKeys(DataKeys.TARGET, torch.as_tensor), |
|
] |
|
) |
|
|
|
def train_per_sample_transform(self) -> Callable: |
|
per_sample_transform = [RandomCrop(self.image_size, pad_if_needed=True)] |
|
|
|
return Compose( |
|
[ |
|
ApplyToKeys( |
|
DataKeys.INPUT, |
|
Compose( |
|
[UniformTemporalSubsample(self.temporal_sub_sample), normalize] |
|
+ per_sample_transform |
|
), |
|
), |
|
ApplyToKeys(DataKeys.TARGET, torch.as_tensor), |
|
] |
|
) |
|
|
|
def per_batch_transform_on_device(self) -> Callable: |
|
return ApplyToKeys( |
|
DataKeys.INPUT, |
|
K.VideoSequential( |
|
K.Normalize(self.mean, self.std), |
|
data_format=self.data_format, |
|
same_on_frame=self.same_on_frame, |
|
), |
|
) |
|
|
|
|
|
|
|
model = VideoClassifier.load_from_checkpoint("video_classfication/checkpoints/epoch=99-step=1000.ckpt") |
|
|
|
|
|
datamodule_p = VideoClassificationData.from_folders( |
|
predict_folder="videos", |
|
batch_size=1, |
|
transform=TransformDataModule() |
|
) |
|
trainer = flash.Trainer( |
|
max_epochs=5, |
|
) |
|
def classfication(): |
|
predictions = trainer.predict(model, datamodule=datamodule_p, output="labels") |
|
return predictions[0][0] |
|
|