dgoot's picture
Add LORA weight customization
9b97455
import os
import shutil
from urllib.parse import parse_qs, urlparse
import gradio as gr
import requests
import spaces
import torch
from diffusers import (
AutoencoderKL,
AutoPipelineForImage2Image,
StableDiffusionImg2ImgPipeline,
StableDiffusionXLImg2ImgPipeline,
)
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
from loguru import logger
from PIL import Image
from slugify import slugify
from tqdm import tqdm
from tqdm.contrib.concurrent import thread_map
SUPPORTED_MODELS = [
"https://civitai.com/models/4384/dreamshaper",
"https://civitai.com/models/44960/mpixel",
"https://civitai.com/models/92444/lelo-lego-lora-for-xl-and-sd15",
"https://civitai.com/models/120298/chinese-landscape-art",
"https://civitai.com/models/150986/blueprintify-sd-xl-10",
"https://civitai.com/models/257749/pony-diffusion-v6-xl",
]
DEFAULT_MODEL = "https://civitai.com/models/4384/dreamshaper"
model_url = os.environ.get("MODEL_URL", DEFAULT_MODEL)
gpu_duration = int(os.environ.get("GPU_DURATION", 60))
logger.debug(f"Loading model info for: {model_url}")
model_url_parsed = urlparse(model_url)
model_id = int(model_url_parsed.path.split("/")[2])
model_version_id = parse_qs(model_url_parsed.query).get("modelVersionId")
if model_version_id is not None:
model_version_id = int(model_version_id[0])
logger.debug(f"Model version id: {model_version_id}")
r = requests.get(f"https://civitai.com/api/v1/models/{model_id}")
try:
r.raise_for_status()
except requests.HTTPError as e:
raise requests.HTTPError(
r.text.strip(), request=e.request, response=e.response
) from e
model = r.json()
logger.debug(f"Model info: {model}")
model_version = (
model["modelVersions"][0]
if model_version_id is None
else next(mv for mv in model["modelVersions"] if mv["id"] == model_version_id)
)
assert len(model_version["files"]) <= 2
assert len({file["type"] for file in model_version["files"]}) == len(
model_version["files"]
)
assert all(file["type"] in ["Model", "VAE"] for file in model_version["files"])
assert all(
file["metadata"]["format"] in ["SafeTensor"] for file in model_version["files"]
)
def download(file: str, url: str):
if os.path.exists(file):
return
r = requests.get(url, stream=True)
r.raise_for_status()
temp_file = f"/tmp/{file}"
with tqdm(
desc=file, total=int(r.headers["content-length"]), unit="B", unit_scale=True
) as pbar, open(temp_file, "wb") as f:
for chunk in r.iter_content(chunk_size=1024 * 1024):
f.write(chunk)
pbar.update(len(chunk))
shutil.move(temp_file, file)
model_name = model["name"]
def get_file_name(file_type):
return f"{slugify(model_name)}.{slugify(file_type)}.safetensors"
for _ in thread_map(
lambda file: download(get_file_name(file["type"]), file["downloadUrl"]),
model_version["files"],
):
pass
model_type = model["type"]
if model_type == "Checkpoint":
logger.debug(f"Loading pipeline for checkpoint")
pipe_args = {}
if os.path.exists(get_file_name("VAE")):
logger.debug(f"Loading VAE")
pipe_args["vae"] = AutoencoderKL.from_single_file(
get_file_name("VAE"),
torch_dtype=torch.float16,
use_safetensors=True,
)
base_model = model_version["baseModel"]
if base_model == "SD 1.5":
pipeline_class = StableDiffusionImg2ImgPipeline
elif base_model == "SDXL 1.0":
pipeline_class = StableDiffusionXLImg2ImgPipeline
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=get_file_name("Model"),
from_safetensors=True,
pipeline_class=pipeline_class,
load_safety_checker=False,
**pipe_args,
)
elif model_type == "LORA":
logger.debug(f"Loading pipeline for LORA")
base_model = model_version["baseModel"]
if base_model == "SD 1.5":
pipe = AutoPipelineForImage2Image.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
safety_checker=None,
requires_safety_checker=False,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
elif base_model == "SDXL 1.0":
# Use AutoPipelineForImage2Image with the base model
# since LORA are trained on base
pipe = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
else:
raise ValueError(f"Unsupported base model: {base_model}")
adapter_name = slugify(model_name)
pipe.load_lora_weights(get_file_name("Model"), adapter_name=adapter_name)
else:
raise ValueError(f"Unsupported model type: {model_type}")
pipe = pipe.to("cuda")
@logger.catch(reraise=True)
@spaces.GPU(duration=gpu_duration)
def infer(
prompt: str,
init_image: Image.Image,
negative_prompt: str | None,
strength: float,
num_inference_steps: int,
guidance_scale: float,
lora_weight: float,
progress=gr.Progress(track_tqdm=True),
):
logger.info(f"Starting image generation: {dict(prompt=prompt, image=init_image)}")
# Downscale the image
init_image.thumbnail((1024, 1024))
additional_args = {
k: v
for k, v in dict(
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).items()
if v
}
if lora_weight:
pipe.set_adapters(adapter_name, lora_weight)
logger.debug(f"Generating image: {dict(prompt=prompt, **additional_args)}")
images = pipe(
prompt=prompt,
image=init_image,
negative_prompt=negative_prompt,
**additional_args,
).images
return images[0]
css = """
@media (max-width: 1280px) {
#images-container {
flex-direction: column;
}
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column():
gr.Markdown("# Image-to-Image with Civitai Models")
gr.Markdown(f"## Model: [{model_name}]({model_url})")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
with gr.Row(elem_id="images-container"):
init_image = gr.Image(label="Initial image", type="pil")
result = gr.Image(label="Result")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
with gr.Row():
strength = gr.Slider(
label="Strength",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.0,
)
lora_weight = gr.Slider(
label="LORA weight",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.0,
visible=model_type == "LORA",
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=0,
maximum=100,
step=1,
value=0,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=100.0,
step=0.1,
value=0.0,
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
init_image,
negative_prompt,
strength,
num_inference_steps,
guidance_scale,
lora_weight,
],
outputs=[result],
)
if __name__ == "__main__":
demo.launch()