Spaces:
Paused
Paused
from flask import Flask, request, jsonify | |
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler | |
from PIL import Image | |
import torch | |
import base64 | |
from io import BytesIO | |
from huggingface_hub import login | |
# Authenticate with Hugging Face Hub (ensure you replace 'your_token_here') | |
import os | |
login(os.environ["HF_TOKEN"]) | |
# Initialize Flask app | |
app = Flask(__name__) | |
# Load Hugging Face pipeline components | |
model_id = "fyp1/sketchToImage" | |
controlnet = ControlNetModel.from_pretrained(f"{model_id}/controlnet", torch_dtype=torch.float16) | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(f"{model_id}/scheduler") | |
# Initialize Stable Diffusion XL ControlNet Pipeline | |
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
controlnet=controlnet, | |
vae=vae, | |
scheduler=scheduler, | |
safety_checker=None, | |
torch_dtype=torch.float16, | |
).to("cuda" if torch.cuda.is_available() else "cpu") | |
def generate_image(): | |
data = request.json | |
# Extract prompt, sketch image (Base64), and optional parameters | |
prompt = data.get("prompt", "A default prompt") | |
negative_prompt = data.get("negative_prompt", "low quality, blurry, bad details") | |
sketch_base64 = data.get("sketch", None) | |
if not sketch_base64: | |
return jsonify({"error": "Sketch image is required."}), 400 | |
try: | |
# Decode and preprocess the sketch image | |
sketch_bytes = base64.b64decode(sketch_base64) | |
sketch_image = Image.open(BytesIO(sketch_bytes)).convert("L") # Convert to grayscale | |
sketch_image = sketch_image.resize((1024, 1024)) | |
# Generate the image using the pipeline | |
with torch.no_grad(): | |
images = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
image=sketch_image, | |
controlnet_conditioning_scale=1.0, | |
width=1024, | |
height=1024, | |
num_inference_steps=30, | |
).images | |
# Convert output image to Base64 | |
buffered = BytesIO() | |
images[0].save(buffered, format="PNG") | |
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return jsonify({"image": image_base64}) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) | |