ttv / app.py
Spanicin's picture
Update app.py
7efbaa8 verified
raw
history blame
4.29 kB
import torch
from diffusers import AnimateDiffPipeline, LCMScheduler, MotionAdapter
from diffusers.utils import export_to_video
from flask import Flask, request, jsonify
from flask_cors import CORS
import base64
import tempfile
import os
import threading
import traceback
app = Flask(__name__)
CORS(app)
pipe = None
app.config['temp_response'] = None
app.config['generation_thread'] = None
def download_pipeline():
global pipe
try:
print('Downloading the model weights')
# Download and initialize the animation pipeline
adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=torch.float16)
pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter, torch_dtype=torch.float16)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
pipe.set_adapters(["lcm-lora"], [0.8])
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload()
return True
except Exception as e:
print(f"Error downloading pipeline: {e}")
return False
def generate_and_export_animation(prompt):
global pipe
# Ensure the animation pipeline is initialized
if pipe is None:
if not download_pipeline():
return None, "Failed to initialize animation pipeline"
try:
# Generate animation frames
print('Generating Video frames')
output = pipe(
prompt=prompt,
negative_prompt="bad quality, worse quality, low resolution, blur",
num_frames=16,
guidance_scale=2.0,
num_inference_steps=6
)
print('Video frames generated')
# Export frames to a temporary video file
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file:
temp_video_path = temp_file.name
print('temp_video_path', temp_video_path)
export_to_video(output.frames[0], temp_video_path)
with open(temp_video_path, 'rb') as video_file:
video_binary = video_file.read()
video_base64 = base64.b64encode(video_binary).decode('utf-8')
os.remove(temp_video_path)
response_data = {'video_base64': '','status':None}
response_data['video_base64'] = video_base64
print('response_data',response_data)
return response_data
except Exception as e:
print(f"Error generating animation: {e}")
# return None, "Failed to generate animation"
traceback.print_exc() # Print exception details to console
return jsonify({"error": f"Failed to generate animation: {str(e)}"}), 500
def background(prompt):
with app.app_context():
temp_response = generate_and_export_animation(prompt)
# json_content = temp_response.get_json()
app.config['temp_response'] = temp_response
@app.route('/run', methods=['POST'])
def handle_animation_request():
prompt = request.form.get('prompt')
if prompt:
generation_thread = threading.Thread(target=background, args=(prompt,))
app.config['generation_thread'] = generation_thread
generation_thread.start()
response_data = {"message": "Video generation started", "process_id": generation_thread.ident}
return jsonify(response_data)
else:
return jsonify({"message": "Please provide a valid text prompt."}), 400
@app.route('/status', methods=['GET'])
def check_animation_status():
process_id = request.args.get('process_id',None)
if process_id:
generation_thread = app.config.get('generation_thread')
if generation_thread and generation_thread.is_alive():
return jsonify({"status": "in_progress"}), 200
elif app.config.get('temp_response'):
print('final',app.config.get('temp_response'))
# app.config['temp_response']['status'] = 'completed'
final_response = app.config['temp_response']
final_response['status'] = 'completed'
return jsonify(final_response)
if __name__ == '__main__':
app.run(debug=True)