ymzhang319's picture
init
7f2690b
raw
history blame
30.1 kB
import json
import os
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import numpy as np
from tqdm import tqdm
from random import sample
import torchaudio
import logging
from glob import glob
import sys
import soundfile
import copy
import csv
import noisereduce as nr
sys.path.insert(0, '.') # nopep8
from train import instantiate_from_config
from foleycrafter.models.specvqgan.data.transforms import *
torchaudio.set_audio_backend("sox_io")
logger = logging.getLogger(f'main.{__name__}')
SR = 22050
FPS = 15
MAX_SAMPLE_ITER = 10
def non_negative(x): return int(np.round(max(0, x), 0))
def rms(x): return np.sqrt(np.mean(x**2))
def get_GH_data_identifier(video_name, start_idx, split='_'):
if isinstance(start_idx, str):
return video_name + split + start_idx
elif isinstance(start_idx, int):
return video_name + split + str(start_idx)
else:
raise NotImplementedError
def draw_spec(spec, dest, cmap='magma'):
plt.imshow(spec, cmap=cmap, origin='lower')
plt.axis('off')
plt.savefig(dest, bbox_inches='tight', pad_inches=0., dpi=300)
plt.close()
def convert_to_decibel(arr):
ref = 1
return 20 * np.log10(abs(arr + 1e-4) / ref)
class ResampleFrames(object):
def __init__(self, feat_sample_size, times_to_repeat_after_resample=None):
self.feat_sample_size = feat_sample_size
self.times_to_repeat_after_resample = times_to_repeat_after_resample
def __call__(self, item):
feat_len = item['feature'].shape[0]
## resample
assert feat_len >= self.feat_sample_size
# evenly spaced points (abcdefghkl -> aoooofoooo)
idx = np.linspace(0, feat_len, self.feat_sample_size, dtype=np.int, endpoint=False)
# xoooo xoooo -> ooxoo ooxoo
shift = feat_len // (self.feat_sample_size + 1)
idx = idx + shift
## repeat after resampling (abc -> aaaabbbbcccc)
if self.times_to_repeat_after_resample is not None and self.times_to_repeat_after_resample > 1:
idx = np.repeat(idx, self.times_to_repeat_after_resample)
item['feature'] = item['feature'][idx, :]
return item
class ImpactSetWave(torch.utils.data.Dataset):
def __init__(self, split, random_crop, mel_num, spec_crop_len,
L=2.0, denoise=False, splits_path='./data',
data_path='data/ImpactSet/impactset-proccess-resize'):
super().__init__()
self.split = split
self.splits_path = splits_path
self.data_path = data_path
self.L = L
self.denoise = denoise
video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json')
if not os.path.exists(video_name_split_path):
self.make_split_files()
video_name = json.load(open(video_name_split_path, 'r'))
self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name}
self.left_over = int(FPS * L + 1)
self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name}
self.dataset = video_name
self.wav_transforms = transforms.Compose([
MakeMono(),
Padding(target_len=int(SR * self.L)),
])
self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = {}
video = self.dataset[idx]
available_frame_idx = self.video_frame_cnt[video] - self.left_over
wav = None
spec = None
max_db = -np.inf
wave_path = ''
cur_wave_path = self.video_audio_path[video]
if self.denoise:
cur_wave_path = cur_wave_path.replace('.wav', '_denoised.wav')
for _ in range(10):
start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0]
# target
start_t = (start_idx + 0.5) / FPS
start_audio_idx = non_negative(start_t * SR)
cur_wav, _ = soundfile.read(cur_wave_path, frames=int(SR * self.L), start=start_audio_idx)
decibel = convert_to_decibel(cur_wav)
if float(np.mean(decibel)) > max_db:
wav = cur_wav
wave_path = cur_wave_path
max_db = float(np.mean(decibel))
if max_db >= -40:
break
# print(max_db)
wav = self.wav_transforms(wav)
item['image'] = wav # (80, 173)
# item['wav'] = wav
item['file_path_wav_'] = wave_path
item['label'] = 'None'
item['target'] = 'None'
return item
def make_split_files(self):
raise NotImplementedError
class ImpactSetWaveTrain(ImpactSetWave):
def __init__(self, specs_dataset_cfg):
super().__init__('train', **specs_dataset_cfg)
class ImpactSetWaveValidation(ImpactSetWave):
def __init__(self, specs_dataset_cfg):
super().__init__('val', **specs_dataset_cfg)
class ImpactSetWaveTest(ImpactSetWave):
def __init__(self, specs_dataset_cfg):
super().__init__('test', **specs_dataset_cfg)
class ImpactSetSpec(torch.utils.data.Dataset):
def __init__(self, split, random_crop, mel_num, spec_crop_len,
L=2.0, denoise=False, splits_path='./data',
data_path='data/ImpactSet/impactset-proccess-resize'):
super().__init__()
self.split = split
self.splits_path = splits_path
self.data_path = data_path
self.L = L
self.denoise = denoise
video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json')
if not os.path.exists(video_name_split_path):
self.make_split_files()
video_name = json.load(open(video_name_split_path, 'r'))
self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name}
self.left_over = int(FPS * L + 1)
self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name}
self.dataset = video_name
self.wav_transforms = transforms.Compose([
MakeMono(),
SpectrogramTorchAudio(nfft=1024, hoplen=1024//4, spec_power=1),
MelScaleTorchAudio(sr=SR, stft=513, fmin=125, fmax=7600, nmels=80),
LowerThresh(1e-5),
Log10(),
Multiply(20),
Subtract(20),
Add(100),
Divide(100),
Clip(0, 1.0),
TrimSpec(173),
])
self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = {}
video = self.dataset[idx]
available_frame_idx = self.video_frame_cnt[video] - self.left_over
wav = None
spec = None
max_rms = -np.inf
wave_path = ''
cur_wave_path = self.video_audio_path[video]
if self.denoise:
cur_wave_path = cur_wave_path.replace('.wav', '_denoised.wav')
for _ in range(10):
start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0]
# target
start_t = (start_idx + 0.5) / FPS
start_audio_idx = non_negative(start_t * SR)
cur_wav, _ = soundfile.read(cur_wave_path, frames=int(SR * self.L), start=start_audio_idx)
if self.wav_transforms is not None:
spec_tensor = self.wav_transforms(torch.tensor(cur_wav).float())
cur_spec = spec_tensor.numpy()
# zeros padding if not enough spec t steps
if cur_spec.shape[1] < 173:
pad = np.zeros((80, 173), dtype=cur_spec.dtype)
pad[:, :cur_spec.shape[1]] = cur_spec
cur_spec = pad
rms_val = rms(cur_spec)
if rms_val > max_rms:
wav = cur_wav
spec = cur_spec
wave_path = cur_wave_path
max_rms = rms_val
# print(rms_val)
if max_rms >= 0.1:
break
item['image'] = 2 * spec - 1 # (80, 173)
# item['wav'] = wav
item['file_path_wav_'] = wave_path
item['label'] = 'None'
item['target'] = 'None'
if self.spec_transforms is not None:
item = self.spec_transforms(item)
return item
def make_split_files(self):
raise NotImplementedError
class ImpactSetSpecTrain(ImpactSetSpec):
def __init__(self, specs_dataset_cfg):
super().__init__('train', **specs_dataset_cfg)
class ImpactSetSpecValidation(ImpactSetSpec):
def __init__(self, specs_dataset_cfg):
super().__init__('val', **specs_dataset_cfg)
class ImpactSetSpecTest(ImpactSetSpec):
def __init__(self, specs_dataset_cfg):
super().__init__('test', **specs_dataset_cfg)
class ImpactSetWaveTestTime(torch.utils.data.Dataset):
def __init__(self, split, random_crop, mel_num, spec_crop_len,
L=2.0, denoise=False, splits_path='./data',
data_path='data/ImpactSet/impactset-proccess-resize'):
super().__init__()
self.split = split
self.splits_path = splits_path
self.data_path = data_path
self.L = L
self.denoise = denoise
self.video_list = glob('data/ImpactSet/RawVideos/StockVideo_sound/*.wav') + [
'data/ImpactSet/RawVideos/YouTube-impact-ccl/1_ckbCU5aQs/1_ckbCU5aQs_0013_0016_resize.wav',
'data/ImpactSet/RawVideos/YouTube-impact-ccl/GFmuVBiwz6k/GFmuVBiwz6k_0034_0054_resize.wav',
'data/ImpactSet/RawVideos/YouTube-impact-ccl/OsPcY316h1M/OsPcY316h1M_0000_0005_resize.wav',
'data/ImpactSet/RawVideos/YouTube-impact-ccl/SExIpBIBj_k/SExIpBIBj_k_0009_0019_resize.wav',
'data/ImpactSet/RawVideos/YouTube-impact-ccl/S6TkbV4B4QI/S6TkbV4B4QI_0028_0036_resize.wav',
'data/ImpactSet/RawVideos/YouTube-impact-ccl/2Ld24pPIn3k/2Ld24pPIn3k_0005_0011_resize.wav',
'data/ImpactSet/RawVideos/YouTube-impact-ccl/6d1YS7fdBK4/6d1YS7fdBK4_0007_0019_resize.wav',
'data/ImpactSet/RawVideos/YouTube-impact-ccl/JnBsmJgEkiw/JnBsmJgEkiw_0008_0016_resize.wav',
'data/ImpactSet/RawVideos/YouTube-impact-ccl/xcUyiXt0gjo/xcUyiXt0gjo_0015_0021_resize.wav',
'data/ImpactSet/RawVideos/YouTube-impact-ccl/4DRFJnZjpMM/4DRFJnZjpMM_0000_0010_resize.wav'
] + glob('data/ImpactSet/RawVideos/self_recorded/*_resize.wav')
self.wav_transforms = transforms.Compose([
MakeMono(),
SpectrogramTorchAudio(nfft=1024, hoplen=1024//4, spec_power=1),
MelScaleTorchAudio(sr=SR, stft=513, fmin=125, fmax=7600, nmels=80),
LowerThresh(1e-5),
Log10(),
Multiply(20),
Subtract(20),
Add(100),
Divide(100),
Clip(0, 1.0),
TrimSpec(173),
])
self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop)
def __len__(self):
return len(self.video_list)
def __getitem__(self, idx):
item = {}
wave_path = self.video_list[idx]
wav, _ = soundfile.read(wave_path)
start_idx = random.randint(0, min(4, wav.shape[0] - int(SR * self.L)))
wav = wav[start_idx:start_idx+int(SR * self.L)]
if self.denoise:
if len(wav.shape) == 1:
wav = wav[None, :]
wav = nr.reduce_noise(y=wav, sr=SR, n_fft=1024, hop_length=1024//4)
wav = wav.squeeze()
if self.wav_transforms is not None:
spec_tensor = self.wav_transforms(torch.tensor(wav).float())
spec = spec_tensor.numpy()
if spec.shape[1] < 173:
pad = np.zeros((80, 173), dtype=spec.dtype)
pad[:, :spec.shape[1]] = spec
spec = pad
item['image'] = 2 * spec - 1 # (80, 173)
# item['wav'] = wav
item['file_path_wav_'] = wave_path
item['label'] = 'None'
item['target'] = 'None'
if self.spec_transforms is not None:
item = self.spec_transforms(item)
return item
def make_split_files(self):
raise NotImplementedError
class ImpactSetWaveTestTimeTrain(ImpactSetWaveTestTime):
def __init__(self, specs_dataset_cfg):
super().__init__('train', **specs_dataset_cfg)
class ImpactSetWaveTestTimeValidation(ImpactSetWaveTestTime):
def __init__(self, specs_dataset_cfg):
super().__init__('val', **specs_dataset_cfg)
class ImpactSetWaveTestTimeTest(ImpactSetWaveTestTime):
def __init__(self, specs_dataset_cfg):
super().__init__('test', **specs_dataset_cfg)
class ImpactSetWaveWithSilent(torch.utils.data.Dataset):
def __init__(self, split, random_crop, mel_num, spec_crop_len,
L=2.0, denoise=False, splits_path='./data',
data_path='data/ImpactSet/impactset-proccess-resize'):
super().__init__()
self.split = split
self.splits_path = splits_path
self.data_path = data_path
self.L = L
self.denoise = denoise
video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json')
if not os.path.exists(video_name_split_path):
self.make_split_files()
video_name = json.load(open(video_name_split_path, 'r'))
self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name}
self.left_over = int(FPS * L + 1)
self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name}
self.dataset = video_name
self.wav_transforms = transforms.Compose([
MakeMono(),
Padding(target_len=int(SR * self.L)),
])
self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = {}
video = self.dataset[idx]
available_frame_idx = self.video_frame_cnt[video] - self.left_over
wave_path = self.video_audio_path[video]
if self.denoise:
wave_path = wave_path.replace('.wav', '_denoised.wav')
start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0]
# target
start_t = (start_idx + 0.5) / FPS
start_audio_idx = non_negative(start_t * SR)
wav, _ = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx)
wav = self.wav_transforms(wav)
item['image'] = wav # (44100,)
# item['wav'] = wav
item['file_path_wav_'] = wave_path
item['label'] = 'None'
item['target'] = 'None'
return item
def make_split_files(self):
raise NotImplementedError
class ImpactSetWaveWithSilentTrain(ImpactSetWaveWithSilent):
def __init__(self, specs_dataset_cfg):
super().__init__('train', **specs_dataset_cfg)
class ImpactSetWaveWithSilentValidation(ImpactSetWaveWithSilent):
def __init__(self, specs_dataset_cfg):
super().__init__('val', **specs_dataset_cfg)
class ImpactSetWaveWithSilentTest(ImpactSetWaveWithSilent):
def __init__(self, specs_dataset_cfg):
super().__init__('test', **specs_dataset_cfg)
class ImpactSetWaveCondOnImage(torch.utils.data.Dataset):
def __init__(self, split,
L=2.0, frame_transforms=None, denoise=False, splits_path='./data',
data_path='data/ImpactSet/impactset-proccess-resize',
p_outside_cond=0.):
super().__init__()
self.split = split
self.splits_path = splits_path
self.frame_transforms = frame_transforms
self.data_path = data_path
self.L = L
self.denoise = denoise
self.p_outside_cond = torch.tensor(p_outside_cond)
video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json')
if not os.path.exists(video_name_split_path):
self.make_split_files()
video_name = json.load(open(video_name_split_path, 'r'))
self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name}
self.left_over = int(FPS * L + 1)
for v, cnt in self.video_frame_cnt.items():
if cnt - (3*self.left_over) <= 0:
video_name.remove(v)
self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name}
self.dataset = video_name
video_timing_split_path = os.path.join(splits_path, f'countixAV_{split}_timing.json')
self.video_timing = json.load(open(video_timing_split_path, 'r'))
self.video_timing = {v: [int(float(t) * FPS) for t in ts] for v, ts in self.video_timing.items()}
if split != 'test':
video_class_path = os.path.join(splits_path, f'countixAV_{split}_class.json')
if not os.path.exists(video_class_path):
self.make_video_class()
self.video_class = json.load(open(video_class_path, 'r'))
self.class2video = {}
for v, c in self.video_class.items():
if c not in self.class2video.keys():
self.class2video[c] = []
self.class2video[c].append(v)
self.wav_transforms = transforms.Compose([
MakeMono(),
Padding(target_len=int(SR * self.L)),
])
if self.frame_transforms == None:
self.frame_transforms = transforms.Compose([
Resize3D(128),
RandomResizedCrop3D(112, scale=(0.5, 1.0)),
RandomHorizontalFlip3D(),
ColorJitter3D(brightness=0.1, saturation=0.1),
ToTensor3D(),
Normalize3D(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def make_video_class(self):
meta_path = f'data/ImpactSet/data-info/CountixAV_{self.split}.csv'
video_class = {}
with open(meta_path, 'r') as f:
reader = csv.reader(f)
for i, row in enumerate(reader):
if i == 0:
continue
vid, k_st, k_et = row[:3]
video_name = f'{vid}_{int(k_st):0>4d}_{int(k_et):0>4d}'
if video_name not in self.dataset:
continue
video_class[video_name] = row[-1]
with open(os.path.join(self.splits_path, f'countixAV_{self.split}_class.json'), 'w') as f:
json.dump(video_class, f)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = {}
video = self.dataset[idx]
available_frame_idx = self.video_frame_cnt[video] - self.left_over
rep_start_idx, rep_end_idx = self.video_timing[video]
rep_end_idx = min(available_frame_idx, rep_end_idx)
if available_frame_idx <= rep_start_idx + self.L * FPS:
idx_set = list(range(0, available_frame_idx))
else:
idx_set = list(range(rep_start_idx, rep_end_idx))
start_idx = sample(idx_set, k=1)[0]
wave_path = self.video_audio_path[video]
if self.denoise:
wave_path = wave_path.replace('.wav', '_denoised.wav')
# target
start_t = (start_idx + 0.5) / FPS
end_idx= non_negative(start_idx + FPS * self.L)
start_audio_idx = non_negative(start_t * SR)
wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx)
assert sr == SR
wav = self.wav_transforms(wav)
frame_path = os.path.join(self.data_path, video, 'frames')
frames = [Image.open(os.path.join(
frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
range(start_idx, end_idx)]
if torch.all(torch.bernoulli(self.p_outside_cond) == 1.) and self.split != 'test':
# outside from the same class
cur_class = self.video_class[video]
tmp_video = copy.copy(self.class2video[cur_class])
if len(tmp_video) > 1:
# if only 1 video in the class, use itself
tmp_video.remove(video)
cond_video = sample(tmp_video, k=1)[0]
cond_available_frame_idx = self.video_frame_cnt[cond_video] - self.left_over
cond_start_idx = torch.randint(0, cond_available_frame_idx, (1,)).tolist()[0]
else:
cond_video = video
idx_set = list(range(0, start_idx)) + list(range(end_idx, available_frame_idx))
cond_start_idx = random.sample(idx_set, k=1)[0]
cond_end_idx = non_negative(cond_start_idx + FPS * self.L)
cond_start_t = (cond_start_idx + 0.5) / FPS
cond_audio_idx = non_negative(cond_start_t * SR)
cond_frame_path = os.path.join(self.data_path, cond_video, 'frames')
cond_wave_path = self.video_audio_path[cond_video]
cond_frames = [Image.open(os.path.join(
cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
range(cond_start_idx, cond_end_idx)]
cond_wav, sr = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_audio_idx)
assert sr == SR
cond_wav = self.wav_transforms(cond_wav)
item['image'] = wav # (44100,)
item['cond_image'] = cond_wav # (44100,)
item['file_path_wav_'] = wave_path
item['file_path_cond_wav_'] = cond_wave_path
if self.frame_transforms is not None:
cond_frames = self.frame_transforms(cond_frames)
frames = self.frame_transforms(frames)
item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3)
item['file_path_feats_'] = (frame_path, start_idx)
item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx)
item['label'] = 'None'
item['target'] = 'None'
return item
def make_split_files(self):
raise NotImplementedError
class ImpactSetWaveCondOnImageTrain(ImpactSetWaveCondOnImage):
def __init__(self, dataset_cfg):
train_transforms = transforms.Compose([
Resize3D(128),
RandomResizedCrop3D(112, scale=(0.5, 1.0)),
RandomHorizontalFlip3D(),
ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1),
ToTensor3D(),
Normalize3D(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
super().__init__('train', frame_transforms=train_transforms, **dataset_cfg)
class ImpactSetWaveCondOnImageValidation(ImpactSetWaveCondOnImage):
def __init__(self, dataset_cfg):
valid_transforms = transforms.Compose([
Resize3D(128),
CenterCrop3D(112),
ToTensor3D(),
Normalize3D(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg)
class ImpactSetWaveCondOnImageTest(ImpactSetWaveCondOnImage):
def __init__(self, dataset_cfg):
test_transforms = transforms.Compose([
Resize3D(128),
CenterCrop3D(112),
ToTensor3D(),
Normalize3D(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
super().__init__('test', frame_transforms=test_transforms, **dataset_cfg)
class ImpactSetCleanWaveCondOnImage(ImpactSetWaveCondOnImage):
def __init__(self, split, L=2, frame_transforms=None, denoise=False, splits_path='./data', data_path='data/ImpactSet/impactset-proccess-resize', p_outside_cond=0):
super().__init__(split, L, frame_transforms, denoise, splits_path, data_path, p_outside_cond)
pred_timing_path = f'data/countixAV_{split}_timing_processed_0.20.json'
assert os.path.exists(pred_timing_path)
self.pred_timing = json.load(open(pred_timing_path, 'r'))
self.dataset = []
for v, ts in self.pred_timing.items():
if v in self.video_audio_path.keys():
for t in ts:
self.dataset.append([v, t])
def __getitem__(self, idx):
item = {}
video, start_t = self.dataset[idx]
available_frame_idx = self.video_frame_cnt[video] - self.left_over
available_timing = (available_frame_idx + 0.5) / FPS
start_t = float(start_t)
start_t = min(start_t, available_timing)
start_idx = non_negative(start_t * FPS - 0.5)
wave_path = self.video_audio_path[video]
if self.denoise:
wave_path = wave_path.replace('.wav', '_denoised.wav')
# target
end_idx= non_negative(start_idx + FPS * self.L)
start_audio_idx = non_negative(start_t * SR)
wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx)
assert sr == SR
wav = self.wav_transforms(wav)
frame_path = os.path.join(self.data_path, video, 'frames')
frames = [Image.open(os.path.join(
frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
range(start_idx, end_idx)]
if torch.all(torch.bernoulli(self.p_outside_cond) == 1.):
other_video = list(self.pred_timing.keys())
other_video.remove(video)
cond_video = sample(other_video, k=1)[0]
cond_available_frame_idx = self.video_frame_cnt[cond_video] - self.left_over
cond_available_timing = (cond_available_frame_idx + 0.5) / FPS
else:
cond_video = video
cond_available_timing = available_timing
cond_start_t = sample(self.pred_timing[cond_video], k=1)[0]
cond_start_t = float(cond_start_t)
cond_start_t = min(cond_start_t, cond_available_timing)
cond_start_idx = non_negative(cond_start_t * FPS - 0.5)
cond_end_idx = non_negative(cond_start_idx + FPS * self.L)
cond_audio_idx = non_negative(cond_start_t * SR)
cond_frame_path = os.path.join(self.data_path, cond_video, 'frames')
cond_wave_path = self.video_audio_path[cond_video]
cond_frames = [Image.open(os.path.join(
cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
range(cond_start_idx, cond_end_idx)]
cond_wav, sr = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_audio_idx)
assert sr == SR
cond_wav = self.wav_transforms(cond_wav)
item['image'] = wav # (44100,)
item['cond_image'] = cond_wav # (44100,)
item['file_path_wav_'] = wave_path
item['file_path_cond_wav_'] = cond_wave_path
if self.frame_transforms is not None:
cond_frames = self.frame_transforms(cond_frames)
frames = self.frame_transforms(frames)
item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3)
item['file_path_feats_'] = (frame_path, start_idx)
item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx)
item['label'] = 'None'
item['target'] = 'None'
return item
class ImpactSetCleanWaveCondOnImageTrain(ImpactSetCleanWaveCondOnImage):
def __init__(self, dataset_cfg):
train_transforms = transforms.Compose([
Resize3D(128),
RandomResizedCrop3D(112, scale=(0.5, 1.0)),
RandomHorizontalFlip3D(),
ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1),
ToTensor3D(),
Normalize3D(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
super().__init__('train', frame_transforms=train_transforms, **dataset_cfg)
class ImpactSetCleanWaveCondOnImageValidation(ImpactSetCleanWaveCondOnImage):
def __init__(self, dataset_cfg):
valid_transforms = transforms.Compose([
Resize3D(128),
CenterCrop3D(112),
ToTensor3D(),
Normalize3D(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg)
class ImpactSetCleanWaveCondOnImageTest(ImpactSetCleanWaveCondOnImage):
def __init__(self, dataset_cfg):
test_transforms = transforms.Compose([
Resize3D(128),
CenterCrop3D(112),
ToTensor3D(),
Normalize3D(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
super().__init__('test', frame_transforms=test_transforms, **dataset_cfg)
if __name__ == '__main__':
import sys
from omegaconf import OmegaConf
cfg = OmegaConf.load('configs/countixAV_transformer_denoise_clean.yaml')
data = instantiate_from_config(cfg.data)
data.prepare_data()
data.setup()
print(data.datasets['train'])
print(len(data.datasets['train']))
# print(data.datasets['train'][24])
exit()
stats = []
torch.manual_seed(0)
np.random.seed(0)
random.seed = 0
for k in range(1):
x = np.arange(SR * 2)
for i in tqdm(range(len(data.datasets['train']))):
wav = data.datasets['train'][i]['wav']
spec = data.datasets['train'][i]['image']
spec = 0.5 * (spec + 1)
spec_rms = rms(spec)
stats.append(float(spec_rms))
# plt.plot(x, wav)
# plt.ylim(-1, 1)
# plt.savefig(f'tmp/th0.1_wav_e_{k}_{i}_{mean_val:.3f}_{spec_rms:.3f}.png')
# plt.close()
# plt.cla()
soundfile.write(f'tmp/wav_e_{k}_{i}_{spec_rms:.3f}.wav', wav, SR)
draw_spec(spec, f'tmp/wav_spec_e_{k}_{i}_{spec_rms:.3f}.png')
if i == 100:
break
# plt.hist(stats, bins=50)
# plt.savefig(f'tmp/rms_spec_stats.png')