print("importing runpod") import runpod print("importing requests") import requests print("importing generate_wav") from voice_generation import generate_wav print("importing boto3") import boto3 print("importing os") import os print("importing uuid") import uuid print("importing pydub") from pydub import AudioSegment from vocalsplit.inference import main as split_audio import time print("setting up environment variables") AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID') AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY') models = { 'kanye': 'weights/kanye.pth', 'rose-bp': 'weights/rose-bp.pth', 'jungkook': 'weights/jungkook.pth', 'iu': 'weights/iu.pth', 'drake': 'weights/drake.pth', 'ariana-grande': 'weights/ariana-grande.pth' } print('run handler. Removed 2nd gen') def combine_audio(voice_path, instrumental_path): audio1 = AudioSegment.from_file(instrumental_path, format="mp3") audio2 = AudioSegment.from_file(voice_path, format="mp3") length = max(len(audio1), len(audio2)) audio1 = audio1 + AudioSegment.silent(duration=length - len(audio1)) audio2 = audio2 + AudioSegment.silent(duration=length - len(audio2)) combined = audio1.overlay(audio2) combined.export("combined.mp3", format="mp3") def upload_file_to_s3(local_file_path, s3_file_path): bucket_name = 'voice-gen-audios' s3 = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY) try: s3.upload_file(local_file_path, bucket_name, s3_file_path) return {"url": f"https://{bucket_name}.s3.eu-north-1.amazonaws.com/{s3_file_path}"} except boto3.exceptions.S3UploadFailedError as e: return {"error": f"failed to upload file {local_file_path} to s3 as {s3_file_path}"} def clean_up_files(remove_voice_model=False): files = [ "song.mp3", "song_Instruments.wav", "song_Vocals.wav", "output_voice.wav", "combined.mp3", ] if remove_voice_model: files.append("voice_model.pth") for file in files: try: os.remove(file) except FileNotFoundError: return {"error": f"failed to remove file {file}"} return {"success": "files removed successfully"} def get_voice_model(event): voice_model_id = event["input"].get("voice_model_id", "") voice_model_url = event["input"].get("voice_model_url", "") if not voice_model_url and not voice_model_id: return {"error": "voice_model_url or voice_model_id is required"} if voice_model_id and voice_model_id not in models: return {"error": "model not found in pre-loaded models"} if voice_model_id: return {"model_path": models[voice_model_id]} print("downloading voice_model") voice_model_response = requests.get(voice_model_url) if voice_model_response.status_code != 200: return {"error": f"failed to download voice_model, error: {voice_model_response.text}"} with open("voice_model.pth", "wb") as f: f.write(voice_model_response.content) return {"model_path": "voice_model.pth"} def handler(event): print(event) file_id = str(uuid.uuid4()) user_id = event["input"].get("user_id", "not provided") if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY: return {"error": "AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are missing from environment variables"} voice_model = get_voice_model(event) if "error" in voice_model: return voice_model.get("error") song_url = event["input"].get("song_url", "") if song_url == "": return {"error": "voice_url is required"} song_file = requests.get(song_url) if song_file.status_code != 200: return {"error": "failed to download song_file"} with open("song.mp3", "wb") as f: f.write(song_file.content) splitting_start = time.time() # remove after testing split_audio("song.mp3") splitting_end = time.time() # remove after testing time_taken_splitting = splitting_end - splitting_start # remove after testing print(f"splitting took {time_taken_splitting} seconds") # remove after testing if not os.path.exists("song_Instruments.wav") or not os.path.exists("song_Vocals.wav"): return {"error": "failed to split song"} song_instruments = upload_file_to_s3("song_Instruments.wav", f"{file_id}-split-instruments.wav") song_vocals = upload_file_to_s3("song_Vocals.wav", f"{file_id}-split-voice.wav") if "error" in song_instruments: return song_instruments.get("error") if "error" in song_vocals: return song_vocals.get("error") gemeration_start = time.time() # remove after testing generation = generate_wav( audio_file='song_Vocals.wav', method='pm', index_rate=0.6, output_file='output_voice.wav', model_path=voice_model.get("model_path") ) generation_end = time.time() # remove after testing time_taken_generation = generation_end - gemeration_start # remove after testing print(f"generation took {time_taken_generation} seconds") # remove after testing if "error" in generation: return generation.get("error") combine_audio("output_voice.wav", "song_Instruments.wav") if not os.path.exists("combined.mp3"): return {"error": "failed to combine audio"} combined = upload_file_to_s3("combined.mp3", f"{file_id}.mp3") output_voice = upload_file_to_s3("output_voice.wav", f"{file_id}-generated-voice.wav") if combined_error := combined.get("error"): return combined_error if output_voice_error := output_voice.get("error"): return output_voice_error combined_url = combined.get("url") output_voice_url = output_voice.get("url") need_to_remove_voice_model = False if voice_model.get("model_path") == "voice_model.pth": need_to_remove_voice_model = True cleanup_result = clean_up_files(need_to_remove_voice_model) if cleanup_error := cleanup_result.get("error"): return cleanup_error return { "combined_url": combined_url, "output_voice_url": output_voice_url, "user_id": user_id, "time_taken_splitting": time_taken_splitting, # remove after testing "time_taken_generation": time_taken_generation, # remove after testing } runpod.serverless.start({"handler": handler})