import json import os import matplotlib.pyplot as plt import torch from torchvision import transforms import numpy as np from tqdm import tqdm from random import sample import torchaudio import logging from glob import glob import sys import soundfile import copy import csv import noisereduce as nr sys.path.insert(0, '.') # nopep8 from train import instantiate_from_config from foleycrafter.models.specvqgan.data.transforms import * torchaudio.set_audio_backend("sox_io") logger = logging.getLogger(f'main.{__name__}') SR = 22050 FPS = 15 MAX_SAMPLE_ITER = 10 def non_negative(x): return int(np.round(max(0, x), 0)) def rms(x): return np.sqrt(np.mean(x**2)) def get_GH_data_identifier(video_name, start_idx, split='_'): if isinstance(start_idx, str): return video_name + split + start_idx elif isinstance(start_idx, int): return video_name + split + str(start_idx) else: raise NotImplementedError def draw_spec(spec, dest, cmap='magma'): plt.imshow(spec, cmap=cmap, origin='lower') plt.axis('off') plt.savefig(dest, bbox_inches='tight', pad_inches=0., dpi=300) plt.close() def convert_to_decibel(arr): ref = 1 return 20 * np.log10(abs(arr + 1e-4) / ref) class ResampleFrames(object): def __init__(self, feat_sample_size, times_to_repeat_after_resample=None): self.feat_sample_size = feat_sample_size self.times_to_repeat_after_resample = times_to_repeat_after_resample def __call__(self, item): feat_len = item['feature'].shape[0] ## resample assert feat_len >= self.feat_sample_size # evenly spaced points (abcdefghkl -> aoooofoooo) idx = np.linspace(0, feat_len, self.feat_sample_size, dtype=np.int, endpoint=False) # xoooo xoooo -> ooxoo ooxoo shift = feat_len // (self.feat_sample_size + 1) idx = idx + shift ## repeat after resampling (abc -> aaaabbbbcccc) if self.times_to_repeat_after_resample is not None and self.times_to_repeat_after_resample > 1: idx = np.repeat(idx, self.times_to_repeat_after_resample) item['feature'] = item['feature'][idx, :] return item class ImpactSetWave(torch.utils.data.Dataset): def __init__(self, split, random_crop, mel_num, spec_crop_len, L=2.0, denoise=False, splits_path='./data', data_path='data/ImpactSet/impactset-proccess-resize'): super().__init__() self.split = split self.splits_path = splits_path self.data_path = data_path self.L = L self.denoise = denoise video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json') if not os.path.exists(video_name_split_path): self.make_split_files() video_name = json.load(open(video_name_split_path, 'r')) self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name} self.left_over = int(FPS * L + 1) self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name} self.dataset = video_name self.wav_transforms = transforms.Compose([ MakeMono(), Padding(target_len=int(SR * self.L)), ]) self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop) def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = {} video = self.dataset[idx] available_frame_idx = self.video_frame_cnt[video] - self.left_over wav = None spec = None max_db = -np.inf wave_path = '' cur_wave_path = self.video_audio_path[video] if self.denoise: cur_wave_path = cur_wave_path.replace('.wav', '_denoised.wav') for _ in range(10): start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0] # target start_t = (start_idx + 0.5) / FPS start_audio_idx = non_negative(start_t * SR) cur_wav, _ = soundfile.read(cur_wave_path, frames=int(SR * self.L), start=start_audio_idx) decibel = convert_to_decibel(cur_wav) if float(np.mean(decibel)) > max_db: wav = cur_wav wave_path = cur_wave_path max_db = float(np.mean(decibel)) if max_db >= -40: break # print(max_db) wav = self.wav_transforms(wav) item['image'] = wav # (80, 173) # item['wav'] = wav item['file_path_wav_'] = wave_path item['label'] = 'None' item['target'] = 'None' return item def make_split_files(self): raise NotImplementedError class ImpactSetWaveTrain(ImpactSetWave): def __init__(self, specs_dataset_cfg): super().__init__('train', **specs_dataset_cfg) class ImpactSetWaveValidation(ImpactSetWave): def __init__(self, specs_dataset_cfg): super().__init__('val', **specs_dataset_cfg) class ImpactSetWaveTest(ImpactSetWave): def __init__(self, specs_dataset_cfg): super().__init__('test', **specs_dataset_cfg) class ImpactSetSpec(torch.utils.data.Dataset): def __init__(self, split, random_crop, mel_num, spec_crop_len, L=2.0, denoise=False, splits_path='./data', data_path='data/ImpactSet/impactset-proccess-resize'): super().__init__() self.split = split self.splits_path = splits_path self.data_path = data_path self.L = L self.denoise = denoise video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json') if not os.path.exists(video_name_split_path): self.make_split_files() video_name = json.load(open(video_name_split_path, 'r')) self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name} self.left_over = int(FPS * L + 1) self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name} self.dataset = video_name self.wav_transforms = transforms.Compose([ MakeMono(), SpectrogramTorchAudio(nfft=1024, hoplen=1024//4, spec_power=1), MelScaleTorchAudio(sr=SR, stft=513, fmin=125, fmax=7600, nmels=80), LowerThresh(1e-5), Log10(), Multiply(20), Subtract(20), Add(100), Divide(100), Clip(0, 1.0), TrimSpec(173), ]) self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop) def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = {} video = self.dataset[idx] available_frame_idx = self.video_frame_cnt[video] - self.left_over wav = None spec = None max_rms = -np.inf wave_path = '' cur_wave_path = self.video_audio_path[video] if self.denoise: cur_wave_path = cur_wave_path.replace('.wav', '_denoised.wav') for _ in range(10): start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0] # target start_t = (start_idx + 0.5) / FPS start_audio_idx = non_negative(start_t * SR) cur_wav, _ = soundfile.read(cur_wave_path, frames=int(SR * self.L), start=start_audio_idx) if self.wav_transforms is not None: spec_tensor = self.wav_transforms(torch.tensor(cur_wav).float()) cur_spec = spec_tensor.numpy() # zeros padding if not enough spec t steps if cur_spec.shape[1] < 173: pad = np.zeros((80, 173), dtype=cur_spec.dtype) pad[:, :cur_spec.shape[1]] = cur_spec cur_spec = pad rms_val = rms(cur_spec) if rms_val > max_rms: wav = cur_wav spec = cur_spec wave_path = cur_wave_path max_rms = rms_val # print(rms_val) if max_rms >= 0.1: break item['image'] = 2 * spec - 1 # (80, 173) # item['wav'] = wav item['file_path_wav_'] = wave_path item['label'] = 'None' item['target'] = 'None' if self.spec_transforms is not None: item = self.spec_transforms(item) return item def make_split_files(self): raise NotImplementedError class ImpactSetSpecTrain(ImpactSetSpec): def __init__(self, specs_dataset_cfg): super().__init__('train', **specs_dataset_cfg) class ImpactSetSpecValidation(ImpactSetSpec): def __init__(self, specs_dataset_cfg): super().__init__('val', **specs_dataset_cfg) class ImpactSetSpecTest(ImpactSetSpec): def __init__(self, specs_dataset_cfg): super().__init__('test', **specs_dataset_cfg) class ImpactSetWaveTestTime(torch.utils.data.Dataset): def __init__(self, split, random_crop, mel_num, spec_crop_len, L=2.0, denoise=False, splits_path='./data', data_path='data/ImpactSet/impactset-proccess-resize'): super().__init__() self.split = split self.splits_path = splits_path self.data_path = data_path self.L = L self.denoise = denoise self.video_list = glob('data/ImpactSet/RawVideos/StockVideo_sound/*.wav') + [ 'data/ImpactSet/RawVideos/YouTube-impact-ccl/1_ckbCU5aQs/1_ckbCU5aQs_0013_0016_resize.wav', 'data/ImpactSet/RawVideos/YouTube-impact-ccl/GFmuVBiwz6k/GFmuVBiwz6k_0034_0054_resize.wav', 'data/ImpactSet/RawVideos/YouTube-impact-ccl/OsPcY316h1M/OsPcY316h1M_0000_0005_resize.wav', 'data/ImpactSet/RawVideos/YouTube-impact-ccl/SExIpBIBj_k/SExIpBIBj_k_0009_0019_resize.wav', 'data/ImpactSet/RawVideos/YouTube-impact-ccl/S6TkbV4B4QI/S6TkbV4B4QI_0028_0036_resize.wav', 'data/ImpactSet/RawVideos/YouTube-impact-ccl/2Ld24pPIn3k/2Ld24pPIn3k_0005_0011_resize.wav', 'data/ImpactSet/RawVideos/YouTube-impact-ccl/6d1YS7fdBK4/6d1YS7fdBK4_0007_0019_resize.wav', 'data/ImpactSet/RawVideos/YouTube-impact-ccl/JnBsmJgEkiw/JnBsmJgEkiw_0008_0016_resize.wav', 'data/ImpactSet/RawVideos/YouTube-impact-ccl/xcUyiXt0gjo/xcUyiXt0gjo_0015_0021_resize.wav', 'data/ImpactSet/RawVideos/YouTube-impact-ccl/4DRFJnZjpMM/4DRFJnZjpMM_0000_0010_resize.wav' ] + glob('data/ImpactSet/RawVideos/self_recorded/*_resize.wav') self.wav_transforms = transforms.Compose([ MakeMono(), SpectrogramTorchAudio(nfft=1024, hoplen=1024//4, spec_power=1), MelScaleTorchAudio(sr=SR, stft=513, fmin=125, fmax=7600, nmels=80), LowerThresh(1e-5), Log10(), Multiply(20), Subtract(20), Add(100), Divide(100), Clip(0, 1.0), TrimSpec(173), ]) self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop) def __len__(self): return len(self.video_list) def __getitem__(self, idx): item = {} wave_path = self.video_list[idx] wav, _ = soundfile.read(wave_path) start_idx = random.randint(0, min(4, wav.shape[0] - int(SR * self.L))) wav = wav[start_idx:start_idx+int(SR * self.L)] if self.denoise: if len(wav.shape) == 1: wav = wav[None, :] wav = nr.reduce_noise(y=wav, sr=SR, n_fft=1024, hop_length=1024//4) wav = wav.squeeze() if self.wav_transforms is not None: spec_tensor = self.wav_transforms(torch.tensor(wav).float()) spec = spec_tensor.numpy() if spec.shape[1] < 173: pad = np.zeros((80, 173), dtype=spec.dtype) pad[:, :spec.shape[1]] = spec spec = pad item['image'] = 2 * spec - 1 # (80, 173) # item['wav'] = wav item['file_path_wav_'] = wave_path item['label'] = 'None' item['target'] = 'None' if self.spec_transforms is not None: item = self.spec_transforms(item) return item def make_split_files(self): raise NotImplementedError class ImpactSetWaveTestTimeTrain(ImpactSetWaveTestTime): def __init__(self, specs_dataset_cfg): super().__init__('train', **specs_dataset_cfg) class ImpactSetWaveTestTimeValidation(ImpactSetWaveTestTime): def __init__(self, specs_dataset_cfg): super().__init__('val', **specs_dataset_cfg) class ImpactSetWaveTestTimeTest(ImpactSetWaveTestTime): def __init__(self, specs_dataset_cfg): super().__init__('test', **specs_dataset_cfg) class ImpactSetWaveWithSilent(torch.utils.data.Dataset): def __init__(self, split, random_crop, mel_num, spec_crop_len, L=2.0, denoise=False, splits_path='./data', data_path='data/ImpactSet/impactset-proccess-resize'): super().__init__() self.split = split self.splits_path = splits_path self.data_path = data_path self.L = L self.denoise = denoise video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json') if not os.path.exists(video_name_split_path): self.make_split_files() video_name = json.load(open(video_name_split_path, 'r')) self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name} self.left_over = int(FPS * L + 1) self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name} self.dataset = video_name self.wav_transforms = transforms.Compose([ MakeMono(), Padding(target_len=int(SR * self.L)), ]) self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop) def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = {} video = self.dataset[idx] available_frame_idx = self.video_frame_cnt[video] - self.left_over wave_path = self.video_audio_path[video] if self.denoise: wave_path = wave_path.replace('.wav', '_denoised.wav') start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0] # target start_t = (start_idx + 0.5) / FPS start_audio_idx = non_negative(start_t * SR) wav, _ = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx) wav = self.wav_transforms(wav) item['image'] = wav # (44100,) # item['wav'] = wav item['file_path_wav_'] = wave_path item['label'] = 'None' item['target'] = 'None' return item def make_split_files(self): raise NotImplementedError class ImpactSetWaveWithSilentTrain(ImpactSetWaveWithSilent): def __init__(self, specs_dataset_cfg): super().__init__('train', **specs_dataset_cfg) class ImpactSetWaveWithSilentValidation(ImpactSetWaveWithSilent): def __init__(self, specs_dataset_cfg): super().__init__('val', **specs_dataset_cfg) class ImpactSetWaveWithSilentTest(ImpactSetWaveWithSilent): def __init__(self, specs_dataset_cfg): super().__init__('test', **specs_dataset_cfg) class ImpactSetWaveCondOnImage(torch.utils.data.Dataset): def __init__(self, split, L=2.0, frame_transforms=None, denoise=False, splits_path='./data', data_path='data/ImpactSet/impactset-proccess-resize', p_outside_cond=0.): super().__init__() self.split = split self.splits_path = splits_path self.frame_transforms = frame_transforms self.data_path = data_path self.L = L self.denoise = denoise self.p_outside_cond = torch.tensor(p_outside_cond) video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json') if not os.path.exists(video_name_split_path): self.make_split_files() video_name = json.load(open(video_name_split_path, 'r')) self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name} self.left_over = int(FPS * L + 1) for v, cnt in self.video_frame_cnt.items(): if cnt - (3*self.left_over) <= 0: video_name.remove(v) self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name} self.dataset = video_name video_timing_split_path = os.path.join(splits_path, f'countixAV_{split}_timing.json') self.video_timing = json.load(open(video_timing_split_path, 'r')) self.video_timing = {v: [int(float(t) * FPS) for t in ts] for v, ts in self.video_timing.items()} if split != 'test': video_class_path = os.path.join(splits_path, f'countixAV_{split}_class.json') if not os.path.exists(video_class_path): self.make_video_class() self.video_class = json.load(open(video_class_path, 'r')) self.class2video = {} for v, c in self.video_class.items(): if c not in self.class2video.keys(): self.class2video[c] = [] self.class2video[c].append(v) self.wav_transforms = transforms.Compose([ MakeMono(), Padding(target_len=int(SR * self.L)), ]) if self.frame_transforms == None: self.frame_transforms = transforms.Compose([ Resize3D(128), RandomResizedCrop3D(112, scale=(0.5, 1.0)), RandomHorizontalFlip3D(), ColorJitter3D(brightness=0.1, saturation=0.1), ToTensor3D(), Normalize3D(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def make_video_class(self): meta_path = f'data/ImpactSet/data-info/CountixAV_{self.split}.csv' video_class = {} with open(meta_path, 'r') as f: reader = csv.reader(f) for i, row in enumerate(reader): if i == 0: continue vid, k_st, k_et = row[:3] video_name = f'{vid}_{int(k_st):0>4d}_{int(k_et):0>4d}' if video_name not in self.dataset: continue video_class[video_name] = row[-1] with open(os.path.join(self.splits_path, f'countixAV_{self.split}_class.json'), 'w') as f: json.dump(video_class, f) def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = {} video = self.dataset[idx] available_frame_idx = self.video_frame_cnt[video] - self.left_over rep_start_idx, rep_end_idx = self.video_timing[video] rep_end_idx = min(available_frame_idx, rep_end_idx) if available_frame_idx <= rep_start_idx + self.L * FPS: idx_set = list(range(0, available_frame_idx)) else: idx_set = list(range(rep_start_idx, rep_end_idx)) start_idx = sample(idx_set, k=1)[0] wave_path = self.video_audio_path[video] if self.denoise: wave_path = wave_path.replace('.wav', '_denoised.wav') # target start_t = (start_idx + 0.5) / FPS end_idx= non_negative(start_idx + FPS * self.L) start_audio_idx = non_negative(start_t * SR) wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx) assert sr == SR wav = self.wav_transforms(wav) frame_path = os.path.join(self.data_path, video, 'frames') frames = [Image.open(os.path.join( frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in range(start_idx, end_idx)] if torch.all(torch.bernoulli(self.p_outside_cond) == 1.) and self.split != 'test': # outside from the same class cur_class = self.video_class[video] tmp_video = copy.copy(self.class2video[cur_class]) if len(tmp_video) > 1: # if only 1 video in the class, use itself tmp_video.remove(video) cond_video = sample(tmp_video, k=1)[0] cond_available_frame_idx = self.video_frame_cnt[cond_video] - self.left_over cond_start_idx = torch.randint(0, cond_available_frame_idx, (1,)).tolist()[0] else: cond_video = video idx_set = list(range(0, start_idx)) + list(range(end_idx, available_frame_idx)) cond_start_idx = random.sample(idx_set, k=1)[0] cond_end_idx = non_negative(cond_start_idx + FPS * self.L) cond_start_t = (cond_start_idx + 0.5) / FPS cond_audio_idx = non_negative(cond_start_t * SR) cond_frame_path = os.path.join(self.data_path, cond_video, 'frames') cond_wave_path = self.video_audio_path[cond_video] cond_frames = [Image.open(os.path.join( cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in range(cond_start_idx, cond_end_idx)] cond_wav, sr = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_audio_idx) assert sr == SR cond_wav = self.wav_transforms(cond_wav) item['image'] = wav # (44100,) item['cond_image'] = cond_wav # (44100,) item['file_path_wav_'] = wave_path item['file_path_cond_wav_'] = cond_wave_path if self.frame_transforms is not None: cond_frames = self.frame_transforms(cond_frames) frames = self.frame_transforms(frames) item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3) item['file_path_feats_'] = (frame_path, start_idx) item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx) item['label'] = 'None' item['target'] = 'None' return item def make_split_files(self): raise NotImplementedError class ImpactSetWaveCondOnImageTrain(ImpactSetWaveCondOnImage): def __init__(self, dataset_cfg): train_transforms = transforms.Compose([ Resize3D(128), RandomResizedCrop3D(112, scale=(0.5, 1.0)), RandomHorizontalFlip3D(), ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1), ToTensor3D(), Normalize3D(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) super().__init__('train', frame_transforms=train_transforms, **dataset_cfg) class ImpactSetWaveCondOnImageValidation(ImpactSetWaveCondOnImage): def __init__(self, dataset_cfg): valid_transforms = transforms.Compose([ Resize3D(128), CenterCrop3D(112), ToTensor3D(), Normalize3D(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg) class ImpactSetWaveCondOnImageTest(ImpactSetWaveCondOnImage): def __init__(self, dataset_cfg): test_transforms = transforms.Compose([ Resize3D(128), CenterCrop3D(112), ToTensor3D(), Normalize3D(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) super().__init__('test', frame_transforms=test_transforms, **dataset_cfg) class ImpactSetCleanWaveCondOnImage(ImpactSetWaveCondOnImage): def __init__(self, split, L=2, frame_transforms=None, denoise=False, splits_path='./data', data_path='data/ImpactSet/impactset-proccess-resize', p_outside_cond=0): super().__init__(split, L, frame_transforms, denoise, splits_path, data_path, p_outside_cond) pred_timing_path = f'data/countixAV_{split}_timing_processed_0.20.json' assert os.path.exists(pred_timing_path) self.pred_timing = json.load(open(pred_timing_path, 'r')) self.dataset = [] for v, ts in self.pred_timing.items(): if v in self.video_audio_path.keys(): for t in ts: self.dataset.append([v, t]) def __getitem__(self, idx): item = {} video, start_t = self.dataset[idx] available_frame_idx = self.video_frame_cnt[video] - self.left_over available_timing = (available_frame_idx + 0.5) / FPS start_t = float(start_t) start_t = min(start_t, available_timing) start_idx = non_negative(start_t * FPS - 0.5) wave_path = self.video_audio_path[video] if self.denoise: wave_path = wave_path.replace('.wav', '_denoised.wav') # target end_idx= non_negative(start_idx + FPS * self.L) start_audio_idx = non_negative(start_t * SR) wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx) assert sr == SR wav = self.wav_transforms(wav) frame_path = os.path.join(self.data_path, video, 'frames') frames = [Image.open(os.path.join( frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in range(start_idx, end_idx)] if torch.all(torch.bernoulli(self.p_outside_cond) == 1.): other_video = list(self.pred_timing.keys()) other_video.remove(video) cond_video = sample(other_video, k=1)[0] cond_available_frame_idx = self.video_frame_cnt[cond_video] - self.left_over cond_available_timing = (cond_available_frame_idx + 0.5) / FPS else: cond_video = video cond_available_timing = available_timing cond_start_t = sample(self.pred_timing[cond_video], k=1)[0] cond_start_t = float(cond_start_t) cond_start_t = min(cond_start_t, cond_available_timing) cond_start_idx = non_negative(cond_start_t * FPS - 0.5) cond_end_idx = non_negative(cond_start_idx + FPS * self.L) cond_audio_idx = non_negative(cond_start_t * SR) cond_frame_path = os.path.join(self.data_path, cond_video, 'frames') cond_wave_path = self.video_audio_path[cond_video] cond_frames = [Image.open(os.path.join( cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in range(cond_start_idx, cond_end_idx)] cond_wav, sr = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_audio_idx) assert sr == SR cond_wav = self.wav_transforms(cond_wav) item['image'] = wav # (44100,) item['cond_image'] = cond_wav # (44100,) item['file_path_wav_'] = wave_path item['file_path_cond_wav_'] = cond_wave_path if self.frame_transforms is not None: cond_frames = self.frame_transforms(cond_frames) frames = self.frame_transforms(frames) item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3) item['file_path_feats_'] = (frame_path, start_idx) item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx) item['label'] = 'None' item['target'] = 'None' return item class ImpactSetCleanWaveCondOnImageTrain(ImpactSetCleanWaveCondOnImage): def __init__(self, dataset_cfg): train_transforms = transforms.Compose([ Resize3D(128), RandomResizedCrop3D(112, scale=(0.5, 1.0)), RandomHorizontalFlip3D(), ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1), ToTensor3D(), Normalize3D(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) super().__init__('train', frame_transforms=train_transforms, **dataset_cfg) class ImpactSetCleanWaveCondOnImageValidation(ImpactSetCleanWaveCondOnImage): def __init__(self, dataset_cfg): valid_transforms = transforms.Compose([ Resize3D(128), CenterCrop3D(112), ToTensor3D(), Normalize3D(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg) class ImpactSetCleanWaveCondOnImageTest(ImpactSetCleanWaveCondOnImage): def __init__(self, dataset_cfg): test_transforms = transforms.Compose([ Resize3D(128), CenterCrop3D(112), ToTensor3D(), Normalize3D(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) super().__init__('test', frame_transforms=test_transforms, **dataset_cfg) if __name__ == '__main__': import sys from omegaconf import OmegaConf cfg = OmegaConf.load('configs/countixAV_transformer_denoise_clean.yaml') data = instantiate_from_config(cfg.data) data.prepare_data() data.setup() print(data.datasets['train']) print(len(data.datasets['train'])) # print(data.datasets['train'][24]) exit() stats = [] torch.manual_seed(0) np.random.seed(0) random.seed = 0 for k in range(1): x = np.arange(SR * 2) for i in tqdm(range(len(data.datasets['train']))): wav = data.datasets['train'][i]['wav'] spec = data.datasets['train'][i]['image'] spec = 0.5 * (spec + 1) spec_rms = rms(spec) stats.append(float(spec_rms)) # plt.plot(x, wav) # plt.ylim(-1, 1) # plt.savefig(f'tmp/th0.1_wav_e_{k}_{i}_{mean_val:.3f}_{spec_rms:.3f}.png') # plt.close() # plt.cla() soundfile.write(f'tmp/wav_e_{k}_{i}_{spec_rms:.3f}.wav', wav, SR) draw_spec(spec, f'tmp/wav_spec_e_{k}_{i}_{spec_rms:.3f}.png') if i == 100: break # plt.hist(stats, bins=50) # plt.savefig(f'tmp/rms_spec_stats.png')