Spaces:
Runtime error
Runtime error
Update pipelines/data/data_module.py
Browse files- pipelines/data/data_module.py +16 -11
pipelines/data/data_module.py
CHANGED
@@ -5,6 +5,7 @@
|
|
5 |
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
|
7 |
import torch
|
|
|
8 |
import torchaudio
|
9 |
import torchvision
|
10 |
from .transforms import AudioTransform, VideoTransform
|
@@ -28,8 +29,9 @@ class AVSRDataLoader:
|
|
28 |
|
29 |
def load_data(self, data_filename, landmarks=None, transform=True):
|
30 |
if self.modality == "audio":
|
31 |
-
audio, sample_rate = self.load_audio(data_filename)
|
32 |
-
audio = self.audio_process(audio, sample_rate)
|
|
|
33 |
return self.audio_transform(audio) if self.transform else audio
|
34 |
if self.modality == "video":
|
35 |
video = self.load_video(data_filename)
|
@@ -38,8 +40,9 @@ class AVSRDataLoader:
|
|
38 |
return self.video_transform(video) if self.transform else video
|
39 |
if self.modality == "audiovisual":
|
40 |
rate_ratio = 640
|
41 |
-
audio, sample_rate = self.load_audio(data_filename)
|
42 |
-
audio = self.audio_process(audio, sample_rate)
|
|
|
43 |
video = self.load_video(data_filename)
|
44 |
video = self.video_process(video, landmarks)
|
45 |
video = torch.tensor(video)
|
@@ -53,16 +56,18 @@ class AVSRDataLoader:
|
|
53 |
|
54 |
|
55 |
def load_audio(self, data_filename):
|
56 |
-
|
57 |
-
|
|
|
|
|
58 |
|
59 |
|
60 |
def load_video(self, data_filename):
|
61 |
return torchvision.io.read_video(data_filename, pts_unit='sec')[0].numpy()
|
62 |
|
63 |
|
64 |
-
def audio_process(self, waveform, sample_rate, target_sample_rate=16000):
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
5 |
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
|
7 |
import torch
|
8 |
+
import whisper
|
9 |
import torchaudio
|
10 |
import torchvision
|
11 |
from .transforms import AudioTransform, VideoTransform
|
|
|
29 |
|
30 |
def load_data(self, data_filename, landmarks=None, transform=True):
|
31 |
if self.modality == "audio":
|
32 |
+
# audio, sample_rate = self.load_audio(data_filename)
|
33 |
+
# audio = self.audio_process(audio, sample_rate)
|
34 |
+
audio = self.load_audio(data_filename)
|
35 |
return self.audio_transform(audio) if self.transform else audio
|
36 |
if self.modality == "video":
|
37 |
video = self.load_video(data_filename)
|
|
|
40 |
return self.video_transform(video) if self.transform else video
|
41 |
if self.modality == "audiovisual":
|
42 |
rate_ratio = 640
|
43 |
+
# audio, sample_rate = self.load_audio(data_filename)
|
44 |
+
# audio = self.audio_process(audio, sample_rate)
|
45 |
+
audio = self.load_audio(data_filename)
|
46 |
video = self.load_video(data_filename)
|
47 |
video = self.video_process(video, landmarks)
|
48 |
video = torch.tensor(video)
|
|
|
56 |
|
57 |
|
58 |
def load_audio(self, data_filename):
|
59 |
+
# rtype: [1, T]
|
60 |
+
waveform = torch.tensor(whisper.load_audio(data_filename)).unsqueeze(0)
|
61 |
+
# waveform, sample_rate = torchaudio.load(data_filename, normalize=True)
|
62 |
+
# return waveform, sample_rate
|
63 |
|
64 |
|
65 |
def load_video(self, data_filename):
|
66 |
return torchvision.io.read_video(data_filename, pts_unit='sec')[0].numpy()
|
67 |
|
68 |
|
69 |
+
# def audio_process(self, waveform, sample_rate, target_sample_rate=16000):
|
70 |
+
# if sample_rate != target_sample_rate:
|
71 |
+
# waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate)
|
72 |
+
# waveform = torch.mean(waveform, dim=0, keepdim=True)
|
73 |
+
# return waveform
|