import os import shutil from urllib.parse import parse_qs, urlparse import gradio as gr import requests import spaces import torch from diffusers import ( AutoencoderKL, AutoPipelineForImage2Image, StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline, ) from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( download_from_original_stable_diffusion_ckpt, ) from loguru import logger from PIL import Image from slugify import slugify from tqdm import tqdm from tqdm.contrib.concurrent import thread_map SUPPORTED_MODELS = [ "https://civitai.com/models/4384/dreamshaper", "https://civitai.com/models/44960/mpixel", "https://civitai.com/models/92444/lelo-lego-lora-for-xl-and-sd15", "https://civitai.com/models/120298/chinese-landscape-art", "https://civitai.com/models/150986/blueprintify-sd-xl-10", "https://civitai.com/models/257749/pony-diffusion-v6-xl", ] DEFAULT_MODEL = "https://civitai.com/models/4384/dreamshaper" model_url = os.environ.get("MODEL_URL", DEFAULT_MODEL) gpu_duration = int(os.environ.get("GPU_DURATION", 60)) logger.debug(f"Loading model info for: {model_url}") model_url_parsed = urlparse(model_url) model_id = int(model_url_parsed.path.split("/")[2]) model_version_id = parse_qs(model_url_parsed.query).get("modelVersionId") if model_version_id is not None: model_version_id = int(model_version_id[0]) logger.debug(f"Model version id: {model_version_id}") r = requests.get(f"https://civitai.com/api/v1/models/{model_id}") try: r.raise_for_status() except requests.HTTPError as e: raise requests.HTTPError( r.text.strip(), request=e.request, response=e.response ) from e model = r.json() logger.debug(f"Model info: {model}") model_version = ( model["modelVersions"][0] if model_version_id is None else next(mv for mv in model["modelVersions"] if mv["id"] == model_version_id) ) assert len(model_version["files"]) <= 2 assert len({file["type"] for file in model_version["files"]}) == len( model_version["files"] ) assert all(file["type"] in ["Model", "VAE"] for file in model_version["files"]) assert all( file["metadata"]["format"] in ["SafeTensor"] for file in model_version["files"] ) def download(file: str, url: str): if os.path.exists(file): return r = requests.get(url, stream=True) r.raise_for_status() temp_file = f"/tmp/{file}" with tqdm( desc=file, total=int(r.headers["content-length"]), unit="B", unit_scale=True ) as pbar, open(temp_file, "wb") as f: for chunk in r.iter_content(chunk_size=1024 * 1024): f.write(chunk) pbar.update(len(chunk)) shutil.move(temp_file, file) model_name = model["name"] def get_file_name(file_type): return f"{slugify(model_name)}.{slugify(file_type)}.safetensors" for _ in thread_map( lambda file: download(get_file_name(file["type"]), file["downloadUrl"]), model_version["files"], ): pass model_type = model["type"] if model_type == "Checkpoint": logger.debug(f"Loading pipeline for checkpoint") pipe_args = {} if os.path.exists(get_file_name("VAE")): logger.debug(f"Loading VAE") pipe_args["vae"] = AutoencoderKL.from_single_file( get_file_name("VAE"), torch_dtype=torch.float16, use_safetensors=True, ) base_model = model_version["baseModel"] if base_model == "SD 1.5": pipeline_class = StableDiffusionImg2ImgPipeline elif base_model == "SDXL 1.0": pipeline_class = StableDiffusionXLImg2ImgPipeline pipe = download_from_original_stable_diffusion_ckpt( checkpoint_path_or_dict=get_file_name("Model"), from_safetensors=True, pipeline_class=pipeline_class, load_safety_checker=False, **pipe_args, ) elif model_type == "LORA": logger.debug(f"Loading pipeline for LORA") base_model = model_version["baseModel"] if base_model == "SD 1.5": pipe = AutoPipelineForImage2Image.from_pretrained( "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, requires_safety_checker=False, torch_dtype=torch.float16, use_safetensors=True, variant="fp16", ) elif base_model == "SDXL 1.0": # Use AutoPipelineForImage2Image with the base model # since LORA are trained on base pipe = AutoPipelineForImage2Image.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16", ) else: raise ValueError(f"Unsupported base model: {base_model}") adapter_name = slugify(model_name) pipe.load_lora_weights(get_file_name("Model"), adapter_name=adapter_name) else: raise ValueError(f"Unsupported model type: {model_type}") pipe = pipe.to("cuda") @logger.catch(reraise=True) @spaces.GPU(duration=gpu_duration) def infer( prompt: str, init_image: Image.Image, negative_prompt: str | None, strength: float, num_inference_steps: int, guidance_scale: float, lora_weight: float, progress=gr.Progress(track_tqdm=True), ): logger.info(f"Starting image generation: {dict(prompt=prompt, image=init_image)}") # Downscale the image init_image.thumbnail((1024, 1024)) additional_args = { k: v for k, v in dict( strength=strength, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).items() if v } if lora_weight: pipe.set_adapters(adapter_name, lora_weight) logger.debug(f"Generating image: {dict(prompt=prompt, **additional_args)}") images = pipe( prompt=prompt, image=init_image, negative_prompt=negative_prompt, **additional_args, ).images return images[0] css = """ @media (max-width: 1280px) { #images-container { flex-direction: column; } } """ with gr.Blocks(css=css) as demo: with gr.Column(): gr.Markdown("# Image-to-Image with Civitai Models") gr.Markdown(f"## Model: [{model_name}]({model_url})") with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0, variant="primary") with gr.Row(elem_id="images-container"): init_image = gr.Image(label="Initial image", type="pil") result = gr.Image(label="Result") with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Text( label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", ) with gr.Row(): strength = gr.Slider( label="Strength", minimum=0.0, maximum=1.0, step=0.01, value=0.0, ) lora_weight = gr.Slider( label="LORA weight", minimum=0.0, maximum=1.0, step=0.01, value=0.0, visible=model_type == "LORA", ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=0, maximum=100, step=1, value=0, ) guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=100.0, step=0.1, value=0.0, ) gr.on( triggers=[run_button.click, prompt.submit], fn=infer, inputs=[ prompt, init_image, negative_prompt, strength, num_inference_steps, guidance_scale, lora_weight, ], outputs=[result], ) if __name__ == "__main__": demo.launch()