Spaces:
Paused
Paused
from flask import Flask, request, jsonify | |
import torch | |
import shutil | |
import os | |
import sys | |
from argparse import ArgumentParser | |
from time import strftime | |
from argparse import Namespace | |
from src.utils.preprocess import CropAndExtract | |
from src.test_audio2coeff import Audio2Coeff | |
from src.facerender.animate import AnimateFromCoeff | |
from src.generate_batch import get_data | |
from src.generate_facerender_batch import get_facerender_data | |
# from src.utils.init_path import init_path | |
import tempfile | |
from openai import OpenAI | |
import threading | |
import elevenlabs | |
from elevenlabs import set_api_key, generate, play, clone | |
from flask_cors import CORS, cross_origin | |
# from flask_swagger_ui import get_swaggerui_blueprint | |
import uuid | |
import time | |
start_time = time.time() | |
class AnimationConfig: | |
def __init__(self, driven_audio_path, source_image_path, result_folder,pose_style,expression_scale,enhancer,still,preprocess,ref_pose_video_path): | |
self.driven_audio = driven_audio_path | |
self.source_image = source_image_path | |
self.ref_eyeblink = None | |
self.ref_pose = ref_pose_video_path | |
self.checkpoint_dir = './checkpoints' | |
self.result_dir = result_folder | |
self.pose_style = pose_style | |
self.batch_size = 2 | |
self.expression_scale = expression_scale | |
self.input_yaw = None | |
self.input_pitch = None | |
self.input_roll = None | |
self.enhancer = enhancer | |
self.background_enhancer = None | |
self.cpu = False | |
self.face3dvis = False | |
self.still = still | |
self.preprocess = preprocess | |
self.verbose = False | |
self.old_version = False | |
self.net_recon = 'resnet50' | |
self.init_path = None | |
self.use_last_fc = False | |
self.bfm_folder = './checkpoints/BFM_Fitting/' | |
self.bfm_model = 'BFM_model_front.mat' | |
self.focal = 1015. | |
self.center = 112. | |
self.camera_d = 10. | |
self.z_near = 5. | |
self.z_far = 15. | |
self.device = 'cpu' | |
app = Flask(__name__) | |
CORS(app) | |
TEMP_DIR = None | |
app.config['temp_response'] = None | |
app.config['generation_thread'] = None | |
app.config['text_prompt'] = None | |
app.config['final_video_path'] = None | |
def main(args): | |
pic_path = args.source_image | |
audio_path = args.driven_audio | |
save_dir = args.result_dir | |
pose_style = args.pose_style | |
device = args.device | |
batch_size = args.batch_size | |
input_yaw_list = args.input_yaw | |
input_pitch_list = args.input_pitch | |
input_roll_list = args.input_roll | |
ref_eyeblink = args.ref_eyeblink | |
ref_pose = args.ref_pose | |
preprocess = args.preprocess | |
dir_path = os.path.dirname(os.path.realpath(__file__)) | |
current_root_path = dir_path | |
print('current_root_path ',current_root_path) | |
# sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess) | |
path_of_lm_croper = os.path.join(current_root_path, args.checkpoint_dir, 'shape_predictor_68_face_landmarks.dat') | |
path_of_net_recon_model = os.path.join(current_root_path, args.checkpoint_dir, 'epoch_20.pth') | |
dir_of_BFM_fitting = os.path.join(current_root_path, args.checkpoint_dir, 'BFM_Fitting') | |
wav2lip_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'wav2lip.pth') | |
audio2pose_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2pose_00140-model.pth') | |
audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml') | |
audio2exp_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2exp_00300-model.pth') | |
audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml') | |
free_view_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'facevid2vid_00189-model.pth.tar') | |
if preprocess == 'full': | |
mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00109-model.pth.tar') | |
facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender_still.yaml') | |
else: | |
mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00229-model.pth.tar') | |
facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml') | |
# preprocess_model = CropAndExtract(sadtalker_paths, device) | |
#init model | |
print(path_of_net_recon_model) | |
preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device) | |
# audio_to_coeff = Audio2Coeff(sadtalker_paths, device) | |
audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path, | |
audio2exp_checkpoint, audio2exp_yaml_path, | |
wav2lip_checkpoint, device) | |
# animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device) | |
animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint, | |
facerender_yaml_path, device) | |
first_frame_dir = os.path.join(save_dir, 'first_frame_dir') | |
os.makedirs(first_frame_dir, exist_ok=True) | |
# first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\ | |
# source_image_flag=True, pic_size=args.size) | |
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess, source_image_flag=True) | |
print('first_coeff_path ',first_coeff_path) | |
print('crop_pic_path ',crop_pic_path) | |
if first_coeff_path is None: | |
print("Can't get the coeffs of the input") | |
return | |
if ref_eyeblink is not None: | |
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0] | |
ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname) | |
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True) | |
# ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False) | |
ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir) | |
else: | |
ref_eyeblink_coeff_path=None | |
print('ref_eyeblink_coeff_path',ref_eyeblink_coeff_path) | |
if ref_pose is not None: | |
if ref_pose == ref_eyeblink: | |
ref_pose_coeff_path = ref_eyeblink_coeff_path | |
else: | |
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0] | |
ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname) | |
os.makedirs(ref_pose_frame_dir, exist_ok=True) | |
# ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False) | |
ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir) | |
else: | |
ref_pose_coeff_path=None | |
print('ref_eyeblink_coeff_path',ref_pose_coeff_path) | |
batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still) | |
coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path) | |
if args.face3dvis: | |
from src.face3d.visualize import gen_composed_video | |
gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4')) | |
# data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, | |
# batch_size, input_yaw_list, input_pitch_list, input_roll_list, | |
# expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size) | |
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, | |
batch_size, input_yaw_list, input_pitch_list, input_roll_list, | |
expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess) | |
# result, base64_video,temp_file_path= animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \ | |
# enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size) | |
result, base64_video,temp_file_path = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \ | |
enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess) | |
print('The generated video is named:') | |
app.config['temp_response'] = base64_video | |
app.config['final_video_path'] = temp_file_path | |
return base64_video, temp_file_path | |
# shutil.move(result, save_dir+'.mp4') | |
if not args.verbose: | |
shutil.rmtree(save_dir) | |
def create_temp_dir(): | |
return tempfile.TemporaryDirectory() | |
def save_uploaded_file(file, filename,TEMP_DIR): | |
unique_filename = str(uuid.uuid4()) + "_" + filename | |
file_path = os.path.join(TEMP_DIR.name, unique_filename) | |
file.save(file_path) | |
return file_path | |
client = OpenAI(api_key="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") | |
def translate_text(text_prompt, target_language): | |
response = client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[{"role": "system", "content": "You are a helpful language translator assistant."}, | |
{"role": "user", "content": f"Translate completely without hallucination, end to end and the ouput should just be the translation of the text prompt and nothing else, and give the following text to {target_language} language and the text is: {text_prompt}"}, | |
], | |
max_tokens = len(text_prompt) + 200 # Use the length of the input text | |
# temperature=0.3, | |
# stop=["Translate:", "Text:"] | |
) | |
return response | |
def generate_video(): | |
global TEMP_DIR | |
TEMP_DIR = create_temp_dir() | |
print('request:',request.method) | |
try: | |
if request.method == 'POST': | |
source_image = request.files['source_image'] | |
text_prompt = request.form['text_prompt'] | |
print('Input text prompt: ',text_prompt) | |
voice_cloning = request.form.get('voice_cloning', 'no') | |
target_language = request.form.get('target_language', 'original_text') | |
print('target_language',target_language) | |
pose_style = int(request.form.get('pose_style', 1)) | |
expression_scale = float(request.form.get('expression_scale', 1)) | |
enhancer = request.form.get('enhancer', None) | |
voice_gender = request.form.get('voice_gender', 'male') | |
still_str = request.form.get('still', 'False') | |
still = still_str.lower() == 'true' | |
print('still', still) | |
preprocess = request.form.get('preprocess', 'crop') | |
print('preprocess selected: ',preprocess) | |
ref_pose_video = request.files.get('ref_pose', None) | |
if target_language != 'original_text': | |
response = translate_text(text_prompt, target_language) | |
# response = await translate_text_async(text_prompt, target_language) | |
text_prompt = response.choices[0].message.content.strip() | |
app.config['text_prompt'] = text_prompt | |
print('Final text prompt: ',text_prompt) | |
source_image_path = save_uploaded_file(source_image, 'source_image.png',TEMP_DIR) | |
print(source_image_path) | |
# driven_audio_path = await voice_cloning_async(voice_cloning, voice_gender, text_prompt, user_voice) | |
if voice_cloning == 'no': | |
if voice_gender == 'male': | |
voice = 'echo' | |
print('Entering Audio creation using elevenlabs') | |
set_api_key("92e149985ea2732b4359c74346c3daee") | |
audio = generate(text = text_prompt, voice = "Daniel", model = "eleven_multilingual_v2",stream=True, latency=4) | |
with tempfile.NamedTemporaryFile(suffix=".mp3", prefix="text_to_speech_",dir=TEMP_DIR.name, delete=False) as temp_file: | |
for chunk in audio: | |
temp_file.write(chunk) | |
driven_audio_path = temp_file.name | |
print('driven_audio_path',driven_audio_path) | |
print('Audio file saved using elevenlabs') | |
else: | |
voice = 'nova' | |
print('Entering Audio creation using whisper') | |
response = client.audio.speech.create(model="tts-1-hd", | |
voice=voice, | |
input = text_prompt) | |
print('Audio created using whisper') | |
with tempfile.NamedTemporaryFile(suffix=".wav", prefix="text_to_speech_",dir=TEMP_DIR.name, delete=False) as temp_file: | |
driven_audio_path = temp_file.name | |
response.write_to_file(driven_audio_path) | |
print('Audio file saved using whisper') | |
elif voice_cloning == 'yes': | |
user_voice = request.files['user_voice'] | |
with tempfile.NamedTemporaryFile(suffix=".wav", prefix="user_voice_",dir=TEMP_DIR.name, delete=False) as temp_file: | |
user_voice_path = temp_file.name | |
user_voice.save(user_voice_path) | |
print('user_voice_path',user_voice_path) | |
set_api_key("92e149985ea2732b4359c74346c3daee") | |
voice = clone(name = "User Cloned Voice", | |
files = [user_voice_path] ) | |
audio = generate(text = text_prompt, voice = voice, model = "eleven_multilingual_v2",stream=True, latency=4) | |
with tempfile.NamedTemporaryFile(suffix=".mp3", prefix="cloned_audio_",dir=TEMP_DIR.name, delete=False) as temp_file: | |
for chunk in audio: | |
temp_file.write(chunk) | |
driven_audio_path = temp_file.name | |
print('driven_audio_path',driven_audio_path) | |
# elevenlabs.save(audio, driven_audio_path) | |
save_dir = tempfile.mkdtemp(dir=TEMP_DIR.name) | |
result_folder = os.path.join(save_dir, "results") | |
os.makedirs(result_folder, exist_ok=True) | |
ref_pose_video_path = None | |
if ref_pose_video: | |
with tempfile.NamedTemporaryFile(suffix=".mp4", prefix="ref_pose_",dir=TEMP_DIR.name, delete=False) as temp_file: | |
ref_pose_video_path = temp_file.name | |
ref_pose_video.save(ref_pose_video_path) | |
print('ref_pose_video_path',ref_pose_video_path) | |
except Exception as e: | |
app.logger.error(f"An error occurred: {e}") | |
return "An error occurred", 500 | |
# Example of using the class with some hypothetical paths | |
args = AnimationConfig(driven_audio_path=driven_audio_path, source_image_path=source_image_path, result_folder=result_folder, pose_style=pose_style, expression_scale=expression_scale, enhancer=enhancer,still=still,preprocess=preprocess,ref_pose_video_path=ref_pose_video_path) | |
if torch.cuda.is_available() and not args.cpu: | |
args.device = "cuda" | |
else: | |
args.device = "cpu" | |
generation_thread = threading.Thread(target=main, args=(args,)) | |
app.config['generation_thread'] = generation_thread | |
generation_thread.start() | |
response_data = {"message": "Video generation started", | |
"process_id": generation_thread.ident} | |
return jsonify(response_data) | |
# base64_video = main(args) | |
# return jsonify({"base64_video": base64_video}) | |
#else: | |
# return 'Unsupported HTTP method', 405 | |
def check_generation_status(): | |
global TEMP_DIR | |
response = {"base64_video": "","text_prompt":"", "status": ""} | |
process_id = request.args.get('process_id', None) | |
# process_id is required to check the status for that specific process | |
if process_id: | |
generation_thread = app.config.get('generation_thread') | |
if generation_thread and generation_thread.ident == int(process_id) and generation_thread.is_alive(): | |
return jsonify({"status": "in_progress"}), 200 | |
elif app.config.get('temp_response'): | |
# app.config['temp_response']['status'] = 'completed' | |
final_response = app.config['temp_response'] | |
response["base64_video"] = final_response | |
response["text_prompt"] = app.config.get('text_prompt') | |
response["status"] = "completed" | |
final_video_path = app.config['final_video_path'] | |
print('final_video_path',final_video_path) | |
if final_video_path and os.path.exists(final_video_path): | |
os.remove(final_video_path) | |
print("Deleted video file:", final_video_path) | |
TEMP_DIR.cleanup() | |
# print("Temporary Directory:", TEMP_DIR.name) | |
# if TEMP_DIR: | |
# print("Contents of Temporary Directory:") | |
# for filename in os.listdir(TEMP_DIR.name): | |
# print(filename) | |
# else: | |
# print("Temporary Directory is None or already cleaned up.") | |
end_time = time.time() | |
total_time = round(end_time - start_time, 2) | |
print("Total time taken for execution:", total_time, " seconds") | |
return jsonify(response) | |
return jsonify({"error":"No process id provided"}) | |
def health_status(): | |
response = {"online": "true"} | |
return jsonify(response) | |
if __name__ == '__main__': | |
app.run(debug=True) |