Spaces:
Paused
Paused
from flask import Flask, request, jsonify, stream_with_context | |
import torch | |
import shutil | |
import os | |
import sys | |
from time import strftime | |
from src.utils.preprocess import CropAndExtract | |
from src.test_audio2coeff import Audio2Coeff | |
from src.facerender.animate import AnimateFromCoeff | |
from src.generate_batch import get_data | |
from src.generate_facerender_batch import get_facerender_data | |
# from src.utils.init_path import init_path | |
import tempfile | |
from openai import OpenAI | |
import elevenlabs | |
from elevenlabs import set_api_key, generate, play, clone, Voice, VoiceSettings | |
import uuid | |
import time | |
from PIL import Image | |
import moviepy.editor as mp | |
import requests | |
import json | |
import pickle | |
# from dotenv import load_dotenv | |
from concurrent.futures import ProcessPoolExecutor, as_completed, ThreadPoolExecutor | |
# Load environment variables from .env file | |
# load_dotenv() | |
# Initialize ProcessPoolExecutor for parallel processing | |
executor = ThreadPoolExecutor(max_workers=3) | |
class AnimationConfig: | |
def __init__(self, driven_audio_path, source_image_path, result_folder,pose_style,expression_scale,enhancer,still,preprocess,ref_pose_video_path, image_hardcoded): | |
self.driven_audio = driven_audio_path | |
self.source_image = source_image_path | |
self.ref_eyeblink = None | |
self.ref_pose = None | |
self.checkpoint_dir = './checkpoints' | |
self.result_dir = result_folder | |
self.pose_style = pose_style | |
self.batch_size = 2 | |
self.expression_scale = expression_scale | |
self.input_yaw = None | |
self.input_pitch = None | |
self.input_roll = None | |
self.enhancer = enhancer | |
self.background_enhancer = None | |
self.cpu = False | |
self.face3dvis = False | |
self.still = still | |
self.preprocess = preprocess | |
self.verbose = False | |
self.old_version = False | |
self.net_recon = 'resnet50' | |
self.init_path = None | |
self.use_last_fc = False | |
self.bfm_folder = './checkpoints/BFM_Fitting/' | |
self.bfm_model = 'BFM_model_front.mat' | |
self.focal = 1015. | |
self.center = 112. | |
self.camera_d = 10. | |
self.z_near = 5. | |
self.z_far = 15. | |
self.device = 'cuda' | |
self.image_hardcoded = image_hardcoded | |
app = Flask(__name__) | |
# CORS(app) | |
TEMP_DIR = None | |
start_time = None | |
app.config['temp_response'] = None | |
app.config['generation_thread'] = None | |
app.config['text_prompt'] = None | |
app.config['final_video_path'] = None | |
app.config['final_video_duration'] = None | |
# Global paths | |
dir_path = os.path.dirname(os.path.realpath(__file__)) | |
current_root_path = dir_path | |
path_of_lm_croper = os.path.join(current_root_path, 'checkpoints', 'shape_predictor_68_face_landmarks.dat') | |
path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth') | |
dir_of_BFM_fitting = os.path.join(current_root_path, 'checkpoints', 'BFM_Fitting') | |
wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth') | |
audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth') | |
audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml') | |
audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth') | |
audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml') | |
free_view_checkpoint = os.path.join(current_root_path, 'checkpoints', 'facevid2vid_00189-model.pth.tar') | |
# Function for running the actual task (using preprocessed data) | |
def process_chunk(audio_chunk, preprocessed_data, args): | |
print("Entered Process Chunk Function") | |
global audio2pose_checkpoint, audio2pose_yaml_path, audio2exp_checkpoint, audio2exp_yaml_path, wav2lip_checkpoint | |
global free_view_checkpoint | |
if args.preprocess == 'full': | |
mapping_checkpoint = os.path.join(current_root_path, 'checkpoints', 'mapping_00109-model.pth.tar') | |
facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender_still.yaml') | |
else: | |
mapping_checkpoint = os.path.join(current_root_path, 'checkpoints', 'mapping_00229-model.pth.tar') | |
facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml') | |
first_coeff_path = preprocessed_data["first_coeff_path"] | |
crop_pic_path = preprocessed_data["crop_pic_path"] | |
crop_info_path = preprocessed_data["crop_info"] | |
with open(crop_info_path , "rb") as f: | |
crop_info = pickle.load(f) | |
print("first_coeff_path",first_coeff_path) | |
print("crop_pic_path",crop_pic_path) | |
print("crop_info",crop_info) | |
torch.cuda.empty_cache() | |
batch = get_data(first_coeff_path, audio_chunk, args.device, ref_eyeblink_coeff_path=None, still=args.still) | |
audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path, | |
audio2exp_checkpoint, audio2exp_yaml_path, | |
wav2lip_checkpoint, args.device) | |
coeff_path = audio_to_coeff.generate(batch, args.result_dir, args.pose_style, ref_pose_coeff_path=None) | |
# Further processing with animate_from_coeff using the coeff_path | |
animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint, | |
facerender_yaml_path, args.device) | |
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_chunk, | |
args.batch_size, args.input_yaw, args.input_pitch, args.input_roll, | |
expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess) | |
torch.cuda.empty_cache() | |
print("Will Enter Animation") | |
result, base64_video, temp_file_path, _ = animate_from_coeff.generate(data, args.result_dir, args.source_image, crop_info, | |
enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess) | |
# video_clip = mp.VideoFileClip(temp_file_path) | |
# duration = video_clip.duration | |
app.config['temp_response'] = base64_video | |
app.config['final_video_path'] = temp_file_path | |
# app.config['final_video_duration'] = duration | |
torch.cuda.empty_cache() | |
return base64_video, temp_file_path | |
def create_temp_dir(): | |
return tempfile.TemporaryDirectory() | |
def save_uploaded_file(file, filename,TEMP_DIR): | |
print("Entered save_uploaded_file") | |
unique_filename = str(uuid.uuid4()) + "_" + filename | |
file_path = os.path.join(TEMP_DIR.name, unique_filename) | |
file.save(file_path) | |
return file_path | |
def custom_cleanup(temp_dir, exclude_dir): | |
# Iterate over the files and directories in TEMP_DIR | |
for filename in os.listdir(temp_dir): | |
file_path = os.path.join(temp_dir, filename) | |
# Skip the directory we want to exclude | |
if file_path != exclude_dir: | |
try: | |
if os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
else: | |
os.remove(file_path) | |
print(f"Deleted: {file_path}") | |
except Exception as e: | |
print(f"Failed to delete {file_path}. Reason: {e}") | |
def generate_audio(voice_cloning, voice_gender, text_prompt): | |
print("generate_audio") | |
if voice_cloning == 'no': | |
if voice_gender == 'male': | |
voice = 'echo' | |
print('Entering Audio creation using elevenlabs') | |
set_api_key('92e149985ea2732b4359c74346c3daee') | |
audio = generate(text = text_prompt, voice = "Daniel", model = "eleven_multilingual_v2",stream=True, latency=4) | |
with tempfile.NamedTemporaryFile(suffix=".mp3", prefix="text_to_speech_",dir=TEMP_DIR.name, delete=False) as temp_file: | |
for chunk in audio: | |
temp_file.write(chunk) | |
driven_audio_path = temp_file.name | |
print('driven_audio_path',driven_audio_path) | |
print('Audio file saved using elevenlabs') | |
else: | |
voice = 'nova' | |
print('Entering Audio creation using whisper') | |
response = client.audio.speech.create(model="tts-1-hd", | |
voice=voice, | |
input = text_prompt) | |
print('Audio created using whisper') | |
with tempfile.NamedTemporaryFile(suffix=".wav", prefix="text_to_speech_",dir=TEMP_DIR.name, delete=False) as temp_file: | |
driven_audio_path = temp_file.name | |
response.write_to_file(driven_audio_path) | |
print('Audio file saved using whisper') | |
elif voice_cloning == 'yes': | |
set_api_key('92e149985ea2732b4359c74346c3daee') | |
# voice = clone(name = "User Cloned Voice", | |
# files = [user_voice_path] ) | |
voice = Voice(voice_id="CEii8R8RxmB0zhAiloZg",name="Marc",settings=VoiceSettings( | |
stability=0.71, similarity_boost=0.5, style=0.0, use_speaker_boost=True),) | |
audio = generate(text = text_prompt, voice = voice, model = "eleven_multilingual_v2",stream=True, latency=4) | |
with tempfile.NamedTemporaryFile(suffix=".mp3", prefix="cloned_audio_",dir=TEMP_DIR.name, delete=False) as temp_file: | |
for chunk in audio: | |
temp_file.write(chunk) | |
driven_audio_path = temp_file.name | |
print('driven_audio_path',driven_audio_path) | |
return driven_audio_path | |
# Preprocessing step that runs only once | |
def run_preprocessing(args): | |
global path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting | |
first_frame_dir = os.path.join(args.result_dir, 'first_frame_dir') | |
os.makedirs(first_frame_dir, exist_ok=True) | |
fixed_temp_dir = "/tmp/preprocess_data" | |
os.makedirs(fixed_temp_dir, exist_ok=True) | |
preprocessed_data_path = os.path.join(fixed_temp_dir, "preprocessed_data.pkl") | |
if os.path.exists(preprocessed_data_path) and args.image_hardcoded == "yes": | |
print("Loading preprocessed data...") | |
with open(preprocessed_data_path, "rb") as f: | |
preprocessed_data = pickle.load(f) | |
print("Loaded existing preprocessed data from:", preprocessed_data_path) | |
else: | |
print("Running preprocessing...") | |
preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, args.device) | |
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(args.source_image, first_frame_dir, args.preprocess, source_image_flag=True) | |
first_coeff_new_path = os.path.join(fixed_temp_dir, os.path.basename(first_coeff_path)) | |
crop_pic_new_path = os.path.join(fixed_temp_dir, os.path.basename(crop_pic_path)) | |
crop_info_new_path = os.path.join(fixed_temp_dir, "crop_info.pkl") | |
shutil.move(first_coeff_path, first_coeff_new_path) | |
shutil.move(crop_pic_path, crop_pic_new_path) | |
with open(crop_info_new_path, "wb") as f: | |
pickle.dump(crop_info, f) | |
preprocessed_data = {"first_coeff_path": first_coeff_new_path, | |
"crop_pic_path": crop_pic_new_path, | |
"crop_info": crop_info_new_path} | |
with open(preprocessed_data_path, "wb") as f: | |
pickle.dump(preprocessed_data, f) | |
print(f"Preprocessed data saved to: {preprocessed_data_path}") | |
return preprocessed_data | |
def split_audio(audio_path, chunk_duration=5): | |
audio_clip = mp.AudioFileClip(audio_path) | |
total_duration = audio_clip.duration | |
audio_chunks = [] | |
for start_time in range(0, int(total_duration), chunk_duration): | |
end_time = min(start_time + chunk_duration, total_duration) | |
chunk = audio_clip.subclip(start_time, end_time) | |
with tempfile.NamedTemporaryFile(suffix=f"_chunk_{start_time}-{end_time}.wav", prefix="audio_chunk_", dir=TEMP_DIR.name, delete=False) as temp_file: | |
chunk_path = temp_file.name | |
chunk.write_audiofile(chunk_path) | |
audio_chunks.append((start_time, chunk_path)) | |
return audio_chunks | |
# Generator function to yield chunk results as they are processed | |
def generate_chunks(audio_chunks, preprocessed_data, args): | |
future_to_chunk = {executor.submit(process_chunk, chunk[1], preprocessed_data, args): chunk[0] for chunk in audio_chunks} | |
for future in as_completed(future_to_chunk): | |
idx = future_to_chunk[future] # Get the original chunk that was processed | |
try: | |
base64_video, temp_file_path = future.result() # Get the result of the completed task | |
yield json.dumps({'start_time': idx, 'path': temp_file_path}).encode('utf-8') | |
except Exception as e: | |
yield f"Task for chunk {idx} failed: {e}\n" | |
def parallel_processing(): | |
global start_time | |
start_time = time.time() | |
global TEMP_DIR | |
global audio_chunks | |
TEMP_DIR = create_temp_dir() | |
print('request:',request.method) | |
try: | |
if request.method == 'POST': | |
# source_image = request.files['source_image'] | |
image_path = '/home/user/app/images/out.jpg' | |
source_image = Image.open(image_path) | |
text_prompt = request.form['text_prompt'] | |
print('Input text prompt: ',text_prompt) | |
text_prompt = text_prompt.strip() | |
if not text_prompt: | |
return jsonify({'error': 'Input text prompt cannot be blank'}), 400 | |
voice_cloning = request.form.get('voice_cloning', 'yes') | |
image_hardcoded = request.form.get('image_hardcoded', 'no') | |
chat_model_used = request.form.get('chat_model_used', 'openai') | |
target_language = request.form.get('target_language', 'original_text') | |
print('target_language',target_language) | |
pose_style = int(request.form.get('pose_style', 1)) | |
expression_scale = float(request.form.get('expression_scale', 1)) | |
enhancer = request.form.get('enhancer', None) | |
voice_gender = request.form.get('voice_gender', 'male') | |
still_str = request.form.get('still', 'False') | |
still = still_str.lower() == 'false' | |
print('still', still) | |
preprocess = request.form.get('preprocess', 'crop') | |
print('preprocess selected: ',preprocess) | |
# ref_pose_video = request.files.get('ref_pose', None) | |
app.config['text_prompt'] = text_prompt | |
print('Final output text prompt using openai: ',text_prompt) | |
source_image_path = save_uploaded_file(source_image, 'source_image.png',TEMP_DIR) | |
print(source_image_path) | |
driven_audio_path = generate_audio(voice_cloning, voice_gender, text_prompt) | |
save_dir = tempfile.mkdtemp(dir=TEMP_DIR.name) | |
result_folder = os.path.join(save_dir, "results") | |
os.makedirs(result_folder, exist_ok=True) | |
ref_pose_video_path = None | |
# if ref_pose_video: | |
# with tempfile.NamedTemporaryFile(suffix=".mp4", prefix="ref_pose_",dir=TEMP_DIR.name, delete=False) as temp_file: | |
# ref_pose_video_path = temp_file.name | |
# ref_pose_video.save(ref_pose_video_path) | |
# print('ref_pose_video_path',ref_pose_video_path) | |
except Exception as e: | |
app.logger.error(f"An error occurred: {e}") | |
return jsonify({'status': 'error', 'message': str(e)}), 500 | |
args = AnimationConfig(driven_audio_path=driven_audio_path, source_image_path=source_image_path, result_folder=result_folder, pose_style=pose_style, expression_scale=expression_scale,enhancer=enhancer,still=still,preprocess=preprocess,ref_pose_video_path=ref_pose_video_path, image_hardcoded=image_hardcoded) | |
preprocessed_data = run_preprocessing(args) | |
chunk_duration = 5 | |
print(f"Splitting the audio into {chunk_duration}-second chunks...") | |
audio_chunks = split_audio(driven_audio_path, chunk_duration=chunk_duration) | |
print(f"Audio has been split into {len(audio_chunks)} chunks: {audio_chunks}") | |
try: | |
return stream_with_context(generate_chunks(audio_chunks, preprocessed_data, args)) | |
# base64_video, temp_file_path, duration = process_chunk(driven_audio_path, preprocessed_data, args) | |
except Exception as e: | |
return jsonify({'status': 'error', 'message': str(e)}), 500 | |
def health_status(): | |
response = {"online": "true"} | |
return jsonify(response) | |
if __name__ == '__main__': | |
app.run(debug=True) | |