File size: 2,093 Bytes
95d158d fbba87c 98740cc e3f2781 51138b6 95d158d 98740cc 51138b6 8b28fcd 3309557 8b28fcd 3309557 8b28fcd 245ede6 95d158d 245ede6 fbba87c 245ede6 fbba87c 73a8db3 245ede6 fbba87c 245ede6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import argparse
# Add the project root directory to the Python path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, project_root)
import torch
import random
from audiocraft.models import AudioGen
from audiocraft.data.audio import audio_write
from utils.audio_utils import process_audio_generations
def generate_audio(args):
description, seed, prompt_dir = args
wav_path = os.path.join(prompt_dir, f"{seed}.wav")
if os.path.exists(wav_path):
print(f"Skipping seed {seed} - file already exists")
return
# Initialize model for this process
model = AudioGen.get_pretrained('facebook/audiogen-medium')
model.set_generation_params(duration=5) # generate 5 seconds
# Set random seed for reproducibility
torch.manual_seed(seed)
wav = model.generate([description]) # Generate one at a time
file_path = os.path.join(prompt_dir, str(seed)) # audio_write will add .wav extension
print(f"Saving audio to: {file_path}.wav")
# Will save with loudness normalization at -14 db LUFS
audio_write(file_path, wav[0].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
def prepare_args(description, seeds, prompt_dir):
"""Prepare arguments for the generate_audio function"""
return [(description, seed, prompt_dir) for seed in seeds]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate audio samples using AudioGen')
parser.add_argument('prompts', nargs='+', help='One or more text prompts to generate audio from')
parser.add_argument('--batch_size', type=int, default=25, help='Number of variations to generate per prompt (default: 25)')
args = parser.parse_args()
process_audio_generations(
descriptions=args.prompts,
model_name='agen',
generate_fn=generate_audio,
prepare_args_fn=prepare_args,
num_variations=args.batch_size # Use the batch_size from command line
)
|