Seed-VC-api / app.py
soiz's picture
Update app.py
7771f63 verified
import os
from flask import Flask, request, jsonify, send_file, Response
import torch
import torchaudio
import librosa
import yaml
import numpy as np
from pydub import AudioSegment
from modules.commons import build_model, load_checkpoint, recursive_munch
from hf_utils import load_custom_model_from_hf
from modules.campplus.DTDNN import CAMPPlus
from modules.bigvgan import bigvgan
from transformers import AutoFeatureExtractor, WhisperModel
from modules.audio import mel_spectrogram
from modules.rmvpe import RMVPE
from io import BytesIO
# Initialize Flask app
app = Flask(__name__)
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and configuration (same as in the original code)
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
"DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
"config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
config = yaml.safe_load(open(dit_config_path, 'r'))
model_params = recursive_munch(config['model_params'])
model = build_model(model_params, stage='DiT')
hop_length = config['preprocess_params']['spect_params']['hop_length']
sr = config['preprocess_params']['sr']
# Load checkpoints
model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
load_only_params=True, ignore_modules=[], is_distributed=False)
for key in model:
model[key].eval()
model[key].to(device)
model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
# Load additional models
campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
campplus_model.eval()
campplus_model.to(device)
bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
bigvgan_model.remove_weight_norm()
bigvgan_model = bigvgan_model.eval().to(device)
whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer,
'whisper_name') else "openai/whisper-small"
whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device)
del whisper_model.decoder
whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
# f0 conditioned model
dit_checkpoint_path_f0, dit_config_path_f0 = load_custom_model_from_hf("Plachta/Seed-VC",
"DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
"config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
config_f0 = yaml.safe_load(open(dit_config_path_f0, 'r'))
model_params_f0 = recursive_munch(config_f0['model_params'])
model_f0 = build_model(model_params_f0, stage='DiT')
hop_length_f0 = config_f0['preprocess_params']['spect_params']['hop_length']
sr_f0 = config_f0['preprocess_params']['sr']
# Load checkpoints for f0 model
model_f0, _, _, _ = load_checkpoint(model_f0, None, dit_checkpoint_path_f0,
load_only_params=True, ignore_modules=[], is_distributed=False)
for key in model_f0:
model_f0[key].eval()
model_f0[key].to(device)
model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
# F0 extractor
model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
rmvpe = RMVPE(model_path, is_half=False, device=device)
# Define Mel spectrogram conversion
def to_mel(x):
mel_fn_args = {
"n_fft": config['preprocess_params']['spect_params']['n_fft'],
"win_size": config['preprocess_params']['spect_params']['win_length'],
"hop_size": config['preprocess_params']['spect_params']['hop_length'],
"num_mels": config['preprocess_params']['spect_params']['n_mels'],
"sampling_rate": sr,
"fmin": 0,
"fmax": None,
"center": False
}
return mel_spectrogram(x, **mel_fn_args)
def adjust_f0_semitones(f0_sequence, n_semitones):
factor = 2 ** (n_semitones / 12)
return f0_sequence * factor
def crossfade(chunk1, chunk2, overlap):
fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
return chunk2
# Define the Flask route for voice conversion
@app.route('/convert', methods=['POST'])
def voice_conversion_api():
# Get the input files and parameters from the request
source = request.files['source']
target = request.files['target']
diffusion_steps = int(request.form['diffusion_steps'])
length_adjust = float(request.form['length_adjust'])
inference_cfg_rate = float(request.form['inference_cfg_rate'])
f0_condition = bool(request.form['f0_condition'])
auto_f0_adjust = bool(request.form['auto_f0_adjust'])
pitch_shift = int(request.form['pitch_shift'])
# Read source and target audio
source_audio = librosa.load(source, sr=sr)[0]
ref_audio = librosa.load(target, sr=sr)[0]
# Process audio
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device)
# Resample and process the audio (same as the original logic)
ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
# You can add further processing and generation logic here (same as the original code)
# At the end, create the output (this is just an example, adapt based on the real output)
output_wave = np.random.randn(44100 * 10) # Replace with actual generated wave
output_wave = (output_wave * 32768.0).astype(np.int16)
# Convert to MP3 and send the response
mp3_file = BytesIO()
AudioSegment(
output_wave.tobytes(), frame_rate=sr,
sample_width=output_wave.dtype.itemsize, channels=1
).export(mp3_file, format="mp3", bitrate="320k")
mp3_file.seek(0) # Ensure the stream is at the beginning
return send_file(mp3_file, mimetype="audio/mpeg", as_attachment=True, download_name="converted_audio.mp3")
if __name__ == "__main__":
# Run the Flask app
app.run(host='0.0.0.0', debug=True, port=7860)