k4d3 commited on
Commit
245ede6
·
1 Parent(s): 5b0ddd1

Refactor audio generation scripts to streamline processing and enhance functionality

Browse files

This commit introduces significant improvements to the audio generation workflow in both `audiogen_medium.py` and `stable_audio.py`. Key changes include:
- Removal of redundant seed extraction logic and integration of a new `process_audio_generations` function to handle audio generation in a more organized manner.
- Consolidation of argument preparation for audio generation into a dedicated `prepare_args` function, improving code clarity and maintainability.
- Enhanced user feedback during the audio generation process, ensuring clearer communication of the actions being performed.

These modifications optimize the audio generation process, improve code organization, and enhance the overall user experience.

audio/audiogen_medium.py CHANGED
@@ -4,18 +4,10 @@
4
  import sys
5
  import os
6
  import torch
7
- import torchaudio
8
  import random
9
- import multiprocessing as mp
10
  from audiocraft.models import AudioGen
11
  from audiocraft.data.audio import audio_write
12
-
13
- def get_seed_from_filename(filename):
14
- """Extract seed from filename like '12345.wav'"""
15
- try:
16
- return int(filename.split('.')[0])
17
- except:
18
- return None
19
 
20
  def generate_audio(args):
21
  description, seed, prompt_dir = args
@@ -37,61 +29,14 @@ def generate_audio(args):
37
  # Will save with loudness normalization at -14 db LUFS
