sketch-to-image / app.py
fyp1's picture
Upload 3 files
ea673b2 verified
raw
history blame
2.67 kB
from flask import Flask, request, jsonify
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler
from PIL import Image
import torch
import base64
from io import BytesIO
from huggingface_hub import login
# Authenticate with Hugging Face Hub (ensure you replace 'your_token_here')
import os
login(os.environ["HF_TOKEN"])
# Initialize Flask app
app = Flask(__name__)
# Load Hugging Face pipeline components
model_id = "fyp1/sketchToImage"
controlnet = ControlNetModel.from_pretrained(f"{model_id}/controlnet", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(f"{model_id}/scheduler")
# Initialize Stable Diffusion XL ControlNet Pipeline
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
vae=vae,
scheduler=scheduler,
safety_checker=None,
torch_dtype=torch.float16,
).to("cuda" if torch.cuda.is_available() else "cpu")
@app.route("/generate", methods=["POST"])
def generate_image():
data = request.json
# Extract prompt, sketch image (Base64), and optional parameters
prompt = data.get("prompt", "A default prompt")
negative_prompt = data.get("negative_prompt", "low quality, blurry, bad details")
sketch_base64 = data.get("sketch", None)
if not sketch_base64:
return jsonify({"error": "Sketch image is required."}), 400
try:
# Decode and preprocess the sketch image
sketch_bytes = base64.b64decode(sketch_base64)
sketch_image = Image.open(BytesIO(sketch_bytes)).convert("L") # Convert to grayscale
sketch_image = sketch_image.resize((1024, 1024))
# Generate the image using the pipeline
with torch.no_grad():
images = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=sketch_image,
controlnet_conditioning_scale=1.0,
width=1024,
height=1024,
num_inference_steps=30,
).images
# Convert output image to Base64
buffered = BytesIO()
images[0].save(buffered, format="PNG")
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return jsonify({"image": image_base64})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)