from flask import Flask, request, jsonify from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler from PIL import Image import torch import base64 from io import BytesIO from huggingface_hub import login # Authenticate with Hugging Face Hub (ensure you replace 'your_token_here') import os login(os.environ["HF_TOKEN"]) # Initialize Flask app app = Flask(__name__) # Load Hugging Face pipeline components model_id = "fyp1/sketchToImage" controlnet = ControlNetModel.from_pretrained(f"{model_id}/controlnet", torch_dtype=torch.float16) vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) scheduler = EulerAncestralDiscreteScheduler.from_pretrained(f"{model_id}/scheduler") # Initialize Stable Diffusion XL ControlNet Pipeline pipe = StableDiffusionXLControlNetPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, ).to("cuda" if torch.cuda.is_available() else "cpu") @app.route("/generate", methods=["POST"]) def generate_image(): data = request.json # Extract prompt, sketch image (Base64), and optional parameters prompt = data.get("prompt", "A default prompt") negative_prompt = data.get("negative_prompt", "low quality, blurry, bad details") sketch_base64 = data.get("sketch", None) if not sketch_base64: return jsonify({"error": "Sketch image is required."}), 400 try: # Decode and preprocess the sketch image sketch_bytes = base64.b64decode(sketch_base64) sketch_image = Image.open(BytesIO(sketch_bytes)).convert("L") # Convert to grayscale sketch_image = sketch_image.resize((1024, 1024)) # Generate the image using the pipeline with torch.no_grad(): images = pipe( prompt=prompt, negative_prompt=negative_prompt, image=sketch_image, controlnet_conditioning_scale=1.0, width=1024, height=1024, num_inference_steps=30, ).images # Convert output image to Base64 buffered = BytesIO() images[0].save(buffered, format="PNG") image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") return jsonify({"image": image_base64}) except Exception as e: return jsonify({"error": str(e)}), 500 if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)