blanchon's picture
🔥 Add app.py
771c08c
raw
history blame
7 kB
from typing import Optional
import gradio as gr
import qrcode
import torch
from diffusers import (
ControlNetModel,
EulerAncestralDiscreteScheduler,
StableDiffusionControlNetPipeline,
)
from gradio.components import Image, Radio, Slider, Textbox, Number
from PIL import Image as PilImage
from typing_extensions import Literal
def main():
device = (
'cuda' if torch.cuda.is_available()
else 'mps' if torch.backends.mps.is_available()
else 'cpu'
)
controlnet_tile = ControlNetModel.from_pretrained(
"lllyasviel/control_v11f1e_sd15_tile",
torch_dtype=torch.float16,
use_safetensors=False
).to(device)
controlnet_brightness = ControlNetModel.from_pretrained(
"ioclab/control_v1p_sd15_brightness",
torch_dtype=torch.float16,
use_safetensors=True
).to(device)
def make_pipe(hf_repo: str, device: str) -> StableDiffusionControlNetPipeline:
pipe = StableDiffusionControlNetPipeline.from_pretrained(
hf_repo,
controlnet=[controlnet_tile, controlnet_brightness],
torch_dtype=torch.float16,
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
return pipe.to(device)
pipes = {
"DreamShaper": make_pipe("Lykon/DreamShaper", "cpu"),
# "Realistic Vision V1.4": make_pipe("SG161222/Realistic_Vision_V1.4", "cpu"),
# "OpenJourney": make_pipe("prompthero/openjourney", "cpu"),
# "Anything V3": make_pipe("Linaqruf/anything-v3.0", "cpu"),
}
def move_pipe(hf_repo: str):
for pipe_name, pipe in pipes.items():
if pipe_name != hf_repo:
pipe.to("cpu")
return pipes[hf_repo].to(device)
def predict(
model: Literal[
"DreamShaper",
# "Realistic Vision V1.4",
# "OpenJourney",
# "Anything V3"
],
qrcode_data: str,
prompt: str,
negative_prompt: Optional[str] = None,
num_inference_steps: int = 100,
guidance_scale: int = 9,
controlnet_conditioning_tile: float = 0.25,
controlnet_conditioning_brightness: float = 0.45,
seed: int = 1331,
) -> PilImage:
generator = torch.Generator(device="cuda").manual_seed(seed)
if model == "DreamShaper":
pipe = move_pipe("DreamShaper")
# elif model == "Realistic Vision V1.4":
# pipe = move_pipe("Realistic Vision V1.4")
# elif model == "OpenJourney":
# pipe = move_pipe("OpenJourney")
# elif model == "Anything V3":
# pipe = move_pipe("Anything V3")
qr = qrcode.QRCode(
error_correction=qrcode.constants.ERROR_CORRECT_H,
box_size=11,
border=9,
)
qr.add_data(qrcode_data)
qr.make(fit=True)
qrcode_image = qr.make_image(
fill_color="black",
back_color="white"
).convert("RGB")
qrcode_image = qrcode_image.resize((512, 512), PilImage.LANCZOS)
image = pipe(
prompt,
[qrcode_image, qrcode_image],
num_inference_steps=num_inference_steps,
generator=generator,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=[
controlnet_conditioning_tile,
controlnet_conditioning_brightness
]
).images[0]
return image
ui = gr.Interface(
fn=predict,
inputs=[
Radio(
value="DreamShaper",
label="Model",
choices=[
"DreamShaper",
# "Realistic Vision V1.4",
# "OpenJourney",
# "Anything V3"
],
),
Textbox(
value="https://twitter.com/JulienBlanchon",
label="QR Code Data",
),
Textbox(
value="Japanese ramen with chopsticks, egg and steam, ultra detailed 8k",
label="Prompt",
),
Textbox(
value="logo, watermark, signature, text, BadDream, UnrealisticDream",
label="Negative Prompt",
optional=True
),
Slider(
value=100,
label="Number of Inference Steps",
minimum=10,
maximum=400,
step=1,
),
Slider(
value=9,
label="Guidance Scale",
minimum=1,
maximum=20,
step=1,
),
Slider(
value=0.25,
label="Controlnet Conditioning Tile",
minimum=0.0,
maximum=1.0,
step=0.05,
),
Slider(
value=0.45,
label="Controlnet Conditioning Brightness",
minimum=0.0,
maximum=1.0,
step=0.05,
),
Number(
value=1,
label="Seed",
precision=0,
),
],
outputs=Image(
label="Generated Image",
type="pil",
),
examples=[
[
"DreamShaper",
"https://twitter.com/JulienBlanchon",
"Japanese ramen with chopsticks, egg and steam, ultra detailed 8k",
"logo, watermark, signature, text, BadDream, UnrealisticDream",
100,
9,
0.25,
0.45,
1,
],
# [
# "Anything V3",
# "https://twitter.com/JulienBlanchon",
# "Japanese ramen with chopsticks, egg and steam, ultra detailed 8k",
# "logo, watermark, signature, text, BadDream, UnrealisticDream",
# 100,
# 9,
# 0.25,
# 0.60,
# 1,
# ],
[
"DreamShaper",
"https://twitter.com/JulienBlanchon",
"processor, chipset, electricity, black and white board",
"logo, watermark, signature, text, BadDream, UnrealisticDream",
300,
9,
0.50,
0.30,
1,
],
],
cache_examples=True,
title="Stable Diffusion QR Code Controlnet",
description="Generate QR Code with Stable Diffusion and Controlnet",
allow_flagging="never",
max_batch_size=1,
)
ui.queue(concurrency_count=10).launch()
if __name__ == "__main__":
main()