Jitesh Dhamaniya
update model id
3890bcb
import os
import torch
from diffusers import DiffusionPipeline
from PIL import Image
from io import BytesIO
import base64
# Load the model once (caching for efficiency)
MODEL_ID = "jiteshdhamaniya/alimama-creative-FLUX.1-dev-Controlnet-Inpainting-Alpha"
CONTROLNET_MODEL = "jiteshdhamaniya/alimama-creative-FLUX.1-dev-Controlnet-Inpainting-Alpha"
TRANSFORMER_MODEL = "black-forest-labs/FLUX.1-dev"
controlnet = DiffusionPipeline.from_pretrained(CONTROLNET_MODEL, torch_dtype=torch.bfloat16)
transformer = DiffusionPipeline.from_pretrained(TRANSFORMER_MODEL, subfolder="transformer", torch_dtype=torch.bfloat16)
pipeline = DiffusionPipeline.from_pretrained(
MODEL_ID,
controlnet=controlnet,
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda" if torch.cuda.is_available() else "cpu")
# Function to handle inference
def handle(inputs, context):
try:
# Parse inputs
prompt = inputs.get("prompt", "default prompt text")
control_image_base64 = inputs.get("control_image")
mask_image_base64 = inputs.get("mask_image")
num_inference_steps = inputs.get("num_inference_steps", 28)
guidance_scale = inputs.get("guidance_scale", 3.5)
controlnet_conditioning_scale = inputs.get("controlnet_conditioning_scale", 0.9)
# Convert Base64 images to PIL format
control_image = Image.open(BytesIO(base64.b64decode(control_image_base64))).convert("RGB")
mask_image = Image.open(BytesIO(base64.b64decode(mask_image_base64))).convert("RGB")
# Perform inference
result = pipeline(
prompt=prompt,
control_image=control_image,
control_mask=mask_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
).images[0]
# Convert result to Base64 string
buffered = BytesIO()
result.save(buffered, format="PNG")
result_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Return the result
return {"status": "success", "image": result_base64}
except Exception as e:
return {"status": "error", "message": str(e)}