import flash from flash.core.data.utils import download_data from flash.video import VideoClassificationData, VideoClassifier import torch from flash.video.classification.input_transform import VideoClassificationInputTransform from pytorchvideo.transforms import ( ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample, UniformCropVideo, ) from dataclasses import dataclass from typing import Callable import torch from torch import Tensor from flash.core.data.io.input import DataKeys from flash.core.data.io.input_transform import InputTransform from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import ( _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, requires, ) from torchvision.transforms import Compose, CenterCrop from torchvision.transforms import RandomCrop from torch import nn import kornia.augmentation as K from torchvision import transforms as T torch.set_float32_matmul_precision('high') def normalize(x: Tensor) -> Tensor: return x / 255.0 class TransformDataModule(InputTransform): image_size: int = 256 temporal_sub_sample: int = 16 # This is the only change in our custom transform mean: Tensor = torch.tensor([0.45, 0.45, 0.45]) std: Tensor = torch.tensor([0.225, 0.225, 0.225]) data_format: str = "BCTHW" same_on_frame: bool = False def per_sample_transform(self) -> Callable: per_sample_transform = [CenterCrop(self.image_size)] return Compose( [ ApplyToKeys( DataKeys.INPUT, Compose( [UniformTemporalSubsample(self.temporal_sub_sample), normalize] + per_sample_transform ), ), ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ] ) def train_per_sample_transform(self) -> Callable: per_sample_transform = [RandomCrop(self.image_size, pad_if_needed=True)] return Compose( [ ApplyToKeys( DataKeys.INPUT, Compose( [UniformTemporalSubsample(self.temporal_sub_sample), normalize] + per_sample_transform ), ), ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ] ) def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys( DataKeys.INPUT, K.VideoSequential( K.Normalize(self.mean, self.std), data_format=self.data_format, same_on_frame=self.same_on_frame, ), ) model = VideoClassifier.load_from_checkpoint("video_classfication/checkpoints/epoch=99-step=1000.ckpt") datamodule_p = VideoClassificationData.from_folders( predict_folder="videos", batch_size=1, transform=TransformDataModule() ) trainer = flash.Trainer( max_epochs=5, ) def classfication(): predictions = trainer.predict(model, datamodule=datamodule_p, output="labels") return predictions[0][0]