WSL / utils.py
Sahil
Add application file
aaf9c6c
raw
history blame
3.13 kB
import flash
from flash.core.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier
import torch
from flash.video.classification.input_transform import VideoClassificationInputTransform
from pytorchvideo.transforms import (
ApplyTransformToKey,
ShortSideScale,
UniformTemporalSubsample,
UniformCropVideo,
)
from dataclasses import dataclass
from typing import Callable
import torch
from torch import Tensor
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import (
_KORNIA_AVAILABLE,
_PYTORCHVIDEO_AVAILABLE,
requires,
)
from torchvision.transforms import Compose, CenterCrop
from torchvision.transforms import RandomCrop
from torch import nn
import kornia.augmentation as K
from torchvision import transforms as T
torch.set_float32_matmul_precision('high')
def normalize(x: Tensor) -> Tensor:
return x / 255.0
class TransformDataModule(InputTransform):
image_size: int = 256
temporal_sub_sample: int = 16 # This is the only change in our custom transform
mean: Tensor = torch.tensor([0.45, 0.45, 0.45])
std: Tensor = torch.tensor([0.225, 0.225, 0.225])
data_format: str = "BCTHW"
same_on_frame: bool = False
def per_sample_transform(self) -> Callable:
per_sample_transform = [CenterCrop(self.image_size)]
return Compose(
[
ApplyToKeys(
DataKeys.INPUT,
Compose(
[UniformTemporalSubsample(self.temporal_sub_sample), normalize]
+ per_sample_transform
),
),
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
]
)
def train_per_sample_transform(self) -> Callable:
per_sample_transform = [RandomCrop(self.image_size, pad_if_needed=True)]
return Compose(
[
ApplyToKeys(
DataKeys.INPUT,
Compose(
[UniformTemporalSubsample(self.temporal_sub_sample), normalize]
+ per_sample_transform
),
),
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
]
)
def per_batch_transform_on_device(self) -> Callable:
return ApplyToKeys(
DataKeys.INPUT,
K.VideoSequential(
K.Normalize(self.mean, self.std),
data_format=self.data_format,
same_on_frame=self.same_on_frame,
),
)
model = VideoClassifier.load_from_checkpoint("video_classfication/checkpoints/epoch=99-step=1000.ckpt")
datamodule_p = VideoClassificationData.from_folders(
predict_folder="videos",
batch_size=1,
transform=TransformDataModule()
)
trainer = flash.Trainer(
max_epochs=5,
)
def classfication():
predictions = trainer.predict(model, datamodule=datamodule_p, output="labels")
return predictions[0][0]