import os import random import torch from torch.utils.data import Dataset import torchaudio import numpy as np # Modify to handle dynamic target duration (8s in this case) # def pad_audio(audio, sample_rate=16000, target_duration=8.0): # target_length = int(sample_rate * target_duration) # Calculate target length for 8 seconds # current_length = audio.shape[1] # if current_length < target_length: # padding = target_length - current_length # audio = torch.cat((audio, torch.zeros(audio.shape[0], padding)), dim=1) # else: # audio = audio[:, :target_length] # return audio def pad_audio(audio, sample_rate=16000, target_duration=7.98): target_length = int(sample_rate * target_duration) # Calculate target length for 8 seconds current_length = audio.shape[1] if current_length < target_length: padding = target_length - current_length audio = torch.cat((audio, torch.zeros(audio.shape[0], padding)), dim=1) elif current_length > target_length: # Add one frame if length is one frame more than the target if current_length - target_length == 1: audio = torch.cat((audio, torch.zeros(audio.shape[0], 1)), dim=1) else: audio = audio[:, :target_length] return audio # Parse labels with 10ms frame intervals for 8-second audio def parse_labels(file_path, audio_length, sample_rate, frame_duration=0.010): frames_per_audio = int(audio_length / frame_duration) labels = np.zeros(frames_per_audio, dtype=np.float32) with open(file_path, 'r') as f: lines = f.readlines()[1:] # Skip header for line in lines: start, end, authenticity = line.strip().split('-') start_time = float(start) end_time = float(end) if authenticity == 'F': start_frame = int(start_time / frame_duration) end_frame = int(end_time / frame_duration) labels[start_frame:end_frame] = 1 # Mark 4 closest frames to boundaries for offset in range(1, 5): if start_frame - offset >= 0: labels[start_frame - offset] = 1 if end_frame + offset < frames_per_audio: labels[end_frame + offset] = 1 return labels class AudioDataset(Dataset): def __init__(self, audio_files, label_dir, sample_rate=16000, target_length=7.98): self.audio_files = audio_files self.label_dir = label_dir self.sample_rate = sample_rate self.target_length = target_length * sample_rate self.raw_target_length = target_length def __len__(self): return len(self.audio_files) def __getitem__(self, idx): audio_path = self.audio_files[idx] try: waveform, sr = torchaudio.load(audio_path) waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform) waveform = pad_audio(waveform, self.sample_rate, self.raw_target_length) audio_filename = os.path.basename(audio_path).replace(".wav", "") if audio_filename.startswith("RFP_R"): labels = np.zeros(int(self.raw_target_length / 0.010), dtype=np.float32) else: label_path = os.path.join(self.label_dir, f"{audio_filename}.wav_labels.txt") labels = parse_labels(label_path, self.raw_target_length, self.sample_rate).astype(np.float32) return waveform, torch.tensor(labels, dtype=torch.float32) except (OSError, IOError) as e: print(f"Error opening file {audio_path}: {e}") new_idx = random.randint(0, len(self.audio_files) - 1) return self.__getitem__(new_idx) def get_audio_file_paths(extrinsic_dir, intrinsic_dir, real_dir): extrinsic_files = [os.path.join(extrinsic_dir, f) for f in os.listdir(extrinsic_dir) if f.endswith(".wav") and not f.startswith("partial_fake")] intrinsic_files = [os.path.join(intrinsic_dir, f) for f in os.listdir(intrinsic_dir) if f.endswith(".wav") and not f.startswith("partial_fake")] real_files = [os.path.join(real_dir, f) for f in os.listdir(real_dir) if f.endswith(".wav") and not f.startswith("partial_fake")] # Combine all audio files into a single list, ensuring valid files only audio_files = [f for f in extrinsic_files + real_files if os.path.basename(f).startswith(("extrinsic"))] return audio_files