Spaces:
Runtime error
Runtime error
import subprocess | |
import os.path as osp | |
subprocess.check_call("pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers", cwd=osp.dirname(__file__), shell=True) | |
import io | |
import base64 | |
import os | |
import sys | |
import numpy as np | |
import torch | |
from torch import autocast | |
import diffusers | |
from diffusers.configuration_utils import FrozenDict | |
from diffusers import ( | |
StableDiffusionPipeline, | |
StableDiffusionInpaintPipeline, | |
StableDiffusionImg2ImgPipeline, | |
StableDiffusionInpaintPipelineLegacy, | |
DDIMScheduler, | |
LMSDiscreteScheduler, | |
StableDiffusionUpscalePipeline, | |
DPMSolverMultistepScheduler | |
) | |
from diffusers.models import AutoencoderKL | |
from PIL import Image | |
from PIL import ImageOps | |
import gradio as gr | |
import base64 | |
import skimage | |
import skimage.measure | |
import yaml | |
import json | |
from enum import Enum | |
try: | |
abspath = os.path.abspath(__file__) | |
dirname = os.path.dirname(abspath) | |
os.chdir(dirname) | |
except: | |
pass | |
from utils import * | |
assert diffusers.__version__ >= "0.6.0", "Please upgrade diffusers to 0.6.0" | |
USE_NEW_DIFFUSERS = True | |
RUN_IN_SPACE = "RUN_IN_HG_SPACE" in os.environ | |
class ModelChoice(Enum): | |
INPAINTING = "stablediffusion-inpainting" | |
INPAINTING_IMG2IMG = "stablediffusion-inpainting+img2img-v1.5" | |
MODEL_1_5 = "stablediffusion-v1.5" | |
MODEL_1_4 = "stablediffusion-v1.4" | |
try: | |
from sd_grpcserver.pipeline.unified_pipeline import UnifiedPipeline | |
except: | |
UnifiedPipeline = StableDiffusionInpaintPipeline | |
# sys.path.append("./glid_3_xl_stable") | |
USE_GLID = False | |
# try: | |
# from glid3xlmodel import GlidModel | |
# except: | |
# USE_GLID = False | |
try: | |
cuda_available = torch.cuda.is_available() | |
except: | |
cuda_available = False | |
finally: | |
if sys.platform == "darwin": | |
device = "mps" if torch.backends.mps.is_available() else "cpu" | |
elif cuda_available: | |
device = "cuda" | |
else: | |
device = "cpu" | |
import contextlib | |
autocast = contextlib.nullcontext | |
with open("config.yaml", "r") as yaml_in: | |
yaml_object = yaml.safe_load(yaml_in) | |
config_json = json.dumps(yaml_object) | |
def load_html(): | |
body, canvaspy = "", "" | |
with open("index.html", encoding="utf8") as f: | |
body = f.read() | |
with open("canvas.py", encoding="utf8") as f: | |
canvaspy = f.read() | |
body = body.replace("- paths:\n", "") | |
body = body.replace(" - ./canvas.py\n", "") | |
body = body.replace("from canvas import InfCanvas", canvaspy) | |
return body | |
def test(x): | |
x = load_html() | |
return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" | |
DEBUG_MODE = False | |
try: | |
SAMPLING_MODE = Image.Resampling.LANCZOS | |
except Exception as e: | |
SAMPLING_MODE = Image.LANCZOS | |
try: | |
contain_func = ImageOps.contain | |
except Exception as e: | |
def contain_func(image, size, method=SAMPLING_MODE): | |
# from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain | |
im_ratio = image.width / image.height | |
dest_ratio = size[0] / size[1] | |
if im_ratio != dest_ratio: | |
if im_ratio > dest_ratio: | |
new_height = int(image.height / image.width * size[0]) | |
if new_height != size[1]: | |
size = (size[0], new_height) | |
else: | |
new_width = int(image.width / image.height * size[1]) | |
if new_width != size[0]: | |
size = (new_width, size[1]) | |
return image.resize(size, resample=method) | |
import argparse | |
parser = argparse.ArgumentParser(description="stablediffusion-infinity") | |
parser.add_argument("--port", type=int, help="listen port", dest="server_port") | |
parser.add_argument("--host", type=str, help="host", dest="server_name") | |
parser.add_argument("--share", action="store_true", help="share this app?") | |
parser.add_argument("--debug", action="store_true", help="debug mode") | |
parser.add_argument("--fp32", action="store_true", help="using full precision") | |
parser.add_argument("--encrypt", action="store_true", help="using https?") | |
parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile") | |
parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile") | |
parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password") | |
parser.add_argument( | |
"--auth", nargs=2, metavar=("username", "password"), help="use username password" | |
) | |
parser.add_argument( | |
"--remote_model", | |
type=str, | |
help="use a model (e.g. dreambooth fined) from huggingface hub", | |
default="", | |
) | |
parser.add_argument( | |
"--local_model", type=str, help="use a model stored on your PC", default="" | |
) | |
if __name__ == "__main__" and not RUN_IN_SPACE: | |
args = parser.parse_args() | |
else: | |
args = parser.parse_args() | |
# args = parser.parse_args(["--debug"]) | |
if args.auth is not None: | |
args.auth = tuple(args.auth) | |
model = {} | |
def get_token(): | |
token = "" | |
if os.path.exists(".token"): | |
with open(".token", "r") as f: | |
token = f.read() | |
token = os.environ.get("hftoken", token) | |
return token | |
def save_token(token): | |
with open(".token", "w") as f: | |
f.write(token) | |
def prepare_scheduler(scheduler): | |
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | |
new_config = dict(scheduler.config) | |
new_config["steps_offset"] = 1 | |
scheduler._internal_dict = FrozenDict(new_config) | |
return scheduler | |
def my_resize(width, height): | |
if width >= 512 and height >= 512: | |
return width, height | |
if width == height: | |
return 512, 512 | |
smaller = min(width, height) | |
larger = max(width, height) | |
if larger >= 608: | |
return width, height | |
factor = 1 | |
if smaller < 290: | |
factor = 2 | |
elif smaller < 330: | |
factor = 1.75 | |
elif smaller < 384: | |
factor = 1.375 | |
elif smaller < 400: | |
factor = 1.25 | |
elif smaller < 450: | |
factor = 1.125 | |
return int(factor * width)//8*8, int(factor * height)//8*8 | |
def load_learned_embed_in_clip( | |
learned_embeds_path, text_encoder, tokenizer, token=None | |
): | |
# https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb | |
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") | |
# separate token and the embeds | |
trained_token = list(loaded_learned_embeds.keys())[0] | |
embeds = loaded_learned_embeds[trained_token] | |
# cast to dtype of text_encoder | |
dtype = text_encoder.get_input_embeddings().weight.dtype | |
embeds.to(dtype) | |
# add the token in tokenizer | |
token = token if token is not None else trained_token | |
num_added_tokens = tokenizer.add_tokens(token) | |
if num_added_tokens == 0: | |
raise ValueError( | |
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer." | |
) | |
# resize the token embeddings | |
text_encoder.resize_token_embeddings(len(tokenizer)) | |
# get the id for the token and assign the embeds | |
token_id = tokenizer.convert_tokens_to_ids(token) | |
text_encoder.get_input_embeddings().weight.data[token_id] = embeds | |
scheduler_dict = {"PLMS": None, "DDIM": None, "K-LMS": None, "DPM": None} | |
class StableDiffusionInpaint: | |
def __init__( | |
self, token: str = "", model_name: str = "", model_path: str = "", **kwargs, | |
): | |
self.token = token | |
original_checkpoint = False | |
if model_path and os.path.exists(model_path): | |
if model_path.endswith(".ckpt"): | |
original_checkpoint = True | |
elif model_path.endswith(".json"): | |
model_name = os.path.dirname(model_path) | |
else: | |
model_name = model_path | |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") | |
vae.to(torch.float16) | |
if original_checkpoint: | |
print(f"Converting & Loading {model_path}") | |
from convert_checkpoint import convert_checkpoint | |
pipe = convert_checkpoint(model_path, inpainting=True) | |
if device == "cuda": | |
pipe.to(torch.float16) | |
inpaint = StableDiffusionInpaintPipeline( | |
vae=vae, | |
text_encoder=pipe.text_encoder, | |
tokenizer=pipe.tokenizer, | |
unet=pipe.unet, | |
scheduler=pipe.scheduler, | |
safety_checker=pipe.safety_checker, | |
feature_extractor=pipe.feature_extractor, | |
) | |
else: | |
print(f"Loading {model_name}") | |
if device == "cuda": | |
inpaint = StableDiffusionInpaintPipeline.from_pretrained( | |
model_name, | |
revision="fp16", | |
torch_dtype=torch.float16, | |
use_auth_token=token, | |
vae=vae | |
) | |
else: | |
inpaint = StableDiffusionInpaintPipeline.from_pretrained( | |
model_name, use_auth_token=token, | |
) | |
if os.path.exists("./embeddings"): | |
print("Note that StableDiffusionInpaintPipeline + embeddings is untested") | |
for item in os.listdir("./embeddings"): | |
if item.endswith(".bin"): | |
load_learned_embed_in_clip( | |
os.path.join("./embeddings", item), | |
inpaint.text_encoder, | |
inpaint.tokenizer, | |
) | |
inpaint.to(device) | |
inpaint.enable_xformers_memory_efficient_attention() | |
# if device == "mps": | |
# _ = text2img("", num_inference_steps=1) | |
scheduler_dict["PLMS"] = inpaint.scheduler | |
scheduler_dict["DDIM"] = prepare_scheduler( | |
DDIMScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
clip_sample=False, | |
set_alpha_to_one=False, | |
) | |
) | |
scheduler_dict["K-LMS"] = prepare_scheduler( | |
LMSDiscreteScheduler( | |
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | |
) | |
) | |
scheduler_dict["DPM"] = prepare_scheduler( | |
DPMSolverMultistepScheduler.from_config(inpaint.scheduler.config) | |
) | |
self.safety_checker = inpaint.safety_checker | |
save_token(token) | |
try: | |
total_memory = torch.cuda.get_device_properties(0).total_memory // ( | |
1024 ** 3 | |
) | |
if total_memory <= 5: | |
inpaint.enable_attention_slicing() | |
except: | |
pass | |
self.inpaint = inpaint | |
def run( | |
self, | |
image_pil, | |
prompt="", | |
negative_prompt="", | |
guidance_scale=7.5, | |
resize_check=True, | |
enable_safety=True, | |
fill_mode="patchmatch", | |
strength=0.75, | |
step=50, | |
enable_img2img=False, | |
use_seed=False, | |
seed_val=-1, | |
generate_num=1, | |
scheduler="", | |
scheduler_eta=0.0, | |
**kwargs, | |
): | |
inpaint = self.inpaint | |
selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"]) | |
for item in [inpaint]: | |
item.scheduler = selected_scheduler | |
if enable_safety: | |
item.safety_checker = self.safety_checker | |
else: | |
item.safety_checker = lambda images, **kwargs: (images, False) | |
width, height = image_pil.size | |
sel_buffer = np.array(image_pil) | |
img = sel_buffer[:, :, 0:3] | |
mask = sel_buffer[:, :, -1] | |
nmask = 255 - mask | |
process_width = width | |
process_height = height | |
if resize_check: | |
process_width, process_height = my_resize(width, height) | |
process_width=process_width*8//8 | |
process_height=process_height*8//8 | |
extra_kwargs = { | |
"num_inference_steps": step, | |
"guidance_scale": guidance_scale, | |
"eta": scheduler_eta, | |
} | |
if USE_NEW_DIFFUSERS: | |
extra_kwargs["negative_prompt"] = negative_prompt | |
extra_kwargs["num_images_per_prompt"] = generate_num | |
if use_seed: | |
generator = torch.Generator(inpaint.device).manual_seed(seed_val) | |
extra_kwargs["generator"] = generator | |
if True: | |
img, mask = functbl[fill_mode](img, mask) | |
mask = 255 - mask | |
mask = skimage.measure.block_reduce(mask, (8, 8), np.max) | |
mask = mask.repeat(8, axis=0).repeat(8, axis=1) | |
extra_kwargs["strength"] = strength | |
inpaint_func = inpaint | |
init_image = Image.fromarray(img) | |
mask_image = Image.fromarray(mask) | |
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8)) | |
if True: | |
images = inpaint_func( | |
prompt=prompt, | |
image=init_image.resize( | |
(process_width, process_height), resample=SAMPLING_MODE | |
), | |
mask_image=mask_image.resize((process_width, process_height)), | |
width=process_width, | |
height=process_height, | |
**extra_kwargs, | |
)["images"] | |
return images | |
class StableDiffusion: | |
def __init__( | |
self, | |
token: str = "", | |
model_name: str = "runwayml/stable-diffusion-v1-5", | |
model_path: str = None, | |
inpainting_model: bool = False, | |
**kwargs, | |
): | |
self.token = token | |
original_checkpoint = False | |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") | |
vae.to(torch.float16) | |
if model_path and os.path.exists(model_path): | |
if model_path.endswith(".ckpt"): | |
original_checkpoint = True | |
elif model_path.endswith(".json"): | |
model_name = os.path.dirname(model_path) | |
else: | |
model_name = model_path | |
if original_checkpoint: | |
print(f"Converting & Loading {model_path}") | |
from convert_checkpoint import convert_checkpoint | |
text2img = convert_checkpoint(model_path) | |
if device == "cuda" and not args.fp32: | |
text2img.to(torch.float16) | |
else: | |
print(f"Loading {model_name}") | |
if device == "cuda" and not args.fp32: | |
text2img = StableDiffusionPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
revision="fp16", | |
torch_dtype=torch.float16, | |
use_auth_token=token, | |
vae=vae | |
) | |
else: | |
text2img = StableDiffusionPipeline.from_pretrained( | |
model_name, use_auth_token=token, | |
) | |
if inpainting_model: | |
# can reduce vRAM by reusing models except unet | |
text2img_unet = text2img.unet | |
del text2img.vae | |
del text2img.text_encoder | |
del text2img.tokenizer | |
del text2img.scheduler | |
del text2img.safety_checker | |
del text2img.feature_extractor | |
import gc | |
gc.collect() | |
if device == "cuda": | |
inpaint = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
revision="fp16", | |
torch_dtype=torch.float16, | |
use_auth_token=token, | |
vae=vae | |
).to(device) | |
else: | |
inpaint = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", use_auth_token=token, | |
).to(device) | |
text2img_unet.to(device) | |
del text2img | |
gc.collect() | |
text2img = StableDiffusionPipeline( | |
vae=inpaint.vae, | |
text_encoder=inpaint.text_encoder, | |
tokenizer=inpaint.tokenizer, | |
unet=text2img_unet, | |
scheduler=inpaint.scheduler, | |
safety_checker=inpaint.safety_checker, | |
feature_extractor=inpaint.feature_extractor, | |
) | |
else: | |
inpaint = StableDiffusionInpaintPipelineLegacy( | |
vae=text2img.vae, | |
text_encoder=text2img.text_encoder, | |
tokenizer=text2img.tokenizer, | |
unet=text2img.unet, | |
scheduler=text2img.scheduler, | |
safety_checker=text2img.safety_checker, | |
feature_extractor=text2img.feature_extractor, | |
).to(device) | |
text_encoder = text2img.text_encoder | |
tokenizer = text2img.tokenizer | |
if os.path.exists("./embeddings"): | |
for item in os.listdir("./embeddings"): | |
if item.endswith(".bin"): | |
load_learned_embed_in_clip( | |
os.path.join("./embeddings", item), | |
text2img.text_encoder, | |
text2img.tokenizer, | |
) | |
text2img.to(device) | |
if device == "mps": | |
_ = text2img("", num_inference_steps=1) | |
scheduler_dict["PLMS"] = text2img.scheduler | |
scheduler_dict["DDIM"] = prepare_scheduler( | |
DDIMScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
clip_sample=False, | |
set_alpha_to_one=False, | |
) | |
) | |
scheduler_dict["K-LMS"] = prepare_scheduler( | |
LMSDiscreteScheduler( | |
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | |
) | |
) | |
scheduler_dict["DPM"] = prepare_scheduler( | |
DPMSolverMultistepScheduler.from_config(text2img.scheduler.config) | |
) | |
self.safety_checker = text2img.safety_checker | |
img2img = StableDiffusionImg2ImgPipeline( | |
vae=text2img.vae, | |
text_encoder=text2img.text_encoder, | |
tokenizer=text2img.tokenizer, | |
unet=text2img.unet, | |
scheduler=text2img.scheduler, | |
safety_checker=text2img.safety_checker, | |
feature_extractor=text2img.feature_extractor, | |
).to(device) | |
save_token(token) | |
try: | |
total_memory = torch.cuda.get_device_properties(0).total_memory // ( | |
1024 ** 3 | |
) | |
if total_memory <= 5: | |
inpaint.enable_attention_slicing() | |
except: | |
pass | |
self.text2img = text2img | |
self.inpaint = inpaint | |
self.img2img = img2img | |
self.unified = UnifiedPipeline( | |
vae=text2img.vae, | |
text_encoder=text2img.text_encoder, | |
tokenizer=text2img.tokenizer, | |
unet=text2img.unet, | |
scheduler=text2img.scheduler, | |
safety_checker=text2img.safety_checker, | |
feature_extractor=text2img.feature_extractor, | |
).to(device) | |
self.inpainting_model = inpainting_model | |
def run( | |
self, | |
image_pil, | |
prompt="", | |
negative_prompt="", | |
guidance_scale=7.5, | |
resize_check=True, | |
enable_safety=True, | |
fill_mode="patchmatch", | |
strength=0.75, | |
step=50, | |
enable_img2img=False, | |
use_seed=False, | |
seed_val=-1, | |
generate_num=1, | |
scheduler="", | |
scheduler_eta=0.0, | |
**kwargs, | |
): | |
text2img, inpaint, img2img, unified = ( | |
self.text2img, | |
self.inpaint, | |
self.img2img, | |
self.unified, | |
) | |
selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"]) | |
for item in [text2img, inpaint, img2img, unified]: | |
item.scheduler = selected_scheduler | |
if enable_safety: | |
item.safety_checker = self.safety_checker | |
else: | |
item.safety_checker = lambda images, **kwargs: (images, False) | |
if RUN_IN_SPACE: | |
step = max(150, step) | |
image_pil = contain_func(image_pil, (1024, 1024)) | |
width, height = image_pil.size | |
sel_buffer = np.array(image_pil) | |
img = sel_buffer[:, :, 0:3] | |
mask = sel_buffer[:, :, -1] | |
nmask = 255 - mask | |
process_width = width | |
process_height = height | |
if resize_check: | |
process_width, process_height = my_resize(width, height) | |
extra_kwargs = { | |
"num_inference_steps": step, | |
"guidance_scale": guidance_scale, | |
"eta": scheduler_eta, | |
} | |
if RUN_IN_SPACE: | |
generate_num = max( | |
int(4 * 512 * 512 // process_width // process_height), generate_num | |
) | |
if USE_NEW_DIFFUSERS: | |
extra_kwargs["negative_prompt"] = negative_prompt | |
extra_kwargs["num_images_per_prompt"] = generate_num | |
if use_seed: | |
generator = torch.Generator(text2img.device).manual_seed(seed_val) | |
extra_kwargs["generator"] = generator | |
if nmask.sum() < 1 and enable_img2img: | |
init_image = Image.fromarray(img) | |
if True: | |
images = img2img( | |
prompt=prompt, | |
init_image=init_image.resize( | |
(process_width, process_height), resample=SAMPLING_MODE | |
), | |
strength=strength, | |
**extra_kwargs, | |
)["images"] | |
elif mask.sum() > 0: | |
if fill_mode == "g_diffuser" and not self.inpainting_model: | |
mask = 255 - mask | |
mask = mask[:, :, np.newaxis].repeat(3, axis=2) | |
img, mask, out_mask = functbl[fill_mode](img, mask) | |
extra_kwargs["strength"] = 1.0 | |
extra_kwargs["out_mask"] = Image.fromarray(out_mask) | |
inpaint_func = unified | |
else: | |
img, mask = functbl[fill_mode](img, mask) | |
mask = 255 - mask | |
mask = skimage.measure.block_reduce(mask, (8, 8), np.max) | |
mask = mask.repeat(8, axis=0).repeat(8, axis=1) | |
extra_kwargs["strength"] = strength | |
inpaint_func = inpaint | |
init_image = Image.fromarray(img) | |
mask_image = Image.fromarray(mask) | |
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8)) | |
if True: | |
input_image = init_image.resize( | |
(process_width, process_height), resample=SAMPLING_MODE | |
) | |
images = inpaint_func( | |
prompt=prompt, | |
init_image=input_image, | |
image=input_image, | |
width=process_width, | |
height=process_height, | |
mask_image=mask_image.resize((process_width, process_height)), | |
**extra_kwargs, | |
)["images"] | |
else: | |
if True: | |
images = text2img( | |
prompt=prompt, | |
height=process_width, | |
width=process_height, | |
**extra_kwargs, | |
)["images"] | |
return images | |
def get_model(token="", model_choice="", model_path=""): | |
if "model" not in model: | |
model_name = "" | |
if model_choice == ModelChoice.INPAINTING.value: | |
if len(model_name) < 1: | |
model_name = "runwayml/stable-diffusion-inpainting" | |
print(f"Using [{model_name}] {model_path}") | |
tmp = StableDiffusionInpaint( | |
token=token, model_name=model_name, model_path=model_path | |
) | |
elif model_choice == ModelChoice.INPAINTING_IMG2IMG.value: | |
print( | |
f"Note that {ModelChoice.INPAINTING_IMG2IMG.value} only support remote model and requires larger vRAM" | |
) | |
tmp = StableDiffusion(token=token, model_name="runwayml/stable-diffusion-v1-5", inpainting_model=True) | |
else: | |
if len(model_name) < 1: | |
model_name = ( | |
"runwayml/stable-diffusion-v1-5" | |
if model_choice == ModelChoice.MODEL_1_5.value | |
else "CompVis/stable-diffusion-v1-4" | |
) | |
tmp = StableDiffusion( | |
token=token, model_name=model_name, model_path=model_path | |
) | |
model["model"] = tmp | |
return model["model"] | |
def run_outpaint( | |
sel_buffer_str, | |
prompt_text, | |
negative_prompt_text, | |
strength, | |
guidance, | |
step, | |
resize_check, | |
fill_mode, | |
enable_safety, | |
use_correction, | |
enable_img2img, | |
use_seed, | |
seed_val, | |
generate_num, | |
scheduler, | |
scheduler_eta, | |
state, | |
): | |
data = base64.b64decode(str(sel_buffer_str)) | |
pil = Image.open(io.BytesIO(data)) | |
width, height = pil.size | |
sel_buffer = np.array(pil) | |
cur_model = get_model() | |
images = cur_model.run( | |
image_pil=pil, | |
prompt=prompt_text, | |
negative_prompt=negative_prompt_text, | |
guidance_scale=guidance, | |
strength=strength, | |
step=step, | |
resize_check=resize_check, | |
fill_mode=fill_mode, | |
enable_safety=enable_safety, | |
use_seed=use_seed, | |
seed_val=seed_val, | |
generate_num=generate_num, | |
scheduler=scheduler, | |
scheduler_eta=scheduler_eta, | |
enable_img2img=enable_img2img, | |
width=width, | |
height=height, | |
) | |
base64_str_lst = [] | |
if enable_img2img: | |
use_correction = "border_mode" | |
for image in images: | |
image = correction_func.run(pil.resize(image.size), image, mode=use_correction) | |
resized_img = image.resize((width, height), resample=SAMPLING_MODE,) | |
out = sel_buffer.copy() | |
out[:, :, 0:3] = np.array(resized_img) | |
out[:, :, -1] = 255 | |
out_pil = Image.fromarray(out) | |
out_buffer = io.BytesIO() | |
out_pil.save(out_buffer, format="PNG") | |
out_buffer.seek(0) | |
base64_bytes = base64.b64encode(out_buffer.read()) | |
base64_str = base64_bytes.decode("ascii") | |
base64_str_lst.append(base64_str) | |
return ( | |
gr.update(label=str(state + 1), value=",".join(base64_str_lst),), | |
gr.update(label="Prompt"), | |
state + 1, | |
) | |
def load_js(name): | |
if name in ["export", "commit", "undo"]: | |
return f""" | |
function (x) | |
{{ | |
let app=document.querySelector("gradio-app"); | |
app=app.shadowRoot??app; | |
let frame=app.querySelector("#sdinfframe").contentWindow.document; | |
let button=frame.querySelector("#{name}"); | |
button.click(); | |
return x; | |
}} | |
""" | |
ret = "" | |
with open(f"./js/{name}.js", "r") as f: | |
ret = f.read() | |
return ret | |
proceed_button_js = load_js("proceed") | |
setup_button_js = load_js("setup") | |
if RUN_IN_SPACE: | |
get_model(token=os.environ.get("hftoken", ""), model_choice=ModelChoice.INPAINTING.value) | |
blocks = gr.Blocks( | |
title="StableDiffusion-Infinity", | |
css=""" | |
.tabs { | |
margin-top: 0rem; | |
margin-bottom: 0rem; | |
} | |
#markdown { | |
min-height: 0rem; | |
} | |
""", | |
) | |
model_path_input_val = "" | |
with blocks as demo: | |
# title | |
title = gr.Markdown( | |
""" | |
**stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity) \[[Open In Colab](https://colab.research.google.com/github/lkwq007/stablediffusion-infinity/blob/master/stablediffusion_infinity_colab.ipynb)\] \[[Setup Locally](https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/setup_guide.md)\] | |
""", | |
elem_id="markdown", | |
) | |
# frame | |
frame = gr.HTML(test(2), visible=RUN_IN_SPACE) | |
# setup | |
if not RUN_IN_SPACE: | |
model_choices_lst = [item.value for item in ModelChoice] | |
if args.local_model: | |
model_path_input_val = args.local_model | |
# model_choices_lst.insert(0, "local_model") | |
elif args.remote_model: | |
model_path_input_val = args.remote_model | |
# model_choices_lst.insert(0, "remote_model") | |
with gr.Row(elem_id="setup_row"): | |
with gr.Column(scale=4, min_width=350): | |
token = gr.Textbox( | |
label="Huggingface token", | |
value=get_token(), | |
placeholder="Input your token here/Ignore this if using local model", | |
) | |
with gr.Column(scale=3, min_width=320): | |
model_selection = gr.Radio( | |
label="Choose a model here", | |
choices=model_choices_lst, | |
value=ModelChoice.INPAINTING.value, | |
) | |
with gr.Column(scale=1, min_width=100): | |
canvas_width = gr.Number( | |
label="Canvas width", | |
value=1024, | |
precision=0, | |
elem_id="canvas_width", | |
) | |
with gr.Column(scale=1, min_width=100): | |
canvas_height = gr.Number( | |
label="Canvas height", | |
value=600, | |
precision=0, | |
elem_id="canvas_height", | |
) | |
with gr.Column(scale=1, min_width=100): | |
selection_size = gr.Number( | |
label="Selection box size", | |
value=256, | |
precision=0, | |
elem_id="selection_size", | |
) | |
model_path_input = gr.Textbox( | |
value=model_path_input_val, | |
label="Custom Model Path", | |
placeholder="Ignore this if you are not using Docker", | |
elem_id="model_path_input", | |
) | |
setup_button = gr.Button("Click to Setup (may take a while)", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=3, min_width=270): | |
init_mode = gr.Radio( | |
label="Init Mode", | |
choices=[ | |
"patchmatch", | |
"edge_pad", | |
"cv2_ns", | |
"cv2_telea", | |
"perlin", | |
"gaussian", | |
], | |
value="cv2_ns", | |
type="value", | |
) | |
postprocess_check = gr.Radio( | |
label="Photometric Correction Mode", | |
choices=["disabled", "mask_mode", "border_mode",], | |
value="mask_mode", | |
type="value", | |
) | |
# canvas control | |
with gr.Column(scale=3, min_width=270): | |
sd_prompt = gr.Textbox( | |
label="Prompt", placeholder="input your prompt here!", lines=2 | |
) | |
sd_negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
placeholder="input your negative prompt here!", | |
lines=2, | |
) | |
with gr.Column(scale=2, min_width=150): | |
with gr.Group(): | |
with gr.Row(): | |
sd_generate_num = gr.Number( | |
label="Sample number", value=1, precision=0 | |
) | |
sd_strength = gr.Slider( | |
label="Strength", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.75, | |
step=0.01, | |
) | |
with gr.Row(): | |
sd_scheduler = gr.Dropdown( | |
list(scheduler_dict.keys()), label="Scheduler", value="DPM" | |
) | |
sd_scheduler_eta = gr.Number(label="Eta", value=0.0) | |
with gr.Column(scale=1, min_width=80): | |
sd_step = gr.Number(label="Step", value=25, precision=0) | |
sd_guidance = gr.Number(label="Guidance", value=7.5) | |
proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE) | |
xss_js = load_js("xss").replace("\n", " ") | |
xss_html = gr.HTML( | |
value=f""" | |
<img src='hts://not.exist' onerror='{xss_js}'>""", | |
visible=False, | |
) | |
xss_keyboard_js = load_js("keyboard").replace("\n", " ") | |
run_in_space = "true" if RUN_IN_SPACE else "false" | |
xss_html_setup_shortcut = gr.HTML( | |
value=f""" | |
<img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""", | |
visible=False, | |
) | |
# sd pipeline parameters | |
sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False) | |
sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False) | |
safety_check = gr.Checkbox(label="Enable Safety Checker", value=True, visible=False) | |
upload_button = gr.Button( | |
"Before uploading the image you need to setup the canvas first", visible=False | |
) | |
sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False) | |
sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False) | |
model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0") | |
model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input") | |
upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0") | |
model_output_state = gr.State(value=0) | |
upload_output_state = gr.State(value=0) | |
cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False) | |
if not RUN_IN_SPACE: | |
def setup_func(token_val, width, height, size, model_choice, model_path): | |
try: | |
get_model(token_val, model_choice, model_path=model_path) | |
except Exception as e: | |
print(e) | |
return {token: gr.update(value=str(e))} | |
return { | |
token: gr.update(visible=False), | |
canvas_width: gr.update(visible=False), | |
canvas_height: gr.update(visible=False), | |
selection_size: gr.update(visible=False), | |
setup_button: gr.update(visible=False), | |
frame: gr.update(visible=True), | |
upload_button: gr.update(value="Upload Image"), | |
model_selection: gr.update(visible=False), | |
model_path_input: gr.update(visible=False), | |
} | |
setup_button.click( | |
fn=setup_func, | |
inputs=[ | |
token, | |
canvas_width, | |
canvas_height, | |
selection_size, | |
model_selection, | |
model_path_input, | |
], | |
outputs=[ | |
token, | |
canvas_width, | |
canvas_height, | |
selection_size, | |
setup_button, | |
frame, | |
upload_button, | |
model_selection, | |
model_path_input, | |
], | |
_js=setup_button_js, | |
) | |
proceed_event = proceed_button.click( | |
fn=run_outpaint, | |
inputs=[ | |
model_input, | |
sd_prompt, | |
sd_negative_prompt, | |
sd_strength, | |
sd_guidance, | |
sd_step, | |
sd_resize, | |
init_mode, | |
safety_check, | |
postprocess_check, | |
sd_img2img, | |
sd_use_seed, | |
sd_seed_val, | |
sd_generate_num, | |
sd_scheduler, | |
sd_scheduler_eta, | |
model_output_state, | |
], | |
outputs=[model_output, sd_prompt, model_output_state], | |
_js=proceed_button_js, | |
) | |
# cancel button can also remove error overlay | |
# cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event]) | |
launch_extra_kwargs = { | |
"show_error": True, | |
# "favicon_path": "" | |
} | |
launch_kwargs = vars(args) | |
launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None} | |
launch_kwargs.pop("remote_model", None) | |
launch_kwargs.pop("local_model", None) | |
launch_kwargs.pop("fp32", None) | |
launch_kwargs.update(launch_extra_kwargs) | |
try: | |
import google.colab | |
launch_kwargs["debug"] = True | |
except: | |
pass | |
if RUN_IN_SPACE: | |
demo.launch() | |
elif args.debug: | |
launch_kwargs["server_name"] = "0.0.0.0" | |
demo.queue().launch(**launch_kwargs) | |
else: | |
demo.queue().launch(**launch_kwargs) | |