File size: 5,710 Bytes
1491666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import sys
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
import re

def replace_low_freq_with_energy_matched(
    a_file: str,
    b_file: str,
    c_file: str,
    cutoff_freq: float = 5500.0,
    eps: float = 1e-10
):
    """
    1. Load a_file (16kHz) and b_file (48kHz).
    2. Resample 'a' to 48kHz if needed.
    3. Match the low-frequency energy of 'a' to that of 'b'.
    4. Replace the low-frequency of 'b' with the matched low-frequency of 'a'.
    5. Save the result to c_file.
    
    Args:
        a_file (str): Path to a.mp3 (16kHz).
        b_file (str): Path to b.mp3 (48kHz).
        c_file (str): Output path for combined result.
        cutoff_freq (float): Cutoff frequency for low/highpass filters.
        eps (float): Small value to avoid division-by-zero.
    """

    # ----------------------------------------------------------
    # 1. Load the two files
    # ----------------------------------------------------------
    wave_a, sr_a = torchaudio.load(a_file)
    wave_b, sr_b = torchaudio.load(b_file)
    
    # If 'a' doesn't match 'b' sample rate, resample it
    if sr_a != sr_b:
        resampler = T.Resample(orig_freq=sr_a, new_freq=sr_b)
        wave_a = resampler(wave_a)
        sr_a = sr_b  # Now they match
    
    # ----------------------------------------------------------
    # 2. Low-pass both signals to isolate low-frequency content
    # ----------------------------------------------------------
    wave_a_low = F.lowpass_biquad(
        wave_a,
        sample_rate=sr_b,
        cutoff_freq=cutoff_freq
    )
    wave_b_low = F.lowpass_biquad(
        wave_b,
        sample_rate=sr_b,
        cutoff_freq=cutoff_freq
    )
    
    # ----------------------------------------------------------
    # 3. Compute RMS of low-frequency portions
    # ----------------------------------------------------------
    # We'll do a simple global RMS (across channels & time)
    # If you need per-channel matching, handle each channel separately.
    a_rms = wave_a_low.pow(2).mean().sqrt().item() + eps
    b_rms = wave_b_low.pow(2).mean().sqrt().item() + eps
    
    # ----------------------------------------------------------
    # 4. Scale 'a_low' so its energy matches 'b_low'
    # ----------------------------------------------------------
    scale_factor = b_rms / a_rms
    wave_a_low_matched = wave_a_low * scale_factor
    
    # ----------------------------------------------------------
    # 5. High-pass 'b' to isolate high-frequency content
    # ----------------------------------------------------------
    wave_b_high = F.highpass_biquad(
        wave_b,
        sample_rate=sr_b,
        cutoff_freq=cutoff_freq
    )
    
    # ----------------------------------------------------------
    # 6. Combine: (scaled a_low) + (b_high)
    # ----------------------------------------------------------
    if wave_a_low_matched.size(1)!=wave_b_high.size(1):
        print(f"Original lengths: a_low={wave_a_low_matched.size()}, b_high={wave_b_high.size()}")
        min_length = min(wave_a_low_matched.size(1), wave_b_high.size(1))
        wave_a_low_matched = wave_a_low_matched[:, :min_length]
        wave_b_high = wave_b_high[:, :min_length]
        
        print(f"After truncation: a_low={wave_a_low_matched.size()}, b_high={wave_b_high.size()}")
        print(f"Samples truncated: {max(wave_a_low_matched.size(1), wave_b_high.size(1)) - min_length}")
    
    wave_combined = wave_a_low_matched + wave_b_high
    
    # (Optional) Normalize if needed to avoid clipping
    # wave_combined /= max(wave_combined.abs().max(), 1.0)
    
    # ----------------------------------------------------------
    # 7. Save to c.mp3
    # ----------------------------------------------------------
    torchaudio.save(c_file, wave_combined, sample_rate=sr_b)
    
    print(f"Successfully created '{os.path.basename(c_file)}' with matched low-frequency energy.")

if __name__ == "__main__":
    stage2_output_dir = sys.argv[1]
    recons_dir = os.path.join(stage2_output_dir, "recons", "mix")
    vocoder_dir = os.path.join(stage2_output_dir, "vocoder", "mix")
    save_dir = os.path.join(stage2_output_dir, "post_process")
    os.makedirs(save_dir, exist_ok=True)
    
    # Create dictionaries mapping IDs to filenames
    recons_files = {}
    vocoder_files = {}

    pattern = r"mixed_([a-f0-9-]+)_xcodec_16k\.mp3$"
    
    # Map IDs to filenames for recons/mix
    for filename in os.listdir(recons_dir):
        match = re.search(pattern, filename)
        if match:
            recons_files[(match.group(1)).lower()] = filename
    
    print(recons_files)

    pattern = r"__([a-f0-9-]+)\.mp3$"
    # Map IDs to filenames for vocoder/mix
    for filename in os.listdir(vocoder_dir):
        match = re.search(pattern, filename)
        if match:
            vocoder_files[(match.group(1)).lower()] = filename
    
    # Find common IDs
    common_ids = set(recons_files.keys()) & set(vocoder_files.keys())
    print(f"Found {len(common_ids)} matching file pairs")
    
    # Create matched file lists
    a_list = []
    b_list = []
    for id in common_ids:
        a_list.append(os.path.join(recons_dir, recons_files[id]))
        b_list.append(os.path.join(vocoder_dir, vocoder_files[id]))

    # Process only matching pairs
    for a, b in zip(a_list, b_list):
        if os.path.exists(os.path.join(save_dir, os.path.basename(b))):
            continue

        replace_low_freq_with_energy_matched(
            a_file=a,     # 16kHz
            b_file=b,     # 48kHz
            c_file=os.path.join(save_dir, os.path.basename(b)),
            cutoff_freq=5500.0
        )