|
|
|
|
|
|
|
import sys |
|
import os |
|
import argparse |
|
|
|
|
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) |
|
sys.path.insert(0, project_root) |
|
|
|
import torch |
|
import random |
|
from audiocraft.models import AudioGen |
|
from audiocraft.data.audio import audio_write |
|
from utils.audio_utils import process_audio_generations |
|
|
|
def generate_audio(args): |
|
description, seed, prompt_dir = args |
|
wav_path = os.path.join(prompt_dir, f"{seed}.wav") |
|
if os.path.exists(wav_path): |
|
print(f"Skipping seed {seed} - file already exists") |
|
return |
|
|
|
|
|
model = AudioGen.get_pretrained('facebook/audiogen-medium') |
|
model.set_generation_params(duration=5) |
|
|
|
|
|
torch.manual_seed(seed) |
|
wav = model.generate([description]) |
|
|
|
file_path = os.path.join(prompt_dir, str(seed)) |
|
print(f"Saving audio to: {file_path}.wav") |
|
|
|
audio_write(file_path, wav[0].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) |
|
|
|
def prepare_args(description, seeds, prompt_dir): |
|
"""Prepare arguments for the generate_audio function""" |
|
return [(description, seed, prompt_dir) for seed in seeds] |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Generate audio samples using AudioGen') |
|
parser.add_argument('prompts', nargs='+', help='One or more text prompts to generate audio from') |
|
parser.add_argument('--batch_size', type=int, default=25, help='Number of variations to generate per prompt (default: 25)') |
|
|
|
args = parser.parse_args() |
|
|
|
process_audio_generations( |
|
descriptions=args.prompts, |
|
model_name='agen', |
|
generate_fn=generate_audio, |
|
prepare_args_fn=prepare_args, |
|
num_variations=args.batch_size |
|
) |
|
|