0xadamm's picture
added handler.py
8a4950c
raw
history blame
3.29 kB
# handler.py
from PIL import Image
from diffusers import (
StableDiffusionControlNetImg2ImgPipeline,
ControlNetModel,
DDIMScheduler,
)
from diffusers.utils import load_image
import torch
import openai
from io import BytesIO
import base64
import qrcode
class QRImageHandler:
def __init__(
self,
controlnet_path="DionTimmer/controlnet_qrcode-control_v11p_sd21",
pipeline_path="stabilityai/stable-diffusion-2-1",
):
self.controlnet = ControlNetModel.from_pretrained(
controlnet_path, torch_dtype=torch.float16
)
self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
pipeline_path,
controlnet=self.controlnet,
safety_checker=None,
torch_dtype=torch.float16,
)
self.pipe.enable_xformers_memory_efficient_attention()
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
self.pipe.enable_model_cpu_offload()
@staticmethod
def resize_for_condition_image(input_image: Image, resolution: int):
input_image = input_image.convert("RGB")
W, H = input_image.size
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(round(H / 64.0)) * 64
W = int(round(W / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
return img
def __call__(
self,
prompt,
negative_prompt,
qrcode_data,
guidance_scale,
controlnet_conditioning_scale,
strength,
generator_seed,
width,
height,
num_inference_steps,
):
openai.api_key = "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"
qr = qrcode.QRCode(
version=1,
error_correction=qrcode.constants.ERROR_CORRECT_H,
box_size=10,
border=4,
)
qr.add_data(qrcode_data)
qr.make(fit=True)
img = qr.make_image(fill_color="black", back_color="white")
# Resize image
basewidth = 768
wpercent = basewidth / float(img.size[0])
hsize = int((float(img.size[1]) * float(wpercent)))
qrcode_image = img.resize((basewidth, hsize), Image.LANCZOS)
response = openai.Image.create(prompt=prompt, n=1, size="1024x1024")
image_url = response.data[0].url
init_image = load_image(image_url)
control_image = qrcode_image
init_image = self.resize_for_condition_image(init_image, 768)
generator = torch.manual_seed(generator_seed)
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=init_image,
control_image=control_image,
width=width,
height=height,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
generator=generator,
strength=strength,
num_inference_steps=num_inference_steps,
)
pil_image = image.images[0]
buffered = BytesIO()
pil_image.save(buffered, format="PNG")
image_base64 = base64.b64encode(buffered.getvalue()).decode()
return image_base64