|
import torch |
|
from diffusers import AnimateDiffPipeline, LCMScheduler, MotionAdapter |
|
from diffusers.utils import export_to_video |
|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
import base64 |
|
import tempfile |
|
import os |
|
import threading |
|
import traceback |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
pipe = None |
|
app.config['temp_response'] = None |
|
app.config['generation_thread'] = None |
|
|
|
def download_pipeline(): |
|
global pipe |
|
try: |
|
print('Downloading the model weights') |
|
|
|
adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=torch.float16) |
|
pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter, torch_dtype=torch.float16) |
|
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear") |
|
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora") |
|
pipe.set_adapters(["lcm-lora"], [0.8]) |
|
pipe.enable_vae_slicing() |
|
pipe.enable_model_cpu_offload() |
|
return True |
|
except Exception as e: |
|
print(f"Error downloading pipeline: {e}") |
|
return False |
|
|
|
def generate_and_export_animation(prompt): |
|
global pipe |
|
|
|
|
|
if pipe is None: |
|
if not download_pipeline(): |
|
return None, "Failed to initialize animation pipeline" |
|
|
|
try: |
|
|
|
print('Generating Video frames') |
|
output = pipe( |
|
prompt=prompt, |
|
negative_prompt="bad quality, worse quality, low resolution, blur", |
|
num_frames=16, |
|
guidance_scale=2.0, |
|
num_inference_steps=6 |
|
) |
|
print('Video frames generated') |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file: |
|
temp_video_path = temp_file.name |
|
print('temp_video_path', temp_video_path) |
|
export_to_video(output.frames[0], temp_video_path) |
|
|
|
with open(temp_video_path, 'rb') as video_file: |
|
video_binary = video_file.read() |
|
|
|
video_base64 = base64.b64encode(video_binary).decode('utf-8') |
|
os.remove(temp_video_path) |
|
response_data = {'video_base64': '','status':None} |
|
response_data['video_base64'] = video_base64 |
|
print('response_data',response_data) |
|
return response_data |
|
|
|
except Exception as e: |
|
print(f"Error generating animation: {e}") |
|
|
|
traceback.print_exc() |
|
return jsonify({"error": f"Failed to generate animation: {str(e)}"}), 500 |
|
|
|
def background(prompt): |
|
with app.app_context(): |
|
temp_response = generate_and_export_animation(prompt) |
|
|
|
app.config['temp_response'] = temp_response |
|
|
|
|
|
@app.route('/run', methods=['POST']) |
|
def handle_animation_request(): |
|
|
|
prompt = request.form.get('prompt') |
|
if prompt: |
|
generation_thread = threading.Thread(target=background, args=(prompt,)) |
|
app.config['generation_thread'] = generation_thread |
|
generation_thread.start() |
|
response_data = {"message": "Video generation started", "process_id": generation_thread.ident} |
|
|
|
return jsonify(response_data) |
|
else: |
|
return jsonify({"message": "Please provide a valid text prompt."}), 400 |
|
|
|
@app.route('/status', methods=['GET']) |
|
def check_animation_status(): |
|
process_id = request.args.get('process_id',None) |
|
|
|
if process_id: |
|
generation_thread = app.config.get('generation_thread') |
|
if generation_thread and generation_thread.is_alive(): |
|
return jsonify({"status": "in_progress"}), 200 |
|
elif app.config.get('temp_response'): |
|
print('final',app.config.get('temp_response')) |
|
|
|
final_response = app.config['temp_response'] |
|
final_response['status'] = 'completed' |
|
return jsonify(final_response) |
|
|
|
if __name__ == '__main__': |
|
app.run(debug=True) |
|
|