{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import os\n", "from transformers import ViTModel, ViTImageProcessor\n", "from utils import text_encoder_forward\n", "from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler\n", "from utils import latents_to_images, downsampling, merge_and_save_images\n", "from omegaconf import OmegaConf\n", "from accelerate.utils import set_seed\n", "from tqdm import tqdm\n", "from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput\n", "from PIL import Image\n", "from models.celeb_embeddings import embedding_forward\n", "import models.embedding_manager\n", "import importlib\n", "\n", "# seed = 42\n", "# set_seed(seed) \n", "# torch.cuda.set_device(0)\n", "\n", "# set your sd2.1 path\n", "model_path = \"/home/user/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6\"\n", "pipe = StableDiffusionPipeline.from_pretrained(model_path) \n", "pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n", "pipe = pipe.to(\"cuda\")\n", "\n", "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "vae = pipe.vae\n", "unet = pipe.unet\n", "text_encoder = pipe.text_encoder\n", "tokenizer = pipe.tokenizer\n", "scheduler = pipe.scheduler\n", "\n", "input_dim = 64\n", "\n", "experiment_name = \"normal_GAN\" # \"normal_GAN\", \"man_GAN\", \"woman_GAN\" , \n", "if experiment_name == \"normal_GAN\":\n", " steps = 10000\n", "elif experiment_name == \"man_GAN\":\n", " steps = 7000\n", "elif experiment_name == \"woman_GAN\":\n", " steps = 6000\n", "else:\n", " print(\"Hello, please notice this ^_^\")\n", " assert 0\n", "\n", "\n", "original_forward = text_encoder.text_model.embeddings.forward\n", "text_encoder.text_model.embeddings.forward = embedding_forward.__get__(text_encoder.text_model.embeddings)\n", "embedding_manager_config = OmegaConf.load(\"datasets_face/identity_space.yaml\")\n", "Embedding_Manager = models.embedding_manager.EmbeddingManagerId_adain( \n", " tokenizer,\n", " text_encoder,\n", " device = device,\n", " training = True,\n", " experiment_name = experiment_name, \n", " num_embeds_per_token = embedding_manager_config.model.personalization_config.params.num_embeds_per_token, \n", " token_dim = embedding_manager_config.model.personalization_config.params.token_dim,\n", " mlp_depth = embedding_manager_config.model.personalization_config.params.mlp_depth,\n", " loss_type = embedding_manager_config.model.personalization_config.params.loss_type,\n", " vit_out_dim = input_dim,\n", ")\n", "embedding_path = os.path.join(\"training_weight\", experiment_name, \"embeddings_manager-{}.pt\".format(str(steps)))\n", "Embedding_Manager.load(embedding_path)\n", "text_encoder.text_model.embeddings.forward = original_forward\n", "\n", "print(\"finish init\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. create a new character and test with prompts" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# sample a z\n", "random_embedding = torch.randn(1, 1, input_dim).to(device)\n", "\n", "# map z to pseudo identity embeddings\n", "_, emb_dict = Embedding_Manager(tokenized_text=None, embedded_text=None, name_batch=None, random_embeddings = random_embedding, timesteps = None,)\n", "\n", "test_emb = emb_dict[\"adained_total_embedding\"].to(device)\n", "\n", "v1_emb = test_emb[:, 0]\n", "v2_emb = test_emb[:, 1]\n", "embeddings = [v1_emb, v2_emb]\n", "\n", "index = \"0000\"\n", "save_dir = os.path.join(\"test_results/\" + experiment_name, index)\n", "os.makedirs(save_dir, exist_ok=True)\n", "test_emb_path = os.path.join(save_dir, \"id_embeddings.pt\")\n", "torch.save(test_emb, test_emb_path)\n", "\n", "'''insert into tokenizer & embedding layer'''\n", "tokens = [\"v1*\", \"v2*\"]\n", "embeddings = [v1_emb, v2_emb]\n", "# add tokens and get ids\n", "tokenizer.add_tokens(tokens)\n", "token_ids = tokenizer.convert_tokens_to_ids(tokens)\n", "\n", "# resize token embeddings and set new embeddings\n", "text_encoder.resize_token_embeddings(len(tokenizer), pad_to_multiple_of = 8)\n", "for token_id, embedding in zip(token_ids, embeddings):\n", " text_encoder.get_input_embeddings().weight.data[token_id] = embedding\n", "\n", "prompts_list = [\"a photo of v1* v2*, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing a Superman outfit, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing a spacesuit, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing a red sweater, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing a blue hoodie, facing to camera, best quality, ultra high res\",\n", "]\n", "\n", "for prompt in prompts_list:\n", " image = pipe(prompt, guidance_scale = 8.5).images[0]\n", " save_img_path = os.path.join(save_dir, prompt.replace(\"v1* v2*\", \"a person\") + '.png')\n", " image.save(save_img_path)\n", " print(save_img_path)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2. directly use a chosen generated pseudo identity embeddings" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# the path of your generated embeddings\n", "test_emb_path = \"demo_embeddings/856.pt\" # \"test_results/normal_GAN/0000/id_embeddings.pt\"\n", "test_emb = torch.load(test_emb_path).cuda()\n", "v1_emb = test_emb[:, 0]\n", "v2_emb = test_emb[:, 1]\n", "\n", "\n", "index = \"chosen_index\"\n", "save_dir = os.path.join(\"test_results/\" + experiment_name, index)\n", "os.makedirs(save_dir, exist_ok=True)\n", "\n", "\n", "'''insert into tokenizer & embedding layer'''\n", "tokens = [\"v1*\", \"v2*\"]\n", "embeddings = [v1_emb, v2_emb]\n", "# add tokens and get ids\n", "tokenizer.add_tokens(tokens)\n", "token_ids = tokenizer.convert_tokens_to_ids(tokens)\n", "\n", "# resize token embeddings and set new embeddings\n", "text_encoder.resize_token_embeddings(len(tokenizer), pad_to_multiple_of = 8)\n", "for token_id, embedding in zip(token_ids, embeddings):\n", " text_encoder.get_input_embeddings().weight.data[token_id] = embedding\n", "\n", "prompts_list = [\"a photo of v1* v2*, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing a Superman outfit, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing a spacesuit, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing a red sweater, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing a purple wizard outfit, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing a blue hoodie, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing headphones, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* with red hair, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing headphones with red hair, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing a Christmas hat, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing sunglasses, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing sunglasses and necklace, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing a blue cap, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing a doctoral cap, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* with white hair, wearing glasses, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* in a helmet and vest riding a motorcycle, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* holding a bottle of red wine, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* driving a bus in the desert, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* playing basketball, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* playing the violin, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* piloting a spaceship, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* riding a horse, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* coding in front of a computer, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* laughing on the lawn, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* frowning at the camera, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* happily smiling, looking at the camera, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* crying disappointedly, with tears flowing, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing sunglasses, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* playing the guitar in the view of left side, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* holding a bottle of red wine, upper body, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing sunglasses and necklace, close-up, in the view of right side, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* riding a horse, in the view of the top, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* wearing a doctoral cap, upper body, with the left side of the face facing the camera, best quality, ultra high res\",\n", " \"v1* v2* crying disappointedly, with tears flowing, with left side of the face facing the camera, best quality, ultra high res\",\n", " \"v1* v2* sitting in front of the camera, with a beautiful purple sunset at the beach in the background, best quality, ultra high res\",\n", " \"v1* v2* swimming in the pool, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* climbing a mountain, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* skiing on the snowy mountain, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* in the snow, facing to camera, best quality, ultra high res\",\n", " \"v1* v2* in space wearing a spacesuit, facing to camera, best quality, ultra high res\",\n", "]\n", "\n", "for prompt in prompts_list:\n", " image = pipe(prompt, guidance_scale = 8.5).images[0]\n", " save_img_path = os.path.join(save_dir, prompt.replace(\"v1* v2*\", \"a person\") + '.png')\n", " image.save(save_img_path)\n", " print(save_img_path)" ] } ], "metadata": { "kernelspec": { "display_name": "lbl", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 2 }