import spaces import gradio as gr import torch from diffusers import AutoPipelineForInpainting from PIL import Image, ImageFilter from transformers import ( AutoModelForCausalLM, AutoTokenizer, BlipForConditionalGeneration, BlipProcessor, Owlv2ForObjectDetection, Owlv2Processor, SamModel, SamProcessor, ) def delete_model(model): model.to("cpu") del model torch.cuda.empty_cache() @spaces.GPU() def run_language_model(edit_prompt, caption, device): language_model_id = "Qwen/Qwen1.5-0.5B-Chat" language_model = AutoModelForCausalLM.from_pretrained( language_model_id, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained(language_model_id) messages = [ {"role": "system", "content": "Follow the examples and return the expected output"}, {"role": "user", "content": "Caption: a blue sky with fluffy clouds\nQuery: Make the sky stormy"}, {"role": "assistant", "content": "A: sky\nB: a stormy sky with heavy gray clouds, torrential rain, gloomy, overcast"}, {"role": "user", "content": "Caption: a cat sleeping on a sofa\nQuery: Change the cat to a dog"}, {"role": "assistant", "content": "A: cat\nB: a dog sleeping on a sofa, cozy and comfortable, snuggled up in a warm blanket, peaceful"}, {"role": "user", "content": "Caption: a snowy mountain peak\nQuery: Replace the snow with greenery"}, {"role": "assistant", "content": "A: snow\nB: a lush green mountain peak in summer, clear blue skies, birds flying overhead, serene and majestic"}, {"role": "user", "content": "Caption: a vintage car parked by the roadside\nQuery: Change the car to a modern electric vehicle"}, {"role": "assistant", "content": "A: car\nB: a sleek modern electric vehicle parked by the roadside, cutting-edge design, environmentally friendly, silent and powerful"}, {"role": "user", "content": "Caption: a wooden bridge over a river\nQuery: Make the bridge stone"}, {"role": "assistant", "content": "A: bridge\nB: an ancient stone bridge over a river, moss-covered, sturdy and timeless, with clear waters flowing beneath"}, {"role": "user", "content": "Caption: a bowl of salad on the table\nQuery: Replace salad with soup"}, {"role": "assistant", "content": "A: bowl\nB: a bowl of steaming hot soup on the table, scrumptious, with garnishing"}, {"role": "user", "content": "Caption: a book on a desk surrounded by stationery\nQuery: Remove all stationery, add a laptop"}, {"role": "assistant", "content": "A: stationery\nB: a book on a desk with a laptop next to it, modern study setup, focused and productive, technology and education combined"}, {"role": "user", "content": "Caption: a cup of coffee on a wooden table\nQuery: Change coffee to tea"}, {"role": "assistant", "content": "A: cup\nB: a steaming cup of tea on a wooden table, calming and aromatic, with a slice of lemon on the side, inviting"}, {"role": "user", "content": "Caption: a small pen on a white table\nQuery: Change the pen to an elaborate fountain pen"}, {"role": "assistant", "content": "A: pen\nB: an elaborate fountain pen on a white table, sleek and elegant, with intricate designs, ready for writing"}, {"role": "user", "content": "Caption: a plain notebook on a desk\nQuery: Replace the notebook with a journal"}, {"role": "assistant", "content": "A: notebook\nB: an artistically decorated journal on a desk, vibrant cover, filled with creativity, inspiring and personalized"}, {"role": "user", "content": f"Caption: {caption}\nQuery: {edit_prompt}"}, ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = tokenizer([text], return_tensors="pt").to(device) with torch.no_grad(): generated_ids = language_model.generate( model_inputs.input_ids, max_new_tokens=512, temperature=0.0, do_sample=False ) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] output_generation_a, output_generation_b = response.split("\n") to_replace = output_generation_a[2:].strip() replace_caption = output_generation_b[2:].strip() delete_model(language_model) return (to_replace, replace_caption) @spaces.GPU() def run_image_captioner(image, device): caption_model_id = "Salesforce/blip-image-captioning-base" caption_model = BlipForConditionalGeneration.from_pretrained(caption_model_id).to( device ) caption_processor = BlipProcessor.from_pretrained(caption_model_id) inputs = caption_processor(image, return_tensors="pt").to(device) with torch.no_grad(): outputs = caption_model.generate(**inputs, max_new_tokens=200) caption = caption_processor.decode(outputs[0], skip_special_tokens=True) delete_model(caption_model) return caption @spaces.GPU() def run_segmentation(image, object_to_segment, device): # OWL-V2 for object detection owl_v2_model_id = "google/owlv2-base-patch16-ensemble" processor = Owlv2Processor.from_pretrained(owl_v2_model_id) od_model = Owlv2ForObjectDetection.from_pretrained(owl_v2_model_id).to(device) text_queries = [object_to_segment] inputs = processor(text=text_queries, images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = od_model(**inputs) target_sizes = torch.tensor([image.size]).to(device) results = processor.post_process_object_detection( outputs, threshold=0.1, target_sizes=target_sizes )[0] boxes = results["boxes"].tolist() delete_model(od_model) # SAM for image segmentation sam_model_id = "facebook/sam-vit-base" seg_model = SamModel.from_pretrained(sam_model_id).to(device) processor = SamProcessor.from_pretrained(sam_model_id) input_boxes = [boxes] inputs = processor(image, input_boxes=input_boxes, return_tensors="pt").to(device) with torch.no_grad(): outputs = seg_model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu(), )[0] # Merge the masks masks = torch.max(masks[:, 0, ...], dim=0, keepdim=False).values delete_model(seg_model) return masks @spaces.GPU() def run_inpainting(image, replaced_caption, masks, generator, device): pipeline = AutoPipelineForInpainting.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16", ).to(device) masks = Image.fromarray(masks.numpy()) dilation_image = masks.filter(ImageFilter.MaxFilter(3)) prompt = replaced_caption negative_prompt = """lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality""" output = pipeline( prompt=prompt, image=image, mask_image=dilation_image, negative_prompt=negative_prompt, guidance_scale=7.5, strength=1.0, generator=generator, ).images[0] delete_model(pipeline) return output def run_open_gen_fill(image, edit_prompt): device = "cuda" if torch.cuda.is_available() else "cpu" # Resize the image to (512, 512) image = image.resize((512, 512)) # Caption the input image caption = run_image_captioner(image, device=device) # Run the langauge model to extract the object for segmentation # and get the replaced caption to_replace, replace_caption = run_language_model( edit_prompt=edit_prompt, caption=caption, device=device ) # Segment the `to_replace` object from the input image masks = run_segmentation(image, to_replace, device=device) # Diffusion pipeline for inpainting generator = torch.Generator(device).manual_seed(17) output = run_inpainting( image=image, replaced_caption=replaced_caption, masks=masks, generator=generator, device=device ) return ( to_replace, replace_with, caption, replaced_caption, Image.fromarray(masks.numpy()), output, ) def setup_gradio_interface(): block = gr.Blocks() with block: gr.Markdown("