{ "cells": [ { "cell_type": "markdown", "id": "a885cf5d-c525-4f5b-a8e4-f67d2f699909", "metadata": {}, "source": [ "## Copyright 2023 Google LLC" ] }, { "cell_type": "code", "execution_count": null, "id": "d891d022-8979-40d4-848f-ecb84c17f12c", "metadata": { "jp-MarkdownHeadingCollapsed": true }, "outputs": [], "source": [ "# Copyright 2023 Google LLC\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "id": "540d8642-c203-471c-a66d-0d43aabb0706", "metadata": {}, "source": [ "# StyleAligned over SDXL from input image" ] }, { "cell_type": "markdown", "id": "483d0cf9", "metadata": {}, "source": [ "#### Model Load " ] }, { "cell_type": "code", "execution_count": null, "id": "23d54ea7-f7ab-4548-9b10-ece87216dc18", "metadata": {}, "outputs": [], "source": [ "from diffusers import StableDiffusionXLPipeline, DDIMScheduler\n", "import torch\n", "import mediapy\n", "import sa_handler\n", "import math\n", "\n", "\n", "scheduler = DDIMScheduler(\n", " beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\",\n", " clip_sample=False, set_alpha_to_one=False)\n", "\n", "pipeline = StableDiffusionXLPipeline.from_pretrained(\n", " \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\",\n", " use_safetensors=True,\n", " scheduler=scheduler\n", ").to(\"cuda\")" ] }, { "cell_type": "markdown", "id": "c09b1a68", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Ref image load and inversion" ] }, { "cell_type": "code", "execution_count": null, "id": "f4717854", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# DDIM inversion\n", "\n", "from diffusers.utils import load_image\n", "import inversion\n", "import numpy as np\n", "\n", "src_style = \"medieval painting\"\n", "src_prompt = f'Man laying in a bed, {src_style}.'\n", "image_path = './example_image/medieval-bed.jpeg'\n", "\n", "num_inference_steps = 50\n", "x0 = np.array(load_image(image_path).resize((1024, 1024)))\n", "zts = inversion.ddim_inversion(pipeline, x0, src_prompt, num_inference_steps, 2)\n", "mediapy.show_image(x0, title=\"innput reference image\", height=256)" ] }, { "cell_type": "code", "execution_count": null, "id": "1751c4fe", "metadata": {}, "outputs": [], "source": [ "prompts = [\n", " src_prompt,\n", " \"A man working on a laptop\",\n", " \"A man eats pizza\",\n", " \"A woman playig on saxophone\",\n", "]\n", "\n", "# some parameters you can adjust to control fidelity to reference\n", "shared_score_shift = np.log(2) # higher value induces higher fidelity, set 0 for no shift\n", "shared_score_scale = 1.0 # higher value induces higher, set 1 for no rescale\n", "\n", "# for very famouse images consider supressing attention to refference, here is a configuration example:\n", "# shared_score_shift = np.log(1)\n", "# shared_score_scale = 0.5\n", "\n", "for i in range(1, len(prompts)):\n", " prompts[i] = f'{prompts[i]}, {src_style}.'\n", "\n", "handler = sa_handler.Handler(pipeline)\n", "sa_args = sa_handler.StyleAlignedArgs(\n", " share_group_norm=True, share_layer_norm=True, share_attention=True,\n", " adain_queries=True, adain_keys=True, adain_values=False,\n", " shared_score_shift=shared_score_shift, shared_score_scale=shared_score_scale,)\n", "handler.register(sa_args)\n", "\n", "zT, inversion_callback = inversion.make_inversion_callback(zts, offset=5)\n", "\n", "g_cpu = torch.Generator(device='cpu')\n", "g_cpu.manual_seed(10)\n", "\n", "latents = torch.randn(len(prompts), 4, 128, 128, device='cpu', generator=g_cpu,\n", " dtype=pipeline.unet.dtype,).to('cuda:0')\n", "latents[0] = zT\n", "\n", "images_a = pipeline(prompts, latents=latents,\n", " callback_on_step_end=inversion_callback,\n", " num_inference_steps=num_inference_steps, guidance_scale=10.0).images\n", "\n", "handler.remove()\n", "mediapy.show_images(images_a, titles=[p[:-(len(src_style) + 3)] for p in prompts])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }