KingNish commited on
Commit
1491666
·
verified ·
1 Parent(s): 5fdf8b7

Upload ./post_process_audio.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. post_process_audio.py +153 -0
post_process_audio.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torchaudio
5
+ import torchaudio.functional as F
6
+ import torchaudio.transforms as T
7
+ import re
8
+
9
+ def replace_low_freq_with_energy_matched(
10
+ a_file: str,
11
+ b_file: str,
12
+ c_file: str,
13
+ cutoff_freq: float = 5500.0,
14
+ eps: float = 1e-10
15
+ ):
16
+ """
17
+ 1. Load a_file (16kHz) and b_file (48kHz).
18
+ 2. Resample 'a' to 48kHz if needed.
19
+ 3. Match the low-frequency energy of 'a' to that of 'b'.
20
+ 4. Replace the low-frequency of 'b' with the matched low-frequency of 'a'.
21
+ 5. Save the result to c_file.
22
+
23
+ Args:
24
+ a_file (str): Path to a.mp3 (16kHz).
25
+ b_file (str): Path to b.mp3 (48kHz).
26
+ c_file (str): Output path for combined result.
27
+ cutoff_freq (float): Cutoff frequency for low/highpass filters.
28
+ eps (float): Small value to avoid division-by-zero.
29
+ """
30
+
31
+ # ----------------------------------------------------------
32
+ # 1. Load the two files
33
+ # ----------------------------------------------------------
34
+ wave_a, sr_a = torchaudio.load(a_file)
35
+ wave_b, sr_b = torchaudio.load(b_file)
36
+
37
+ # If 'a' doesn't match 'b' sample rate, resample it
38
+ if sr_a != sr_b:
39
+ resampler = T.Resample(orig_freq=sr_a, new_freq=sr_b)
40
+ wave_a = resampler(wave_a)
41
+ sr_a = sr_b # Now they match
42
+
43
+ # ----------------------------------------------------------
44
+ # 2. Low-pass both signals to isolate low-frequency content
45
+ # ----------------------------------------------------------
46
+ wave_a_low = F.lowpass_biquad(
47
+ wave_a,
48
+ sample_rate=sr_b,
49
+ cutoff_freq=cutoff_freq
50
+ )
51
+ wave_b_low = F.lowpass_biquad(
52
+ wave_b,
53
+ sample_rate=sr_b,
54
+ cutoff_freq=cutoff_freq
55
+ )
56
+
57
+ # ----------------------------------------------------------
58
+ # 3. Compute RMS of low-frequency portions
59
+ # ----------------------------------------------------------
60
+ # We'll do a simple global RMS (across channels & time)
61
+ # If you need per-channel matching, handle each channel separately.
62
+ a_rms = wave_a_low.pow(2).mean().sqrt().item() + eps
63
+ b_rms = wave_b_low.pow(2).mean().sqrt().item() + eps
64
+
65
+ # ----------------------------------------------------------
66
+ # 4. Scale 'a_low' so its energy matches 'b_low'
67
+ # ----------------------------------------------------------
68
+ scale_factor = b_rms / a_rms
69
+ wave_a_low_matched = wave_a_low * scale_factor
70
+
71
+ # ----------------------------------------------------------
72
+ # 5. High-pass 'b' to isolate high-frequency content
73
+ # ----------------------------------------------------------
74
+ wave_b_high = F.highpass_biquad(
75
+ wave_b,
76
+ sample_rate=sr_b,
77
+ cutoff_freq=cutoff_freq
78
+ )
79
+
80
+ # ----------------------------------------------------------
81
+ # 6. Combine: (scaled a_low) + (b_high)
82
+ # ----------------------------------------------------------
83
+ if wave_a_low_matched.size(1)!=wave_b_high.size(1):
84
+ print(f"Original lengths: a_low={wave_a_low_matched.size()}, b_high={wave_b_high.size()}")
85
+ min_length = min(wave_a_low_matched.size(1), wave_b_high.size(1))
86
+ wave_a_low_matched = wave_a_low_matched[:, :min_length]
87
+ wave_b_high = wave_b_high[:, :min_length]
88
+
89
+ print(f"After truncation: a_low={wave_a_low_matched.size()}, b_high={wave_b_high.size()}")
90
+ print(f"Samples truncated: {max(wave_a_low_matched.size(1), wave_b_high.size(1)) - min_length}")
91
+
92
+ wave_combined = wave_a_low_matched + wave_b_high
93
+
94
+ # (Optional) Normalize if needed to avoid clipping
95
+ # wave_combined /= max(wave_combined.abs().max(), 1.0)
96
+
97
+ # ----------------------------------------------------------
98
+ # 7. Save to c.mp3
99
+ # ----------------------------------------------------------
100
+ torchaudio.save(c_file, wave_combined, sample_rate=sr_b)
101
+
102
+ print(f"Successfully created '{os.path.basename(c_file)}' with matched low-frequency energy.")
103
+
104
+ if __name__ == "__main__":
105
+ stage2_output_dir = sys.argv[1]
106
+ recons_dir = os.path.join(stage2_output_dir, "recons", "mix")
107
+ vocoder_dir = os.path.join(stage2_output_dir, "vocoder", "mix")
108
+ save_dir = os.path.join(stage2_output_dir, "post_process")
109
+ os.makedirs(save_dir, exist_ok=True)
110
+
111
+ # Create dictionaries mapping IDs to filenames
112
+ recons_files = {}
113
+ vocoder_files = {}
114
+
115
+ pattern = r"mixed_([a-f0-9-]+)_xcodec_16k\.mp3$"
116
+
117
+ # Map IDs to filenames for recons/mix
118
+ for filename in os.listdir(recons_dir):
119
+ match = re.search(pattern, filename)
120
+ if match:
121
+ recons_files[(match.group(1)).lower()] = filename
122
+
123
+ print(recons_files)
124
+
125
+ pattern = r"__([a-f0-9-]+)\.mp3$"
126
+ # Map IDs to filenames for vocoder/mix
127
+ for filename in os.listdir(vocoder_dir):
128
+ match = re.search(pattern, filename)
129
+ if match:
130
+ vocoder_files[(match.group(1)).lower()] = filename
131
+
132
+ # Find common IDs
133
+ common_ids = set(recons_files.keys()) & set(vocoder_files.keys())
134
+ print(f"Found {len(common_ids)} matching file pairs")
135
+
136
+ # Create matched file lists
137
+ a_list = []
138
+ b_list = []
139
+ for id in common_ids:
140
+ a_list.append(os.path.join(recons_dir, recons_files[id]))
141
+ b_list.append(os.path.join(vocoder_dir, vocoder_files[id]))
142
+
143
+ # Process only matching pairs
144
+ for a, b in zip(a_list, b_list):
145
+ if os.path.exists(os.path.join(save_dir, os.path.basename(b))):
146
+ continue
147
+
148
+ replace_low_freq_with_energy_matched(
149
+ a_file=a, # 16kHz
150
+ b_file=b, # 48kHz
151
+ c_file=os.path.join(save_dir, os.path.basename(b)),
152
+ cutoff_freq=5500.0
153
+ )