38
  audio_write(file_path, wav[0].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
39
 
40
- if __name__ == '__main__':
41
- # Set start method to spawn for CUDA multiprocessing
42
- mp.set_start_method('spawn')
43
-
44
- descriptions = sys.argv[1:]
45
- if not descriptions:
46
- print('At least one prompt should be provided')
47
- sys.exit(1)
48
-
49
- # Base output directory
50
- base_output_dir = 'generated_audio'
51
- os.makedirs(base_output_dir, exist_ok=True)
52
-
53
- # Generate 25 variations for each prompt
54
- num_variations = 25
55
- num_processes = 3 # Number of parallel models to run
56
- seed_range = (0, 1000000) # Use seeds between 0 and 1,000,000
57
-
58
- for description in descriptions:
59
- # Create a safe folder name from the description
60
- folder_name = description.replace(' ', '_').replace('/', '_').replace('\\', '_')
61
- folder_name = ''.join(c for c in folder_name if c.isalnum() or c in '_-')
62
- prompt_dir = os.path.join(base_output_dir, folder_name)
63
- os.makedirs(prompt_dir, exist_ok=True)
64
-
65
- print(f"\nGenerating variations for prompt: {description}")
66
- print(f"Saving in directory: {prompt_dir}")
67
-
68
- # Get existing seeds
69
- existing_seeds = set()
70
- for filename in os.listdir(prompt_dir):
71
- if filename.endswith('.wav'):
72
- seed = get_seed_from_filename(filename)
73
- if seed is not None:
74
- existing_seeds.add(seed)
75
-
76
- if len(existing_seeds) >= num_variations:
77
- print(f"All {num_variations} variations already exist in {prompt_dir}, skipping...")
78
- continue
79
-
80
- # Generate new random seeds that haven't been used yet
81
- needed_variations = num_variations - len(existing_seeds)
82
- new_seeds = set()
83
- while len(new_seeds) < needed_variations:
84
- seed = random.randint(*seed_range)
85
- if seed not in existing_seeds and seed not in new_seeds:
86
- new_seeds.add(seed)
87
-
88
- print(f"Generating {needed_variations} new variations using {num_processes} parallel processes...")
89
- print(f"Using seeds: {sorted(new_seeds)}")
90
-
91
- # Prepare arguments for parallel processing
92
- args_list = [(description, seed, prompt_dir) for seed in new_seeds]
93
-
94
- # Use multiprocessing to distribute the work
95
- with mp.Pool(processes=num_processes) as pool:
96
- pool.map(generate_audio, args_list)
97
 
 
 
 
 
 
 
 
 
4
  import sys
5
  import os
6
  import torch
 
7
  import random
 
8
  from audiocraft.models import AudioGen
9
  from audiocraft.data.audio import audio_write
10
+ from utils.audio_utils import process_audio_generations
 
 
 
 
 
 
11
 
12
  def generate_audio(args):
13
  description, seed, prompt_dir = args
 
29
  # Will save with loudness normalization at -14 db LUFS
30
  audio_write(file_path, wav[0].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
31
 
32
+ def prepare_args(description, seeds, prompt_dir):
33
+ """Prepare arguments for the generate_audio function"""
34
+ return [(description, seed, prompt_dir) for seed in seeds]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ if __name__ == '__main__':
37
+ process_audio_generations(
38
+ descriptions=sys.argv[1:],
39
+ model_name='audiogen',
40
+ generate_fn=generate_audio,
41
+ prepare_args_fn=prepare_args
42
+ )
audio/stable_audio.py CHANGED
@@ -6,15 +6,11 @@ import os
6
  import torch
7
  import soundfile as sf
8
  import random
9
- import multiprocessing as mp
10
  from diffusers import StableAudioPipeline
 
11
 
12
- def get_seed_from_filename(filename):
13
- """Extract seed from filename like '12345.wav'"""
14
- try:
15
- return int(filename.split('.')[0])
16
- except:
17
- return None
18
 
19
  def generate_audio(args):
20
  description, negative_prompt, seed, prompt_dir = args
@@ -46,63 +42,14 @@ def generate_audio(args):
46
  print(f"Saving audio to: {file_path}")
47
  sf.write(file_path, output, pipe.vae.sampling_rate)
48
 
49
- if __name__ == '__main__':
50
- # Set start method to spawn for CUDA multiprocessing
51
- mp.set_start_method('spawn')
52
-
53
- descriptions = sys.argv[1:]
54
- if not descriptions:
55
- print('At least one prompt should be provided')
56
- sys.exit(1)
57
-
58
- # Default negative prompt
59
- negative_prompt = "Low quality, noise, distortion, low fidelity"
60
-
61
- # Base output directory
62
- base_output_dir = 'generated_audio/sa'
63
- os.makedirs(base_output_dir, exist_ok=True)
64
 
65
- # Generate 25 variations for each prompt
66
- num_variations = 25
67
- num_processes = 3 # Number of parallel models to run
68
- seed_range = (0, 1000000) # Use seeds between 0 and 1,000,000
69
-
70
- for description in descriptions:
71
- # Create a safe folder name from the description
72
- folder_name = description.replace(' ', '_').replace('/', '_').replace('\\', '_')
73
- folder_name = ''.join(c for c in folder_name if c.isalnum() or c in '_-')
74
- prompt_dir = os.path.join(base_output_dir, folder_name)
75
- os.makedirs(prompt_dir, exist_ok=True)
76
-
77
- print(f"\nGenerating variations for prompt: {description}")
78
- print(f"Saving in directory: {prompt_dir}")
79
-
80
- # Get existing seeds
81
- existing_seeds = set()
82
- for filename in os.listdir(prompt_dir):
83
- if filename.endswith('.wav'):
84
- seed = get_seed_from_filename(filename)
85
- if seed is not None:
86
- existing_seeds.add(seed)
87
-
88
- if len(existing_seeds) >= num_variations:
89
- print(f"All {num_variations} variations already exist in {prompt_dir}, skipping...")
90
- continue
91
-
92
- # Generate new random seeds that haven't been used yet
93
- needed_variations = num_variations - len(existing_seeds)
94
- new_seeds = set()
95
- while len(new_seeds) < needed_variations:
96
- seed = random.randint(*seed_range)
97
- if seed not in existing_seeds and seed not in new_seeds:
98
- new_seeds.add(seed)
99
-
100
- print(f"Generating {needed_variations} new variations using {num_processes} parallel processes...")
101
- print(f"Using seeds: {sorted(new_seeds)}")
102
-
103
- # Prepare arguments for parallel processing
104
- args_list = [(description, negative_prompt, seed, prompt_dir) for seed in new_seeds]
105
-
106
- # Use multiprocessing to distribute the work
107
- with mp.Pool(processes=num_processes) as pool:
108
- pool.map(generate_audio, args_list)
 
6
  import torch
7
  import soundfile as sf
8
  import random
 
9
  from diffusers import StableAudioPipeline
10
+ from utils.audio_utils import process_audio_generations
11
 
12
+ # Default negative prompt
13
+ NEGATIVE_PROMPT = "Low quality, noise, distortion, low fidelity"
 
 
 
 
14
 
15
  def generate_audio(args):
16
  description, negative_prompt, seed, prompt_dir = args
 
42
  print(f"Saving audio to: {file_path}")
43
  sf.write(file_path, output, pipe.vae.sampling_rate)
44
 
45
+ def prepare_args(description, seeds, prompt_dir):
46
+ """Prepare arguments for the generate_audio function"""
47
+ return [(description, NEGATIVE_PROMPT, seed, prompt_dir) for seed in seeds]
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ if __name__ == '__main__':
50
+ process_audio_generations(
51
+ descriptions=sys.argv[1:],
52
+ model_name='sa',
53
+ generate_fn=generate_audio,
54
+ prepare_args_fn=prepare_args
55
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audio/tango_audio.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import sys
5
+ import os
6
+ import random
7
+ import soundfile as sf
8
+ from tango import Tango
9
+ from utils.audio_utils import process_audio_generations
10
+
11
+ def generate_audio(args):
12
+ description, seed, prompt_dir = args
13
+ wav_path = os.path.join(prompt_dir, f"{seed}.wav")
14
+ if os.path.exists(wav_path):
15
+ print(f"Skipping seed {seed} - file already exists")
16
+ return
17
+
18
+ # Initialize model for this process
19
+ tango = Tango("declare-lab/tango")
20
+
21
+ # Set random seed for reproducibility
22
+ random.seed(seed)
23
+
24
+ # Generate audio
25
+ audio = tango.generate(description)
26
+
27
+ # Save the audio
28
+ file_path = os.path.join(prompt_dir, f"{seed}.wav")
29
+ print(f"Saving audio to: {file_path}")
30
+ sf.write(file_path, audio, samplerate=16000)
31
+
32
+ def prepare_args(description, seeds, prompt_dir):
33
+ """Prepare arguments for the generate_audio function"""
34
+ return [(description, seed, prompt_dir) for seed in seeds]
35
+
36
+ if __name__ == '__main__':
37
+ process_audio_generations(
38
+ descriptions=sys.argv[1:],
39
+ model_name='tango',
40
+ generate_fn=generate_audio,
41
+ prepare_args_fn=prepare_args
42
+ )
caption/jtp2.py CHANGED
@@ -447,5 +447,3 @@ def create_tags(threshold):
447
 
448
  if __name__ == "__main__":
449
  process_directory(args.directory, args.threshold, args.cpu, args.no_grad)
450
-
451
-
 
447
 
448
  if __name__ == "__main__":
449
  process_directory(args.directory, args.threshold, args.cpu, args.no_grad)
 
 
caption/wdv3.py CHANGED
@@ -395,4 +395,3 @@ if __name__ == "__main__":
395
  print(f"Available models: {list(MODEL_REPO_MAP.keys())}")
396
  raise ValueError(f"Unknown model name '{opts.model}'")
397
  main(opts)
398
-
 
395
  print(f"Available models: {list(MODEL_REPO_MAP.keys())}")
396
  raise ValueError(f"Unknown model name '{opts.model}'")
397
  main(opts)
 
utils/audio_utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import sys
6
+ import random
7
+ import multiprocessing as mp
8
+
9
+ def get_seed_from_filename(filename):
10
+ """Extract seed from filename like '12345.wav'"""
11
+ try:
12
+ return int(filename.split('.')[0])
13
+ except:
14
+ return None
15
+
16
+ def setup_generation_dir(base_output_dir, description):
17
+ """Setup and return the directory for a given prompt"""
18
+ os.makedirs(base_output_dir, exist_ok=True)
19
+
20
+ # Create a safe folder name from the description
21
+ folder_name = description.replace(' ', '_').replace('/', '_').replace('\\', '_')
22
+ folder_name = ''.join(c for c in folder_name if c.isalnum() or c in '_-')
23
+ prompt_dir = os.path.join(base_output_dir, folder_name)
24
+ os.makedirs(prompt_dir, exist_ok=True)
25
+ return prompt_dir
26
+
27
+ def get_existing_seeds(prompt_dir):
28
+ """Get set of seeds from existing wav files in directory"""
29
+ existing_seeds = set()
30
+ for filename in os.listdir(prompt_dir):
31
+ if filename.endswith('.wav'):
32
+ seed = get_seed_from_filename(filename)
33
+ if seed is not None:
34
+ existing_seeds.add(seed)
35
+ return existing_seeds
36
+
37
+ def generate_new_seeds(needed_variations, existing_seeds, seed_range=(0, 1000000)):
38
+ """Generate new unique random seeds"""
39
+ new_seeds = set()
40
+ while len(new_seeds) < needed_variations:
41
+ seed = random.randint(*seed_range)
42
+ if seed not in existing_seeds and seed not in new_seeds:
43
+ new_seeds.add(seed)
44
+ return new_seeds
45
+
46
+ def process_audio_generations(descriptions, model_name, generate_fn, prepare_args_fn, num_variations=25, num_processes=3):
47
+ """
48
+ Shared logic for processing audio generations across different models.
49
+
50
+ Args:
51
+ descriptions: List of text prompts to generate audio for
52
+ model_name: Name of the model (used for output directory)
53
+ generate_fn: Function that generates a single audio sample
54
+ prepare_args_fn: Function that prepares arguments for generate_fn
55
+ num_variations: Number of variations to generate per prompt
56
+ num_processes: Number of parallel processes to use
57
+ """
58
+ # Set start method for multiprocessing
59
+ mp.set_start_method('spawn', force=True)
60
+
61
+ if not descriptions:
62
+ print('At least one prompt should be provided')
63
+ sys.exit(1)
64
+
65
+ # Base output directory
66
+ base_output_dir = f'generated_audio/{model_name}'
67
+
68
+ for description in descriptions:
69
+ prompt_dir = setup_generation_dir(base_output_dir, description)
70
+ print(f"\nGenerating variations for prompt: {description}")
71
+ print(f"Saving in directory: {prompt_dir}")
72
+
73
+ # Get existing seeds and check if we need to generate more
74
+ existing_seeds = get_existing_seeds(prompt_dir)
75
+ if len(existing_seeds) >= num_variations:
76
+ print(f"All {num_variations} variations already exist in {prompt_dir}, skipping...")
77
+ continue
78
+
79
+ # Generate new random seeds that haven't been used yet
80
+ needed_variations = num_variations - len(existing_seeds)
81
+ new_seeds = generate_new_seeds(needed_variations, existing_seeds)
82
+
83
+ print(f"Generating {needed_variations} new variations using {num_processes} parallel processes...")
84
+ print(f"Using seeds: {sorted(new_seeds)}")
85
+
86
+ # Prepare arguments for parallel processing
87
+ args_list = prepare_args_fn(description, new_seeds, prompt_dir)
88
+
89
+ # Use multiprocessing to distribute the work
90
+ with mp.Pool(processes=num_processes) as pool:
91
+ pool.map(generate_fn, args_list)