from pytorch_lightning import seed_everything from scripts.demo.streamlit_helpers import * SAVE_PATH = "outputs/demo/txt2img/" SD_XL_BASE_RATIOS = { "0.5": (704, 1408), "0.52": (704, 1344), "0.57": (768, 1344), "0.6": (768, 1280), "0.68": (832, 1216), "0.72": (832, 1152), "0.78": (896, 1152), "0.82": (896, 1088), "0.88": (960, 1088), "0.94": (960, 1024), "1.0": (1024, 1024), "1.07": (1024, 960), "1.13": (1088, 960), "1.21": (1088, 896), "1.29": (1152, 896), "1.38": (1152, 832), "1.46": (1216, 832), "1.67": (1280, 768), "1.75": (1344, 768), "1.91": (1344, 704), "2.0": (1408, 704), "2.09": (1472, 704), "2.4": (1536, 640), "2.5": (1600, 640), "2.89": (1664, 576), "3.0": (1728, 576), } VERSION2SPECS = { "SDXL-base-1.0": { "H": 1024, "W": 1024, "C": 4, "f": 8, "is_legacy": False, "config": "configs/inference/sd_xl_base.yaml", "ckpt": "checkpoints/sd_xl_base_1.0.safetensors", }, "SDXL-base-0.9": { "H": 1024, "W": 1024, "C": 4, "f": 8, "is_legacy": False, "config": "configs/inference/sd_xl_base.yaml", "ckpt": "checkpoints/sd_xl_base_0.9.safetensors", }, "SD-2.1": { "H": 512, "W": 512, "C": 4, "f": 8, "is_legacy": True, "config": "configs/inference/sd_2_1.yaml", "ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors", }, "SD-2.1-768": { "H": 768, "W": 768, "C": 4, "f": 8, "is_legacy": True, "config": "configs/inference/sd_2_1_768.yaml", "ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors", }, "SDXL-refiner-0.9": { "H": 1024, "W": 1024, "C": 4, "f": 8, "is_legacy": True, "config": "configs/inference/sd_xl_refiner.yaml", "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors", }, "SDXL-refiner-1.0": { "H": 1024, "W": 1024, "C": 4, "f": 8, "is_legacy": True, "config": "configs/inference/sd_xl_refiner.yaml", "ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors", }, } def load_img(display=True, key=None, device="cuda"): image = get_interactive_image(key=key) if image is None: return None if display: st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") width, height = map( lambda x: x - x % 64, (w, h) ) # resize to integer multiple of 64 image = image.resize((width, height)) image = np.array(image.convert("RGB")) image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 return image.to(device) def run_txt2img( state, version, version_dict, is_legacy=False, return_latents=False, filter=None, stage2strength=None, ): if version.startswith("SDXL-base"): W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10) else: H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048) W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048) C = version_dict["C"] F = version_dict["f"] init_dict = { "orig_width": W, "orig_height": H, "target_width": W, "target_height": H, } value_dict = init_embedder_options( get_unique_embedder_keys_from_conditioner(state["model"].conditioner), init_dict, prompt=prompt, negative_prompt=negative_prompt, ) sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength) num_samples = num_rows * num_cols if st.button("Sample"): st.write(f"**Model I:** {version}") out = do_sample( state["model"], sampler, value_dict, num_samples, H, W, C, F, force_uc_zero_embeddings=["txt"] if not is_legacy else [], return_latents=return_latents, filter=filter, ) return out def run_img2img( state, version_dict, is_legacy=False, return_latents=False, filter=None, stage2strength=None, ): img = load_img() if img is None: return None H, W = img.shape[2], img.shape[3] init_dict = { "orig_width": W, "orig_height": H, "target_width": W, "target_height": H, } value_dict = init_embedder_options( get_unique_embedder_keys_from_conditioner(state["model"].conditioner), init_dict, prompt=prompt, negative_prompt=negative_prompt, ) strength = st.number_input( "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0 ) sampler, num_rows, num_cols = init_sampling( img2img_strength=strength, stage2strength=stage2strength, ) num_samples = num_rows * num_cols if st.button("Sample"): out = do_img2img( repeat(img, "1 ... -> n ...", n=num_samples), state["model"], sampler, value_dict, num_samples, force_uc_zero_embeddings=["txt"] if not is_legacy else [], return_latents=return_latents, filter=filter, ) return out def apply_refiner( input, state, sampler, num_samples, prompt, negative_prompt, filter=None, finish_denoising=False, ): init_dict = { "orig_width": input.shape[3] * 8, "orig_height": input.shape[2] * 8, "target_width": input.shape[3] * 8, "target_height": input.shape[2] * 8, } value_dict = init_dict value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt value_dict["crop_coords_top"] = 0 value_dict["crop_coords_left"] = 0 value_dict["aesthetic_score"] = 6.0 value_dict["negative_aesthetic_score"] = 2.5 st.warning(f"refiner input shape: {input.shape}") samples = do_img2img( input, state["model"], sampler, value_dict, num_samples, skip_encode=True, filter=filter, add_noise=not finish_denoising, ) return samples if __name__ == "__main__": st.title("Stable Diffusion") version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) version_dict = VERSION2SPECS[version] if st.checkbox("Load Model"): mode = st.radio("Mode", ("txt2img", "img2img"), 0) else: mode = "skip" st.write("__________________________") set_lowvram_mode(st.checkbox("Low vram mode", True)) if version.startswith("SDXL-base"): add_pipeline = st.checkbox("Load SDXL-refiner?", False) st.write("__________________________") else: add_pipeline = False seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) seed_everything(seed) save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) if mode != "skip": state = init_st(version_dict, load_filter=True) if state["msg"]: st.info(state["msg"]) model = state["model"] is_legacy = version_dict["is_legacy"] prompt = st.text_input( "prompt", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", ) if is_legacy: negative_prompt = st.text_input("negative prompt", "") else: negative_prompt = "" # which is unused stage2strength = None finish_denoising = False if add_pipeline: st.write("__________________________") version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"]) st.warning( f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) " ) st.write("**Refiner Options:**") version_dict2 = VERSION2SPECS[version2] state2 = init_st(version_dict2, load_filter=False) st.info(state2["msg"]) stage2strength = st.number_input( "**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0 ) sampler2, *_ = init_sampling( key=2, img2img_strength=stage2strength, specify_num_samples=False, ) st.write("__________________________") finish_denoising = st.checkbox("Finish denoising with refiner.", True) if not finish_denoising: stage2strength = None if mode == "txt2img": out = run_txt2img( state, version, version_dict, is_legacy=is_legacy, return_latents=add_pipeline, filter=state.get("filter"), stage2strength=stage2strength, ) elif mode == "img2img": out = run_img2img( state, version_dict, is_legacy=is_legacy, return_latents=add_pipeline, filter=state.get("filter"), stage2strength=stage2strength, ) elif mode == "skip": out = None else: raise ValueError(f"unknown mode {mode}") if isinstance(out, (tuple, list)): samples, samples_z = out else: samples = out samples_z = None if add_pipeline and samples_z is not None: st.write("**Running Refinement Stage**") samples = apply_refiner( samples_z, state2, sampler2, samples_z.shape[0], prompt=prompt, negative_prompt=negative_prompt if is_legacy else "", filter=state.get("filter"), finish_denoising=finish_denoising, ) if save_locally and samples is not None: perform_save_locally(save_path, samples)