|
import torch |
|
import gradio as gr |
|
from PIL import Image |
|
import qrcode |
|
from pathlib import Path |
|
from multiprocessing import cpu_count |
|
import requests |
|
import io |
|
import os |
|
from PIL import Image |
|
|
|
from diffusers import ( |
|
StableDiffusionPipeline, |
|
StableDiffusionControlNetImg2ImgPipeline, |
|
ControlNetModel, |
|
DDIMScheduler, |
|
DPMSolverMultistepScheduler, |
|
DEISMultistepScheduler, |
|
HeunDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
) |
|
|
|
qrcode_generator = qrcode.QRCode( |
|
version=1, |
|
error_correction=qrcode.ERROR_CORRECT_H, |
|
box_size=10, |
|
border=4, |
|
) |
|
|
|
controlnet = ControlNetModel.from_pretrained( |
|
"DionTimmer/controlnet_qrcode-control_v1p_sd15", torch_dtype=torch.float16 |
|
) |
|
|
|
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
controlnet=controlnet, |
|
safety_checker=None, |
|
torch_dtype=torch.float16, |
|
).to("cuda") |
|
pipe.enable_xformers_memory_efficient_attention() |
|
|
|
|
|
def resize_for_condition_image(input_image: Image.Image, resolution: int): |
|
input_image = input_image.convert("RGB") |
|
W, H = input_image.size |
|
k = float(resolution) / min(H, W) |
|
H *= k |
|
W *= k |
|
H = int(round(H / 64.0)) * 64 |
|
W = int(round(W / 64.0)) * 64 |
|
img = input_image.resize((W, H), resample=Image.LANCZOS) |
|
return img |
|
|
|
|
|
SAMPLER_MAP = { |
|
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"), |
|
"DPM++ Karras": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True), |
|
"Heun": lambda config: HeunDiscreteScheduler.from_config(config), |
|
"Euler": lambda config: EulerDiscreteScheduler.from_config(config), |
|
"DDIM": lambda config: DDIMScheduler.from_config(config), |
|
"DEIS": lambda config: DEISMultistepScheduler.from_config(config), |
|
} |
|
|
|
|
|
|
|
def inference( |
|
first_name: str = "John", |
|
last_name: str = "Doe", |
|
telephone_number: str = "+60123456789", |
|
email_address: str = "john.doe@example.com", |
|
url: str = "https://example.com", |
|
prompt: str = "Sky view of highly aesthetic, ancient greek thermal baths in beautiful nature", |
|
negative_prompt: str = "ugly, disfigured, low quality, blurry, nsfw", |
|
): |
|
guidance_scale = 7.5 |
|
controlnet_conditioning_scale = 1.5 |
|
strength = 0.9 |
|
seed = -1 |
|
sampler = "DPM++ Karras SDE" |
|
qrcode_image = None |
|
|
|
qr_code_content = f"MECARD:N:{last_name},{first_name};TEL:{telephone_number};EMAIL:{email_address};URL:{url};" |
|
|
|
if prompt is None or prompt == "": |
|
raise gr.Error("Prompt is required") |
|
|
|
if qr_code_content == "": |
|
raise gr.Error("Content is required") |
|
|
|
pipe.scheduler = SAMPLER_MAP[sampler](pipe.scheduler.config) |
|
|
|
generator = torch.manual_seed(seed) if seed != -1 else torch.Generator() |
|
|
|
if qr_code_content != "" or qrcode_image.size == (1, 1): |
|
print("Generating QR Code from content") |
|
qr = qrcode.QRCode( |
|
version=1, |
|
error_correction=qrcode.constants.ERROR_CORRECT_H, |
|
box_size=10, |
|
border=4, |
|
) |
|
qr.add_data(qr_code_content) |
|
qr.make(fit=True) |
|
|
|
qrcode_image = qr.make_image(fill_color="black", back_color="white") |
|
qrcode_image = resize_for_condition_image(qrcode_image, 768) |
|
else: |
|
print("Using QR Code Image") |
|
qrcode_image = resize_for_condition_image(qrcode_image, 768) |
|
|
|
out = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
image=qrcode_image, |
|
control_image=qrcode_image, |
|
width=768, |
|
height=768, |
|
guidance_scale=float(guidance_scale), |
|
controlnet_conditioning_scale=float(controlnet_conditioning_scale), |
|
generator=generator, |
|
strength=float(strength), |
|
num_inference_steps=40, |
|
) |
|
return out.images[0] |
|
|
|
|
|
|
|
generator = gr.Interface( |
|
fn=inference, |
|
inputs=[ |
|
gr.Textbox( |
|
label="First Name", |
|
value="John", |
|
), |
|
gr.Textbox( |
|
label="Last Name", |
|
value="Doe", |
|
), |
|
gr.Textbox( |
|
label="Telephone Number", |
|
value="+60123456789", |
|
), |
|
gr.Textbox( |
|
label="Email Address", |
|
value="john.doe@example.com" |
|
), |
|
gr.Textbox( |
|
label="URL", |
|
value="https://example.com", |
|
), |
|
gr.Textbox( |
|
label="Prompt", |
|
value="Sky view of highly aesthetic, ancient greek thermal baths in beautiful nature", |
|
), |
|
gr.Textbox( |
|
label="Negative Prompt", |
|
value="ugly, disfigured, low quality, blurry, nsfw", |
|
) |
|
], |
|
outputs="image" |
|
) |
|
|
|
if __name__ == "__main__": |
|
generator.queue(concurrency_count=1, max_size=20) |
|
generator.launch() |