Spaces:
Runtime error
Runtime error
File size: 1,423 Bytes
09481f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
import torchaudio
import torchvision
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class VideoTransform:
def __init__(self, speed_rate):
self.video_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: x.unsqueeze(-1)),
FunctionalModule(lambda x: x if speed_rate == 1 else torch.index_select(x, dim=0, index=torch.linspace(0, x.shape[0]-1, int(x.shape[0] / speed_rate), dtype=torch.int64))),
FunctionalModule(lambda x: x.permute(3, 0, 1, 2)),
FunctionalModule(lambda x: x / 255.),
torchvision.transforms.CenterCrop(88),
torchvision.transforms.Normalize(0.421, 0.165),
)
def __call__(self, sample):
return self.video_pipeline(sample)
class AudioTransform:
def __init__(self):
self.audio_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: torch.nn.functional.layer_norm(x, x.shape, eps=0)),
FunctionalModule(lambda x: x.transpose(0, 1)),
)
def __call__(self, sample):
return self.audio_pipeline(sample)
|