File size: 2,093 Bytes
95d158d
 
 
 
 
fbba87c
98740cc
 
 
 
 
e3f2781
51138b6
95d158d
 
98740cc
51138b6
8b28fcd
 
 
 
 
 
 
 
 
 
3309557
8b28fcd
 
 
3309557
8b28fcd
 
 
 
 
245ede6
 
 
95d158d
245ede6
fbba87c
 
 
 
 
 
245ede6
fbba87c
73a8db3
245ede6
fbba87c
 
245ede6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import os
import argparse

# Add the project root directory to the Python path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, project_root)

import torch
import random
from audiocraft.models import AudioGen
from audiocraft.data.audio import audio_write
from utils.audio_utils import process_audio_generations

def generate_audio(args):
    description, seed, prompt_dir = args
    wav_path = os.path.join(prompt_dir, f"{seed}.wav")
    if os.path.exists(wav_path):
        print(f"Skipping seed {seed} - file already exists")
        return
        
    # Initialize model for this process
    model = AudioGen.get_pretrained('facebook/audiogen-medium')
    model.set_generation_params(duration=5)  # generate 5 seconds
    
    # Set random seed for reproducibility
    torch.manual_seed(seed)
    wav = model.generate([description])  # Generate one at a time
    
    file_path = os.path.join(prompt_dir, str(seed))  # audio_write will add .wav extension
    print(f"Saving audio to: {file_path}.wav")
    # Will save with loudness normalization at -14 db LUFS
    audio_write(file_path, wav[0].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)

def prepare_args(description, seeds, prompt_dir):
    """Prepare arguments for the generate_audio function"""
    return [(description, seed, prompt_dir) for seed in seeds]

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Generate audio samples using AudioGen')
    parser.add_argument('prompts', nargs='+', help='One or more text prompts to generate audio from')
    parser.add_argument('--batch_size', type=int, default=25, help='Number of variations to generate per prompt (default: 25)')
    
    args = parser.parse_args()
    
    process_audio_generations(
        descriptions=args.prompts,
        model_name='agen',
        generate_fn=generate_audio,
        prepare_args_fn=prepare_args,
        num_variations=args.batch_size  # Use the batch_size from command line
    )