import os, sys, time, re import torch from PIL import Image import hashlib from tqdm import tqdm import openai from utils.direction_utils import * p = "submodules/pix2pix-zero/src/utils" if p not in sys.path: sys.path.append(p) from diffusers import DDIMScheduler from edit_pipeline import EditingPipeline from ddim_inv import DDIMInversion from scheduler import DDIMInverseScheduler from lavis.models import load_model_and_preprocess from transformers import T5Tokenizer, AutoTokenizer, T5ForConditionalGeneration, BloomForCausalLM DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device=DEVICE): with torch.no_grad(): l_embeddings = [] for sent in tqdm(l_sentences): text_inputs = tokenizer( sent, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0] l_embeddings.append(prompt_embeds) return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0) def launch_generate_sample(prompt, seed, negative_scale, num_ddim): os.makedirs("tmp", exist_ok=True) # do the editing edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config) # set the random seed and sample the input noise map torch.cuda.manual_seed(int(seed)) if torch.cuda.is_available() else torch.manual_seed(int(seed)) z = torch.randn((1,4,64,64), device=DEVICE) z_hashname = hashlib.sha256(z.cpu().numpy().tobytes()).hexdigest() z_inv_fname = f"tmp/{z_hashname}_ddim_{num_ddim}_inv.pt" torch.save(z, z_inv_fname) rec_pil = edit_pipe(prompt, num_inference_steps=num_ddim, x_in=z, only_sample=True, # this flag will only generate the sampled image, not the edited image guidance_scale=negative_scale, negative_prompt="" # use the empty string for the negative prompt ) # print(rec_pil) del edit_pipe torch.cuda.empty_cache() return rec_pil[0], z_inv_fname def clean_l_sentences(ls): s = [re.sub('\d', '', x) for x in ls] s = [x.replace(".","").replace("-","").replace(")","").strip() for x in s] return s def gpt3_compute_word2sentences(task_type, word, num=100): l_sentences = [] if task_type=="object": template_prompt = f"Provide many captions for images containing {word}." elif task_type=="style": template_prompt = f"Provide many captions for images that are in the {word} style." while True: ret = openai.Completion.create( model="text-davinci-002", prompt=template_prompt, max_tokens=1000, temperature=1.0) raw_return = ret.choices[0].text for line in raw_return.split("\n"): line = line.strip() if len(line)>10: skip=False for subword in word.split(" "): if subword not in line: skip=True if not skip: l_sentences.append(line) else: l_sentences.append(line+f", {word}") time.sleep(0.05) print(len(l_sentences)) if len(l_sentences)>=num: break l_sentences = clean_l_sentences(l_sentences) return l_sentences def flant5xl_compute_word2sentences(word, num=100): text_input = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters." l_sentences = [] tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl") model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16) input_ids = tokenizer(text_input, return_tensors="pt").input_ids.to(DEVICE) input_length = input_ids.shape[1] while True: outputs = model.generate(input_ids,temperature=0.9, num_return_sequences=16, do_sample=True, max_length=128) output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True) for line in output: line = line.strip() skip=False for subword in word.split(" "): if subword not in line: skip=True if not skip: l_sentences.append(line) else: l_sentences.append(line+f", {word}") print(len(l_sentences)) if len(l_sentences)>=num: break l_sentences = clean_l_sentences(l_sentences) del model del tokenizer torch.cuda.empty_cache() return l_sentences def bloomz_compute_sentences(word, num=100): l_sentences = [] tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-7b1") model = BloomForCausalLM.from_pretrained("bigscience/bloomz-7b1", device_map="auto", torch_dtype=torch.float16) input_text = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters. Caption:" input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(DEVICE) input_length = input_ids.shape[1] t = 0.95 eta = 1e-5 min_length = 15 while True: try: outputs = model.generate(input_ids,temperature=t, num_return_sequences=16, do_sample=True, max_length=128, min_length=min_length, eta_cutoff=eta) output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True) except: continue for line in output: line = line.strip() skip=False for subword in word.split(" "): if subword not in line: skip=True if not skip: l_sentences.append(line) else: l_sentences.append(line+f", {word}") print(len(l_sentences)) if len(l_sentences)>=num: break l_sentences = clean_l_sentences(l_sentences) del model del tokenizer torch.cuda.empty_cache() return l_sentences def make_custom_dir(description, sent_type, api_key, org_key, l_custom_sentences): if sent_type=="fixed-template": l_sentences = generate_image_prompts_with_templates(description) elif "GPT3" in sent_type: import openai openai.organization = org_key openai.api_key = api_key _=openai.Model.retrieve("text-davinci-002") l_sentences = gpt3_compute_word2sentences("object", description, num=1000) elif "flan-t5-xl" in sent_type: l_sentences = flant5xl_compute_word2sentences(description, num=1000) # save the sentences to file with open(f"tmp/flant5xl_sentences_{description}.txt", "w") as f: for line in l_sentences: f.write(line+"\n") elif "BLOOMZ-7B" in sent_type: l_sentences = bloomz_compute_sentences(description, num=1000) # save the sentences to file with open(f"tmp/bloomz_sentences_{description}.txt", "w") as f: for line in l_sentences: f.write(line+"\n") elif sent_type=="custom sentences": l_sentences = l_custom_sentences.split("\n") print(f"length of new sentence is {len(l_sentences)}") pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) emb = load_sentence_embeddings(l_sentences, pipe.tokenizer, pipe.text_encoder, device=DEVICE) del pipe torch.cuda.empty_cache() return emb def launch_main(img_in_real, img_in_synth, src, src_custom, dest, dest_custom, num_ddim, xa_guidance, edit_mul, fpath_z_gen, gen_prompt, sent_type_src, sent_type_dest, api_key, org_key, custom_sentences_src, custom_sentences_dest): d_name2desc = get_all_directions_names() d_desc2name = {v:k for k,v in d_name2desc.items()} os.makedirs("tmp", exist_ok=True) # generate custom direction first if src=="make your own!": outf_name = f"tmp/template_emb_{src_custom}_{sent_type_src}.pt" if not os.path.exists(outf_name): src_emb = make_custom_dir(src_custom, sent_type_src, api_key, org_key, custom_sentences_src) torch.save(src_emb, outf_name) else: src_emb = torch.load(outf_name, map_location=torch.device('cpu'), weights_only=True) else: src_emb = get_emb(d_desc2name[src]) if dest=="make your own!": outf_name = f"tmp/template_emb_{dest_custom}_{sent_type_dest}.pt" if not os.path.exists(outf_name): dest_emb = make_custom_dir(dest_custom, sent_type_dest, api_key, org_key, custom_sentences_dest) torch.save(dest_emb, outf_name) else: dest_emb = torch.load(outf_name, map_location=torch.device('cpu'), weights_only=True) else: dest_emb = get_emb(d_desc2name[dest]) text_dir = (dest_emb.to(DEVICE) - src_emb.to(DEVICE))*edit_mul if img_in_real is not None and img_in_synth is None: print("using real image") # resize the image so that the longer side is 512 width, height = img_in_real.size if width > height: scale_factor = 512 / width else: scale_factor = 512 / height new_size = (int(width * scale_factor), int(height * scale_factor)) img_in_real = img_in_real.resize(new_size, Image.Resampling.LANCZOS) hash = hashlib.sha256(img_in_real.tobytes()).hexdigest() # print(hash) inv_fname = f"tmp/{hash}_ddim_{num_ddim}_inv.pt" caption_fname = f"tmp/{hash}_caption.txt" # make the caption if it hasn't been made before if not os.path.exists(caption_fname): # BLIP model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device(DEVICE)) _image = vis_processors["eval"](img_in_real).unsqueeze(0).to(DEVICE) prompt_str = model_blip.generate({"image": _image})[0] del model_blip torch.cuda.empty_cache() with open(caption_fname, "w") as f: f.write(prompt_str) else: prompt_str = open(caption_fname, "r").read().strip() print(f"CAPTION: {prompt_str}") # do the inversion if it hasn't been done before if not os.path.exists(inv_fname): # inversion pipeline pipe_inv = DDIMInversion.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) pipe_inv.scheduler = DDIMInverseScheduler.from_config(pipe_inv.scheduler.config) x_inv, x_inv_image, x_dec_img = pipe_inv( prompt_str, guidance_scale=1, num_inversion_steps=num_ddim, img=img_in_real, torch_dtype=torch.float32 ) x_inv = x_inv.detach() torch.save(x_inv, inv_fname) del pipe_inv torch.cuda.empty_cache() else: x_inv = torch.load(inv_fname, map_location=torch.device('cpu'), weights_only=True) # do the editing edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config) _, edit_pil = edit_pipe(prompt_str, num_inference_steps=num_ddim, x_in=x_inv, edit_dir=text_dir, guidance_amount=xa_guidance, guidance_scale=5.0, negative_prompt=prompt_str # use the unedited prompt for the negative prompt ) del edit_pipe torch.cuda.empty_cache() return edit_pil[0] elif img_in_real is None and img_in_synth is not None: print("using synthetic image") x_inv = torch.load(fpath_z_gen, map_location=torch.device('cpu'), weights_only=True) pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) rec_pil, edit_pil = pipe(gen_prompt, num_inference_steps=num_ddim, x_in=x_inv, edit_dir=text_dir, guidance_amount=xa_guidance, guidance_scale=5, negative_prompt="" # use the empty string for the negative prompt ) del pipe torch.cuda.empty_cache() return edit_pil[0] else: raise ValueError(f"Invalid image type found: {img_in_real} {img_in_synth}") if __name__=="__main__": print(flant5xl_compute_word2sentences("cat wearing sunglasses", num=100))