mpc001 commited on
Commit
425231f
·
1 Parent(s): c2d564e

Update pipelines/data/data_module.py

Browse files
Files changed (1) hide show
  1. pipelines/data/data_module.py +16 -11
pipelines/data/data_module.py CHANGED
@@ -5,6 +5,7 @@
5
  # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
 
7
  import torch
 
8
  import torchaudio
9
  import torchvision
10
  from .transforms import AudioTransform, VideoTransform
@@ -28,8 +29,9 @@ class AVSRDataLoader:
28
 
29
  def load_data(self, data_filename, landmarks=None, transform=True):
30
  if self.modality == "audio":
31
- audio, sample_rate = self.load_audio(data_filename)
32
- audio = self.audio_process(audio, sample_rate)
 
33
  return self.audio_transform(audio) if self.transform else audio
34
  if self.modality == "video":
35
  video = self.load_video(data_filename)
@@ -38,8 +40,9 @@ class AVSRDataLoader:
38
  return self.video_transform(video) if self.transform else video
39
  if self.modality == "audiovisual":
40
  rate_ratio = 640
41
- audio, sample_rate = self.load_audio(data_filename)
42
- audio = self.audio_process(audio, sample_rate)
 
43
  video = self.load_video(data_filename)
44
  video = self.video_process(video, landmarks)
45
  video = torch.tensor(video)
@@ -53,16 +56,18 @@ class AVSRDataLoader:
53
 
54
 
55
  def load_audio(self, data_filename):
56
- waveform, sample_rate = torchaudio.load(data_filename, normalize=True)
57
- return waveform, sample_rate
 
 
58
 
59
 
60
  def load_video(self, data_filename):
61
  return torchvision.io.read_video(data_filename, pts_unit='sec')[0].numpy()
62
 
63
 
64
- def audio_process(self, waveform, sample_rate, target_sample_rate=16000):
65
- if sample_rate != target_sample_rate:
66
- waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate)
67
- waveform = torch.mean(waveform, dim=0, keepdim=True)
68
- return waveform
 
5
  # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
 
7
  import torch
8
+ import whisper
9
  import torchaudio
10
  import torchvision
11
  from .transforms import AudioTransform, VideoTransform
 
29
 
30
  def load_data(self, data_filename, landmarks=None, transform=True):
31
  if self.modality == "audio":
32
+ # audio, sample_rate = self.load_audio(data_filename)
33
+ # audio = self.audio_process(audio, sample_rate)
34
+ audio = self.load_audio(data_filename)
35
  return self.audio_transform(audio) if self.transform else audio
36
  if self.modality == "video":
37
  video = self.load_video(data_filename)
 
40
  return self.video_transform(video) if self.transform else video
41
  if self.modality == "audiovisual":
42
  rate_ratio = 640
43
+ # audio, sample_rate = self.load_audio(data_filename)
44
+ # audio = self.audio_process(audio, sample_rate)
45
+ audio = self.load_audio(data_filename)
46
  video = self.load_video(data_filename)
47
  video = self.video_process(video, landmarks)
48
  video = torch.tensor(video)
 
56
 
57
 
58
  def load_audio(self, data_filename):
59
+ # rtype: [1, T]
60
+ waveform = torch.tensor(whisper.load_audio(data_filename)).unsqueeze(0)
61
+ # waveform, sample_rate = torchaudio.load(data_filename, normalize=True)
62
+ # return waveform, sample_rate
63
 
64
 
65
  def load_video(self, data_filename):
66
  return torchvision.io.read_video(data_filename, pts_unit='sec')[0].numpy()
67
 
68
 
69
+ # def audio_process(self, waveform, sample_rate, target_sample_rate=16000):
70
+ # if sample_rate != target_sample_rate:
71
+ # waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate)
72
+ # waveform = torch.mean(waveform, dim=0, keepdim=True)
73
+ # return waveform