Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import shutil | |
from urllib.parse import parse_qs, urlparse | |
import gradio as gr | |
import requests | |
import spaces | |
import torch | |
from diffusers import ( | |
AutoencoderKL, | |
AutoPipelineForImage2Image, | |
StableDiffusionImg2ImgPipeline, | |
StableDiffusionXLImg2ImgPipeline, | |
) | |
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( | |
download_from_original_stable_diffusion_ckpt, | |
) | |
from loguru import logger | |
from PIL import Image | |
from slugify import slugify | |
from tqdm import tqdm | |
from tqdm.contrib.concurrent import thread_map | |
SUPPORTED_MODELS = [ | |
"https://civitai.com/models/4384/dreamshaper", | |
"https://civitai.com/models/44960/mpixel", | |
"https://civitai.com/models/92444/lelo-lego-lora-for-xl-and-sd15", | |
"https://civitai.com/models/120298/chinese-landscape-art", | |
"https://civitai.com/models/150986/blueprintify-sd-xl-10", | |
"https://civitai.com/models/257749/pony-diffusion-v6-xl", | |
] | |
DEFAULT_MODEL = "https://civitai.com/models/4384/dreamshaper" | |
model_url = os.environ.get("MODEL_URL", DEFAULT_MODEL) | |
gpu_duration = int(os.environ.get("GPU_DURATION", 60)) | |
logger.debug(f"Loading model info for: {model_url}") | |
model_url_parsed = urlparse(model_url) | |
model_id = int(model_url_parsed.path.split("/")[2]) | |
model_version_id = parse_qs(model_url_parsed.query).get("modelVersionId") | |
if model_version_id is not None: | |
model_version_id = int(model_version_id[0]) | |
logger.debug(f"Model version id: {model_version_id}") | |
r = requests.get(f"https://civitai.com/api/v1/models/{model_id}") | |
try: | |
r.raise_for_status() | |
except requests.HTTPError as e: | |
raise requests.HTTPError( | |
r.text.strip(), request=e.request, response=e.response | |
) from e | |
model = r.json() | |
logger.debug(f"Model info: {model}") | |
model_version = ( | |
model["modelVersions"][0] | |
if model_version_id is None | |
else next(mv for mv in model["modelVersions"] if mv["id"] == model_version_id) | |
) | |
assert len(model_version["files"]) <= 2 | |
assert len({file["type"] for file in model_version["files"]}) == len( | |
model_version["files"] | |
) | |
assert all(file["type"] in ["Model", "VAE"] for file in model_version["files"]) | |
assert all( | |
file["metadata"]["format"] in ["SafeTensor"] for file in model_version["files"] | |
) | |
def download(file: str, url: str): | |
if os.path.exists(file): | |
return | |
r = requests.get(url, stream=True) | |
r.raise_for_status() | |
temp_file = f"/tmp/{file}" | |
with tqdm( | |
desc=file, total=int(r.headers["content-length"]), unit="B", unit_scale=True | |
) as pbar, open(temp_file, "wb") as f: | |
for chunk in r.iter_content(chunk_size=1024 * 1024): | |
f.write(chunk) | |
pbar.update(len(chunk)) | |
shutil.move(temp_file, file) | |
model_name = model["name"] | |
def get_file_name(file_type): | |
return f"{slugify(model_name)}.{slugify(file_type)}.safetensors" | |
for _ in thread_map( | |
lambda file: download(get_file_name(file["type"]), file["downloadUrl"]), | |
model_version["files"], | |
): | |
pass | |
model_type = model["type"] | |
if model_type == "Checkpoint": | |
logger.debug(f"Loading pipeline for checkpoint") | |
pipe_args = {} | |
if os.path.exists(get_file_name("VAE")): | |
logger.debug(f"Loading VAE") | |
pipe_args["vae"] = AutoencoderKL.from_single_file( | |
get_file_name("VAE"), | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
) | |
base_model = model_version["baseModel"] | |
if base_model == "SD 1.5": | |
pipeline_class = StableDiffusionImg2ImgPipeline | |
elif base_model == "SDXL 1.0": | |
pipeline_class = StableDiffusionXLImg2ImgPipeline | |
pipe = download_from_original_stable_diffusion_ckpt( | |
checkpoint_path_or_dict=get_file_name("Model"), | |
from_safetensors=True, | |
pipeline_class=pipeline_class, | |
load_safety_checker=False, | |
**pipe_args, | |
) | |
elif model_type == "LORA": | |
logger.debug(f"Loading pipeline for LORA") | |
base_model = model_version["baseModel"] | |
if base_model == "SD 1.5": | |
pipe = AutoPipelineForImage2Image.from_pretrained( | |
"stable-diffusion-v1-5/stable-diffusion-v1-5", | |
safety_checker=None, | |
requires_safety_checker=False, | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
variant="fp16", | |
) | |
elif base_model == "SDXL 1.0": | |
# Use AutoPipelineForImage2Image with the base model | |
# since LORA are trained on base | |
pipe = AutoPipelineForImage2Image.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
variant="fp16", | |
) | |
else: | |
raise ValueError(f"Unsupported base model: {base_model}") | |
adapter_name = slugify(model_name) | |
pipe.load_lora_weights(get_file_name("Model"), adapter_name=adapter_name) | |
else: | |
raise ValueError(f"Unsupported model type: {model_type}") | |
pipe = pipe.to("cuda") | |
def infer( | |
prompt: str, | |
init_image: Image.Image, | |
negative_prompt: str | None, | |
strength: float, | |
num_inference_steps: int, | |
guidance_scale: float, | |
lora_weight: float, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
logger.info(f"Starting image generation: {dict(prompt=prompt, image=init_image)}") | |
# Downscale the image | |
init_image.thumbnail((1024, 1024)) | |
additional_args = { | |
k: v | |
for k, v in dict( | |
strength=strength, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
).items() | |
if v | |
} | |
if lora_weight: | |
pipe.set_adapters(adapter_name, lora_weight) | |
logger.debug(f"Generating image: {dict(prompt=prompt, **additional_args)}") | |
images = pipe( | |
prompt=prompt, | |
image=init_image, | |
negative_prompt=negative_prompt, | |
**additional_args, | |
).images | |
return images[0] | |
css = """ | |
@media (max-width: 1280px) { | |
#images-container { | |
flex-direction: column; | |
} | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(): | |
gr.Markdown("# Image-to-Image with Civitai Models") | |
gr.Markdown(f"## Model: [{model_name}]({model_url})") | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
run_button = gr.Button("Run", scale=0, variant="primary") | |
with gr.Row(elem_id="images-container"): | |
init_image = gr.Image(label="Initial image", type="pil") | |
result = gr.Image(label="Result") | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
max_lines=1, | |
placeholder="Enter a negative prompt", | |
) | |
with gr.Row(): | |
strength = gr.Slider( | |
label="Strength", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.0, | |
) | |
lora_weight = gr.Slider( | |
label="LORA weight", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.0, | |
visible=model_type == "LORA", | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=0, | |
maximum=100, | |
step=1, | |
value=0, | |
) | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.0, | |
maximum=100.0, | |
step=0.1, | |
value=0.0, | |
) | |
gr.on( | |
triggers=[run_button.click, prompt.submit], | |
fn=infer, | |
inputs=[ | |
prompt, | |
init_image, | |
negative_prompt, | |
strength, | |
num_inference_steps, | |
guidance_scale, | |
lora_weight, | |
], | |
outputs=[result], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |