pOpsPaper commited on
Commit
71d3bec
1 Parent(s): 5fdb8b2

Added space

Browse files
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pops import PopsPipelines
3
+
4
+ BLOCK_WIDTH = 250
5
+ BLOCK_HEIGHT = 270
6
+ FONT_SIZE = 3.5
7
+
8
+ pops_pipelines = PopsPipelines()
9
+
10
+ def run_equation_1(object_path, text, texture_path):
11
+ image = pops_pipelines.run_instruct_texture(object_path, text, texture_path)
12
+ return image
13
+
14
+ def run_equation_2(object_path, texture_path, scene_path):
15
+ image = pops_pipelines.run_texture_scene(object_path, texture_path, scene_path)
16
+ return image
17
+
18
+ with gr.Blocks(css='style.css') as demo:
19
+ gr.HTML('''<h1>p<span class="o-pops">O</span>ps: Photo-Inspired Diffusion <span class="o-operators">O</span>perators</h1>''')
20
+ gr.HTML('<div style="text-align: center;"><h3><a href="https://popspaper.github.io/pOps/">https://popspaper.github.io/pOps/</a></h3></div>')
21
+ gr.HTML(
22
+ '<div style="text-align: center;">Our method learns operators that are applied directly in the image embedding space, resulting in a variety of semantic operations that can then be realized as images using an image diffusion model.</div>')
23
+ with gr.Row(equal_height=True,elem_classes='justified-element'):
24
+ with gr.Column(scale=0,min_width=BLOCK_WIDTH):
25
+ object_path_eq_1 = gr.Image(label="Upload object image", type="filepath",width=BLOCK_WIDTH,height=BLOCK_HEIGHT)
26
+ with gr.Column(scale=0,min_width=50):
27
+ gr.HTML(f'''<div style="justify-content: center; align-items: center;min-height:{BLOCK_HEIGHT}px"><span class="vertical-center" style="color:#82cf8e;font-size:{FONT_SIZE}rem;font-family:'Google Sans', sans-serif";>O</span></div>''')
28
+ with gr.Column(scale=0,min_width=200):
29
+ with gr.Group(elem_classes='instruct'):
30
+ text_eq_1 = gr.Textbox(value="",label="Enter adjective",max_lines=1,placeholder='e.g. melting, shiny, spiky',elem_classes='vertical-center')
31
+ with gr.Column(scale=0,min_width=50):
32
+ gr.HTML(f'''<div style="justify-content: center; align-items: center;min-height:{BLOCK_HEIGHT}px"><span class="vertical-center" style="color:#efa241;font-size:{FONT_SIZE}rem;font-family:'Google Sans', sans-serif";>O</span></div>''')
33
+ with gr.Column(scale=0,min_width=BLOCK_WIDTH):
34
+ texture_path_eq_1 = gr.Image(label="Upload texture image", type="filepath",width=BLOCK_WIDTH,height=BLOCK_HEIGHT)
35
+ with gr.Column(scale=0,min_width=50):
36
+ gr.HTML(f'''<div style="justify-content: center; align-items: center;min-height:{BLOCK_HEIGHT}px"><span class="vertical-center" style="color:#efa241;font-size:{FONT_SIZE}rem;font-family:'Google Sans', sans-serif";>=</span></div>''')
37
+ with gr.Column(scale=0,min_width=BLOCK_WIDTH):
38
+ output_eq_1 = gr.Image(label="Output",width=BLOCK_WIDTH,height=BLOCK_HEIGHT)
39
+ with gr.Row(equal_height=True, elem_classes='justified-element'):
40
+ run_button_eq_1 = gr.Button("Run Instruct and Texture Equation",elem_classes='small-elem')
41
+ run_button_eq_1.click(fn=run_equation_1,inputs=[object_path_eq_1, text_eq_1, texture_path_eq_1],outputs=[output_eq_1])
42
+ with gr.Row(equal_height=True, elem_classes='justified-element'):
43
+ pass
44
+ with gr.Row(equal_height=True,elem_classes='justified-element'):
45
+ with gr.Column(scale=0,min_width=BLOCK_WIDTH):
46
+ object_path_eq_2 = gr.Image(label="Upload object image", type="filepath",width=BLOCK_WIDTH,height=BLOCK_HEIGHT)
47
+ with gr.Column(scale=0,min_width=50):
48
+ gr.HTML(f'''<div style="justify-content: center; align-items: center;min-height:{BLOCK_HEIGHT}px"><span class="vertical-center" style="color:#efa241;font-size:{FONT_SIZE}rem;font-family:'Google Sans', sans-serif";>O</span></div>''')
49
+ with gr.Column(scale=0,min_width=BLOCK_WIDTH):
50
+ texture_path_eq_2 = gr.Image(label="Upload texture image", type="filepath",width=BLOCK_WIDTH,height=BLOCK_HEIGHT)
51
+ # texture_path = gr.Image(label="Upload texture image", type="filepath",width=BLOCK_WIDTH,height=BLOCK_HEIGHT)
52
+ with gr.Column(scale=0,min_width=50):
53
+ gr.HTML(f'''<div style="justify-content: center; align-items: center;min-height:{BLOCK_HEIGHT}px"><span class="vertical-center" style="color:#A085FF;font-size:{FONT_SIZE}rem;font-family:'Google Sans', sans-serif";>O</span></div>''')
54
+ with gr.Column(scale=0,min_width=BLOCK_WIDTH):
55
+ scene_path_eq_2 = gr.Image(label="Upload scene image", type="filepath",width=BLOCK_WIDTH,height=BLOCK_HEIGHT)
56
+ with gr.Column(scale=0,min_width=50):
57
+ gr.HTML(f'''<div style="justify-content: center; align-items: center;min-height:{BLOCK_HEIGHT}px"><span class="vertical-center" style="color:#A085FF;font-size:{FONT_SIZE}rem;font-family:'Google Sans', sans-serif";>=</span></div>''')
58
+ with gr.Column(scale=0,min_width=BLOCK_WIDTH):
59
+ output_eq_2 = gr.Image(label="Output",width=BLOCK_WIDTH,height=BLOCK_HEIGHT)
60
+ with gr.Row(equal_height=True, elem_classes='justified-element'):
61
+ run_button_eq_2 = gr.Button("Run Texture and Scene Equation",elem_classes='small-elem')
62
+ run_button_eq_2.click(fn=run_equation_2,inputs=[object_path_eq_2, texture_path_eq_2, scene_path_eq_2],outputs=[output_eq_2])
63
+
64
+
65
+ with gr.Row(equal_height=True, elem_classes='justified-element'):
66
+ with gr.Column(scale=1):
67
+ examples = [
68
+ ['inputs/birmingham-museums-trust-q2OwlfXAYfo-unsplash.jpg', 'enormous',
69
+ 'inputs/mihaly-varga-AQFfdEY3X4Q-unsplash.jpg'],
70
+ ['inputs/r-n-tyfqOL1FAQc-unsplash.jpg', 'group', 'inputs/george-webster-p1VZ5IbT2Tg-unsplash.jpg'],
71
+ ]
72
+ gr.Examples(examples=examples,
73
+ inputs=[object_path_eq_1, text_eq_1, texture_path_eq_1],
74
+ outputs=[output_eq_1],
75
+ fn=run_equation_1,
76
+ cache_examples=False)
77
+ examples_2 = [
78
+ ['inputs/hannah-pemberton-3d82e5_ylGo-unsplash.jpg', 'inputs/engin-akyurt-aXVro7lQyUM-unsplash.jpg', 'inputs/alexandra-zelena-phskyemu_c4-unsplash.jpg'],
79
+ ]
80
+ gr.Examples(examples=examples_2,
81
+ inputs=[object_path_eq_2, texture_path_eq_2, scene_path_eq_2],
82
+ outputs=[output_eq_2],
83
+ fn=run_equation_2,
84
+ cache_examples=False)
85
+ with gr.Column(scale=1):
86
+ gr.HTML('''
87
+ <div class="column">
88
+ <h2 class="">🎶 Learn More 🎶</h2>
89
+ <div class="">
90
+ <div height="100%">
91
+ <video src="https://github.com/pOpsPaper/pOps/raw/gh-pages/static/figures/teaser_video.mp4" controls ></video>
92
+ </div>
93
+ </div>
94
+ <div class=""><small>
95
+ Audio track for the teaser video was generated with the help of <a href="https://suno.com/">suno</a>.
96
+ </small>
97
+ </div>
98
+
99
+ ''')
100
+
101
+ demo.queue().launch()
inputs/alexandra-zelena-phskyemu_c4-unsplash.jpg ADDED
inputs/birmingham-museums-trust-q2OwlfXAYfo-unsplash.jpg ADDED
inputs/engin-akyurt-aXVro7lQyUM-unsplash.jpg ADDED
inputs/george-webster-p1VZ5IbT2Tg-unsplash.jpg ADDED
inputs/hannah-pemberton-3d82e5_ylGo-unsplash.jpg ADDED
inputs/mihaly-varga-AQFfdEY3X4Q-unsplash.jpg ADDED
inputs/r-n-tyfqOL1FAQc-unsplash.jpg ADDED
model/__init__.py ADDED
File without changes
model/pipeline_pops.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import PIL
4
+ import torch
5
+ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
6
+
7
+ from diffusers.models import PriorTransformer
8
+ from diffusers.schedulers import UnCLIPScheduler
9
+ from diffusers.utils import (
10
+ is_accelerate_available,
11
+ is_accelerate_version,
12
+ logging,
13
+ replace_example_docstring,
14
+ )
15
+ from diffusers.utils.torch_utils import randn_tensor
16
+
17
+ from diffusers.pipelines.kandinsky import KandinskyPriorPipelineOutput
18
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
19
+
20
+
21
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
+
23
+ EXAMPLE_DOC_STRING = """
24
+ Examples:
25
+ ```py
26
+ >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline
27
+ >>> import torch
28
+
29
+ >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior")
30
+ >>> pipe_prior.to("cuda")
31
+ >>> prompt = "red cat, 4k photo"
32
+ >>> image_emb, negative_image_emb = pipe_prior(prompt).to_tuple()
33
+
34
+ >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder")
35
+ >>> pipe.to("cuda")
36
+ >>> image = pipe(
37
+ ... image_embeds=image_emb,
38
+ ... negative_image_embeds=negative_image_emb,
39
+ ... height=768,
40
+ ... width=768,
41
+ ... num_inference_steps=50,
42
+ ... ).images
43
+ >>> image[0].save("cat.png")
44
+ ```
45
+ """
46
+
47
+ EXAMPLE_INTERPOLATE_DOC_STRING = """
48
+ Examples:
49
+ ```py
50
+ >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline
51
+ >>> from diffusers.utils import load_image
52
+ >>> import PIL
53
+ >>> import torch
54
+ >>> from torchvision import transforms
55
+
56
+ >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
57
+ ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
58
+ ... )
59
+ >>> pipe_prior.to("cuda")
60
+ >>> img1 = load_image(
61
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
62
+ ... "/kandinsky/cat.png"
63
+ ... )
64
+ >>> img2 = load_image(
65
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
66
+ ... "/kandinsky/starry_night.jpeg"
67
+ ... )
68
+ >>> images_texts = ["a cat", img1, img2]
69
+ >>> weights = [0.3, 0.3, 0.4]
70
+ >>> out = pipe_prior.interpolate(images_texts, weights)
71
+ >>> pipe = KandinskyV22Pipeline.from_pretrained(
72
+ ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
73
+ ... )
74
+ >>> pipe.to("cuda")
75
+ >>> image = pipe(
76
+ ... image_embeds=out.image_embeds,
77
+ ... negative_image_embeds=out.negative_image_embeds,
78
+ ... height=768,
79
+ ... width=768,
80
+ ... num_inference_steps=50,
81
+ ... ).images[0]
82
+ >>> image.save("starry_cat.png")
83
+ ```
84
+ """
85
+
86
+
87
+ class pOpsPipeline(DiffusionPipeline):
88
+ """
89
+ Pipeline for generating image prior for Kandinsky
90
+
91
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
92
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
93
+
94
+ Args:
95
+ prior ([`PriorTransformer`]):
96
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
97
+ image_encoder ([`CLIPVisionModelWithProjection`]):
98
+ Frozen image-encoder.
99
+ text_encoder ([`CLIPTextModelWithProjection`]):
100
+ Frozen text-encoder.
101
+ tokenizer (`CLIPTokenizer`):
102
+ Tokenizer of class
103
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
104
+ scheduler ([`UnCLIPScheduler`]):
105
+ A scheduler to be used in combination with `prior` to generate image embedding.
106
+ image_processor ([`CLIPImageProcessor`]):
107
+ A image_processor to be used to preprocess image from clip.
108
+ """
109
+
110
+ _exclude_from_cpu_offload = ["prior"]
111
+
112
+ def __init__(
113
+ self,
114
+ prior: PriorTransformer,
115
+ image_encoder: CLIPVisionModelWithProjection,
116
+ text_encoder: CLIPTextModelWithProjection,
117
+ tokenizer: CLIPTokenizer,
118
+ scheduler: UnCLIPScheduler,
119
+ image_processor: CLIPImageProcessor,
120
+ ):
121
+ super().__init__()
122
+
123
+ self.register_modules(
124
+ prior=prior,
125
+ text_encoder=text_encoder,
126
+ tokenizer=tokenizer,
127
+ scheduler=scheduler,
128
+ image_encoder=image_encoder,
129
+ image_processor=image_processor,
130
+ )
131
+
132
+ @torch.no_grad()
133
+ @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING)
134
+ def interpolate(
135
+ self,
136
+ images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]],
137
+ weights: List[float],
138
+ num_images_per_prompt: int = 1,
139
+ num_inference_steps: int = 25,
140
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
141
+ latents: Optional[torch.FloatTensor] = None,
142
+ negative_prior_prompt: Optional[str] = None,
143
+ negative_prompt: str = "",
144
+ guidance_scale: float = 4.0,
145
+ device=None,
146
+ ):
147
+ """
148
+ Function invoked when using the prior pipeline for interpolation.
149
+
150
+ Args:
151
+ images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`):
152
+ list of prompts and images to guide the image generation.
153
+ weights: (`List[float]`):
154
+ list of weights for each condition in `images_and_prompts`
155
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
156
+ The number of images to generate per prompt.
157
+ num_inference_steps (`int`, *optional*, defaults to 100):
158
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
159
+ expense of slower inference.
160
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
161
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
162
+ to make generation deterministic.
163
+ latents (`torch.FloatTensor`, *optional*):
164
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
165
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
166
+ tensor will ge generated by sampling using the supplied random `generator`.
167
+ negative_prior_prompt (`str`, *optional*):
168
+ The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
169
+ `guidance_scale` is less than `1`).
170
+ negative_prompt (`str` or `List[str]`, *optional*):
171
+ The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
172
+ `guidance_scale` is less than `1`).
173
+ guidance_scale (`float`, *optional*, defaults to 4.0):
174
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
175
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
176
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
177
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
178
+ usually at the expense of lower image quality.
179
+
180
+ Examples:
181
+
182
+ Returns:
183
+ [`KandinskyPriorPipelineOutput`] or `tuple`
184
+ """
185
+
186
+ device = device or self.device
187
+
188
+ if len(images_and_prompts) != len(weights):
189
+ raise ValueError(
190
+ f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length"
191
+ )
192
+
193
+ image_embeddings = []
194
+ for cond, weight in zip(images_and_prompts, weights):
195
+ if isinstance(cond, str):
196
+ image_emb = self(
197
+ cond,
198
+ num_inference_steps=num_inference_steps,
199
+ num_images_per_prompt=num_images_per_prompt,
200
+ generator=generator,
201
+ latents=latents,
202
+ negative_prompt=negative_prior_prompt,
203
+ guidance_scale=guidance_scale,
204
+ ).image_embeds.unsqueeze(0)
205
+
206
+ elif isinstance(cond, (PIL.Image.Image, torch.Tensor)):
207
+ if isinstance(cond, PIL.Image.Image):
208
+ cond = (
209
+ self.image_processor(cond, return_tensors="pt")
210
+ .pixel_values[0]
211
+ .unsqueeze(0)
212
+ .to(dtype=self.image_encoder.dtype, device=device)
213
+ )
214
+
215
+ image_emb = self.image_encoder(cond)["image_embeds"].repeat(num_images_per_prompt, 1).unsqueeze(0)
216
+
217
+ else:
218
+ raise ValueError(
219
+ f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}"
220
+ )
221
+
222
+ image_embeddings.append(image_emb * weight)
223
+
224
+ image_emb = torch.cat(image_embeddings).sum(dim=0)
225
+
226
+ out_zero = self(
227
+ negative_prompt,
228
+ num_inference_steps=num_inference_steps,
229
+ num_images_per_prompt=num_images_per_prompt,
230
+ generator=generator,
231
+ latents=latents,
232
+ negative_prompt=negative_prior_prompt,
233
+ guidance_scale=guidance_scale,
234
+ )
235
+ zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds
236
+
237
+ return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb)
238
+
239
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
240
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
241
+ if latents is None:
242
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
243
+ else:
244
+ if latents.shape != shape:
245
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
246
+ latents = latents.to(device)
247
+
248
+ latents = latents * scheduler.init_noise_sigma
249
+ return latents
250
+
251
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed
252
+ def get_zero_embed(self, batch_size=1, device=None):
253
+ device = device or self.device
254
+ zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to(
255
+ device=device, dtype=self.image_encoder.dtype
256
+ )
257
+ zero_image_emb = self.image_encoder(zero_img)["image_embeds"]
258
+ zero_image_emb = zero_image_emb.repeat(batch_size, 1)
259
+ return zero_image_emb
260
+
261
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt
262
+ def _encode_prompt(
263
+ self,
264
+ prompt,
265
+ device,
266
+ num_images_per_prompt,
267
+ do_classifier_free_guidance,
268
+ negative_prompt=None,
269
+ ):
270
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
271
+ # get prompt text embeddings
272
+ text_inputs = self.tokenizer(
273
+ prompt,
274
+ padding="max_length",
275
+ max_length=self.tokenizer.model_max_length,
276
+ truncation=True,
277
+ return_tensors="pt",
278
+ )
279
+ text_input_ids = text_inputs.input_ids
280
+ text_mask = text_inputs.attention_mask.bool().to(device)
281
+
282
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
283
+
284
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
285
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
286
+ logger.warning(
287
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
288
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
289
+ )
290
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
291
+
292
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
293
+
294
+ prompt_embeds = text_encoder_output.text_embeds
295
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
296
+
297
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
298
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
299
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
300
+
301
+ if do_classifier_free_guidance:
302
+ uncond_tokens: List[str]
303
+ if negative_prompt is None:
304
+ uncond_tokens = [""] * batch_size
305
+ elif type(prompt) is not type(negative_prompt):
306
+ raise TypeError(
307
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
308
+ f" {type(prompt)}."
309
+ )
310
+ elif isinstance(negative_prompt, str):
311
+ uncond_tokens = [negative_prompt]
312
+ elif batch_size != len(negative_prompt):
313
+ raise ValueError(
314
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
315
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
316
+ " the batch size of `prompt`."
317
+ )
318
+ else:
319
+ uncond_tokens = negative_prompt
320
+
321
+ uncond_input = self.tokenizer(
322
+ uncond_tokens,
323
+ padding="max_length",
324
+ max_length=self.tokenizer.model_max_length,
325
+ truncation=True,
326
+ return_tensors="pt",
327
+ )
328
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
329
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
330
+
331
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
332
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
333
+
334
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
335
+
336
+ seq_len = negative_prompt_embeds.shape[1]
337
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
338
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
339
+
340
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
341
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
342
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
343
+ batch_size * num_images_per_prompt, seq_len, -1
344
+ )
345
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
346
+
347
+ # done duplicates
348
+
349
+ # For classifier free guidance, we need to do two forward passes.
350
+ # Here we concatenate the unconditional and text embeddings into a single batch
351
+ # to avoid doing two forward passes
352
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
353
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
354
+
355
+ text_mask = torch.cat([uncond_text_mask, text_mask])
356
+
357
+ return prompt_embeds, text_encoder_hidden_states, text_mask
358
+
359
+ def enable_model_cpu_offload(self, gpu_id=0):
360
+ r"""
361
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
362
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
363
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
364
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
365
+ """
366
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
367
+ from accelerate import cpu_offload_with_hook
368
+ else:
369
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
370
+
371
+ device = torch.device(f"cuda:{gpu_id}")
372
+
373
+ if self.device.type != "cpu":
374
+ self.to("cpu", silence_dtype_warnings=True)
375
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
376
+
377
+ hook = None
378
+ for cpu_offloaded_model in [self.text_encoder, self.prior]:
379
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
380
+
381
+ # We'll offload the last model manually.
382
+ self.prior_hook = hook
383
+
384
+ _, hook = cpu_offload_with_hook(self.image_encoder, device, prev_module_hook=self.prior_hook)
385
+
386
+ self.final_offload_hook = hook
387
+
388
+ @torch.no_grad()
389
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
390
+ def __call__(
391
+ self,
392
+ input_embeds: torch.FloatTensor,
393
+ input_hidden_states: torch.FloatTensor,
394
+ negative_input_embeds: Optional[torch.FloatTensor] = None,
395
+ negative_input_hidden_states: Optional[torch.FloatTensor] = None,
396
+ input_mask: Optional[torch.FloatTensor]=None,
397
+ num_images_per_prompt: int = 1,
398
+ num_inference_steps: int = 25,
399
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
400
+ latents: Optional[torch.FloatTensor] = None,
401
+ guidance_scale: float = 1.0,
402
+ output_type: Optional[str] = "pt", # pt only
403
+ return_dict: bool = True,
404
+ ):
405
+ """
406
+ Function invoked when calling the pipeline for generation.
407
+
408
+ Args:
409
+ prompt (`str` or `List[str]`):
410
+ The prompt or prompts to guide the image generation.
411
+ negative_prompt (`str` or `List[str]`, *optional*):
412
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
413
+ if `guidance_scale` is less than `1`).
414
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
415
+ The number of images to generate per prompt.
416
+ num_inference_steps (`int`, *optional*, defaults to 100):
417
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
418
+ expense of slower inference.
419
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
420
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
421
+ to make generation deterministic.
422
+ latents (`torch.FloatTensor`, *optional*):
423
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
424
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
425
+ tensor will ge generated by sampling using the supplied random `generator`.
426
+ guidance_scale (`float`, *optional*, defaults to 4.0):
427
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
428
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
429
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
430
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
431
+ usually at the expense of lower image quality.
432
+ output_type (`str`, *optional*, defaults to `"pt"`):
433
+ The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
434
+ (`torch.Tensor`).
435
+ return_dict (`bool`, *optional*, defaults to `True`):
436
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
437
+
438
+ Examples:
439
+
440
+ Returns:
441
+ [`KandinskyPriorPipelineOutput`] or `tuple`
442
+ """
443
+
444
+ do_classifier_free_guidance = guidance_scale > 1.0
445
+ if do_classifier_free_guidance:
446
+ if negative_input_embeds is None or negative_input_hidden_states is None:
447
+ raise ValueError('negative_input_embeds and negative_input_hidden_states must be provided')
448
+
449
+ device = self._execution_device
450
+
451
+ batch_size = input_embeds.shape[0]
452
+ batch_size = batch_size * num_images_per_prompt
453
+
454
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
455
+ "", device, num_images_per_prompt, False, ""
456
+ )
457
+
458
+ # prior
459
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
460
+ prior_timesteps_tensor = self.scheduler.timesteps
461
+
462
+ embedding_dim = self.prior.config.embedding_dim
463
+
464
+ latents = self.prepare_latents(
465
+ (batch_size, embedding_dim),
466
+ prompt_embeds.dtype,
467
+ device,
468
+ generator,
469
+ latents,
470
+ self.scheduler,
471
+ )
472
+
473
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
474
+ # expand the latents if we are doing classifier free guidance
475
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
476
+
477
+ # TODO: I can stop being dependent on the text encoder size
478
+ image_feat_seq = torch.zeros_like(text_encoder_hidden_states)
479
+ image_feat_seq[:, :input_hidden_states.shape[1]] = input_hidden_states
480
+ if input_mask is not None:
481
+ image_txt_mask = input_mask
482
+ else:
483
+ image_txt_mask = torch.zeros_like(text_mask)
484
+ image_txt_mask[:, :input_hidden_states.shape[1]] = 1
485
+ proj_embedding = input_embeds
486
+
487
+ if do_classifier_free_guidance:
488
+ neg_image_feat_seq = torch.zeros_like(text_encoder_hidden_states)
489
+ neg_image_feat_seq[:, :negative_input_hidden_states.shape[1]] = negative_input_hidden_states
490
+ if input_mask is not None:
491
+ neg_image_txt_mask = input_mask
492
+ else:
493
+ neg_image_txt_mask = torch.zeros_like(text_mask)
494
+ neg_image_txt_mask[:, :negative_input_hidden_states.shape[1]] = 1
495
+ proj_embedding = torch.cat([negative_input_embeds, proj_embedding])
496
+ image_feat_seq = torch.cat([neg_image_feat_seq, image_feat_seq])
497
+ image_txt_mask = torch.cat([neg_image_txt_mask, image_txt_mask])
498
+
499
+ predicted_image_embedding = self.prior(
500
+ latent_model_input,
501
+ timestep=t,
502
+ proj_embedding=proj_embedding,
503
+ encoder_hidden_states=image_feat_seq,
504
+ attention_mask=image_txt_mask,
505
+ ).predicted_image_embedding
506
+
507
+ if do_classifier_free_guidance:
508
+ # print(f'Doing guidance with scale {guidance_scale}')
509
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
510
+ predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * (
511
+ predicted_image_embedding_text - predicted_image_embedding_uncond
512
+ )
513
+
514
+ if i + 1 == prior_timesteps_tensor.shape[0]:
515
+ prev_timestep = None
516
+ else:
517
+ prev_timestep = prior_timesteps_tensor[i + 1]
518
+
519
+ latents = self.scheduler.step(
520
+ predicted_image_embedding,
521
+ timestep=t,
522
+ sample=latents,
523
+ generator=generator,
524
+ prev_timestep=prev_timestep,
525
+ ).prev_sample
526
+
527
+ latents = self.prior.post_process_latents(latents)
528
+
529
+ image_embeddings = latents
530
+
531
+ # if negative prompt has been defined, we retrieve split the image embedding into two
532
+ # if negative_prompt is None:
533
+ zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
534
+
535
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
536
+ self.final_offload_hook.offload()
537
+ # else:
538
+ # image_embeddings, zero_embeds = image_embeddings.chunk(2)
539
+ #
540
+ # if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
541
+ # self.prior_hook.offload()
542
+
543
+ if output_type not in ["pt", "np"]:
544
+ raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
545
+
546
+ if output_type == "np":
547
+ image_embeddings = image_embeddings.cpu().numpy()
548
+ zero_embeds = zero_embeds.cpu().numpy()
549
+
550
+ if not return_dict:
551
+ return (image_embeddings, zero_embeds)
552
+
553
+ return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds)
model/pops_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ def preprocess(image_a: torch.Tensor, image_b: torch.Tensor, image_encoder: nn.Module, clip_mean: torch.Tensor,
8
+ clip_std: torch.Tensor, should_drop_cond: List[Tuple[bool, bool]] = None, concat_hidden_states=None,
9
+ image_list=None):
10
+ with torch.no_grad():
11
+ image_list = [] if image_list is None else image_list
12
+ additional_list = []
13
+ if image_a is not None:
14
+ additional_list.append(image_a)
15
+ if image_b is not None:
16
+ additional_list.append(image_b)
17
+ image_list = additional_list + image_list
18
+ embeds_list = []
19
+ for image in image_list:
20
+ # If already is vector skip encoder
21
+ if len(image.shape) == 2:
22
+ image_embeds = image
23
+ else:
24
+ encoder_outs = image_encoder(image, output_hidden_states=False)
25
+ image_embeds = encoder_outs.image_embeds
26
+ image_embeds = (image_embeds - clip_mean) / clip_std
27
+ embeds_list.append(image_embeds.unsqueeze(1))
28
+ if should_drop_cond is not None:
29
+ for b_ind in range(embeds_list[0].shape[0]):
30
+ should_drop_a, should_drop_b = should_drop_cond[b_ind]
31
+ if should_drop_a:
32
+ embeds_list[0][b_ind] = torch.zeros_like(embeds_list[0][b_ind])
33
+ if should_drop_b and image_b is not None:
34
+ embeds_list[1][b_ind] = torch.zeros_like(embeds_list[1][b_ind])
35
+ if concat_hidden_states is not None:
36
+ embeds_list.append(concat_hidden_states)
37
+ out_hidden_states = torch.concat(embeds_list, dim=1)
38
+
39
+ image_embeds = torch.zeros_like(embeds_list[0].squeeze(1))
40
+
41
+ return image_embeds, out_hidden_states
pops.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from diffusers import PriorTransformer, UNet2DConditionModel, KandinskyV22Pipeline
5
+ from huggingface_hub import hf_hub_download
6
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor, CLIPTokenizer, CLIPTextModelWithProjection
7
+
8
+ from model import pops_utils
9
+ from model.pipeline_pops import pOpsPipeline
10
+
11
+ kandinsky_prior_repo: str = 'kandinsky-community/kandinsky-2-2-prior'
12
+ kandinsky_decoder_repo: str = 'kandinsky-community/kandinsky-2-2-decoder'
13
+ prior_texture_repo: str = 'models/texturing/learned_prior.pth'
14
+ prior_instruct_repo: str = 'models/instruct/learned_prior.pth'
15
+ prior_scene_repo: str = 'models/scene/learned_prior.pth'
16
+ prior_repo = "pOpsPaper/operators"
17
+
18
+ gpu = torch.device('cuda')
19
+ cpu = torch.device('cpu')
20
+
21
+ class PopsPipelines:
22
+ def __init__(self):
23
+ weight_dtype = torch.float16
24
+ self.weight_dtype = weight_dtype
25
+ device = 'cuda:0'
26
+ self.device = device
27
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(kandinsky_prior_repo,
28
+ subfolder='image_encoder',
29
+ torch_dtype=weight_dtype).eval()
30
+ self.image_encoder.requires_grad_(False)
31
+
32
+ self.image_processor = CLIPImageProcessor.from_pretrained(kandinsky_prior_repo,
33
+ subfolder='image_processor')
34
+
35
+ self.tokenizer = CLIPTokenizer.from_pretrained(kandinsky_prior_repo, subfolder='tokenizer')
36
+ self.text_encoder = CLIPTextModelWithProjection.from_pretrained(kandinsky_prior_repo,
37
+ subfolder='text_encoder',
38
+ torch_dtype=weight_dtype).eval().to(device)
39
+
40
+ # Load full model for vis
41
+ self.unet = UNet2DConditionModel.from_pretrained(kandinsky_decoder_repo,
42
+ subfolder='unet').to(torch.float16).to(device)
43
+
44
+
45
+ self.decoder = KandinskyV22Pipeline.from_pretrained(kandinsky_decoder_repo, unet=self.unet,
46
+ torch_dtype=torch.float16)
47
+ self.decoder = self.decoder.to(device)
48
+
49
+
50
+ self.priors_dict = {
51
+ 'texturing':{'repo':prior_texture_repo},
52
+ 'instruct': {'repo': prior_instruct_repo},
53
+ 'scene': {'repo':prior_scene_repo}
54
+ }
55
+
56
+ for prior_type in self.priors_dict:
57
+ prior_path = self.priors_dict[prior_type]['repo']
58
+ prior = PriorTransformer.from_pretrained(
59
+ kandinsky_prior_repo, subfolder="prior"
60
+ )
61
+
62
+ # Load from huggingface
63
+ prior_path = hf_hub_download(repo_id=prior_repo, filename=str(prior_path))
64
+ prior_state_dict = torch.load(prior_path, map_location=device)
65
+ prior.load_state_dict(prior_state_dict, strict=False)
66
+
67
+ prior.eval()
68
+ prior = prior.to(weight_dtype)
69
+
70
+ prior_pipeline = pOpsPipeline.from_pretrained(kandinsky_prior_repo,
71
+ prior=prior,
72
+ image_encoder=self.image_encoder,
73
+ torch_dtype=torch.float16)
74
+
75
+ self.priors_dict[prior_type]['pipeline'] = prior_pipeline
76
+
77
+ def process_image(self, input_path):
78
+ if input_path is None:
79
+ return None
80
+ image_pil = Image.open(input_path).convert("RGB").resize((512, 512))
81
+ image = torch.Tensor(self.image_processor(image_pil)['pixel_values'][0]).to(self.device).unsqueeze(0).to(
82
+ self.weight_dtype)
83
+
84
+ return image
85
+
86
+ def process_text(self, text):
87
+ text_inputs = self.tokenizer(
88
+ text,
89
+ padding="max_length",
90
+ max_length=self.tokenizer.model_max_length,
91
+ truncation=True,
92
+ return_tensors="pt",
93
+ )
94
+ mask = text_inputs.attention_mask.bool() # [0]
95
+
96
+ text_encoder_output = self.text_encoder(text_inputs.input_ids.to(self.device))
97
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
98
+ text_encoder_concat = text_encoder_hidden_states[:, :mask.sum().item()]
99
+ return text_encoder_concat
100
+
101
+ def run_binary(self, input_a, input_b, prior_type):
102
+ # Move pipeline to GPU
103
+ pipeline = self.priors_dict[prior_type]['pipeline']
104
+ pipeline.to('cuda')
105
+ input_image_embeds, input_hidden_state = pops_utils.preprocess(input_a, input_b,
106
+ self.image_encoder,
107
+ pipeline.prior.clip_mean.detach(),
108
+ pipeline.prior.clip_std.detach())
109
+
110
+ negative_input_embeds = torch.zeros_like(input_image_embeds)
111
+ negative_hidden_states = torch.zeros_like(input_hidden_state)
112
+
113
+ guidance_scale = 1.0
114
+ if prior_type == 'texturing':
115
+ guidance_scale = 8.0
116
+
117
+ img_emb = pipeline(input_embeds=input_image_embeds, input_hidden_states=input_hidden_state,
118
+ negative_input_embeds=negative_input_embeds,
119
+ negative_input_hidden_states=negative_hidden_states,
120
+ num_inference_steps=25,
121
+ num_images_per_prompt=1,
122
+ guidance_scale=guidance_scale)
123
+
124
+ # Optional
125
+ if prior_type == 'scene':
126
+ # Scene is the closet to what avg represents for a background image so incorporate that as well
127
+ mean_emb = 0.5 * input_hidden_state[:, 0] + 0.5 * input_hidden_state[:, 1]
128
+ mean_emb = (mean_emb * pipeline.prior.clip_std) + pipeline.prior.clip_mean
129
+ alpha = 0.4
130
+ img_emb.image_embeds = (1 - alpha) * img_emb.image_embeds + alpha * mean_emb
131
+
132
+ # Move pipeline to CPU
133
+ pipeline.to('cpu')
134
+ return img_emb
135
+
136
+ def run_instruct(self, input_a, text):
137
+ text_encodings = self.process_text(text)
138
+
139
+ # Move pipeline to GPU
140
+ instruct_pipeline = self.priors_dict['instruct']['pipeline']
141
+ instruct_pipeline.to('cuda')
142
+ input_image_embeds, input_hidden_state = pops_utils.preprocess(input_a, None,
143
+ self.image_encoder,
144
+ instruct_pipeline.prior.clip_mean.detach(), instruct_pipeline.prior.clip_std.detach(),
145
+ concat_hidden_states=text_encodings)
146
+
147
+ negative_input_embeds = torch.zeros_like(input_image_embeds)
148
+ negative_hidden_states = torch.zeros_like(input_hidden_state)
149
+ img_emb = instruct_pipeline(input_embeds=input_image_embeds, input_hidden_states=input_hidden_state,
150
+ negative_input_embeds=negative_input_embeds,
151
+ negative_input_hidden_states=negative_hidden_states,
152
+ num_inference_steps=25,
153
+ num_images_per_prompt=1,
154
+ guidance_scale=1.0)
155
+
156
+ # Move pipeline to CPU
157
+ instruct_pipeline.to('cpu')
158
+ return img_emb
159
+
160
+ def render(self, img_emb):
161
+ images = self.decoder(image_embeds=img_emb.image_embeds, negative_image_embeds=img_emb.negative_image_embeds,
162
+ num_inference_steps=50, height=512,
163
+ width=512, guidance_scale=4).images
164
+
165
+ return images[0]
166
+
167
+ def run_instruct_texture(self, image_object_path, text_instruct, image_texture_path):
168
+ # Process both inputs
169
+ image_object = self.process_image(image_object_path)
170
+ image_texture = self.process_image(image_texture_path)
171
+
172
+ if image_object is None:
173
+ raise gr.Error('Object image is required')
174
+
175
+ current_emb = None
176
+
177
+ if image_texture is None:
178
+ instruct_input = image_object
179
+ else:
180
+ # Run texturing
181
+ current_emb = self.run_binary(input_a=image_object, input_b=image_texture,prior_type='texturing')
182
+ instruct_input = current_emb.image_embeds
183
+
184
+ if text_instruct != '':
185
+ current_emb = self.run_instruct(input_a=instruct_input, text=text_instruct)
186
+
187
+ if current_emb is None:
188
+ raise gr.Error('At least one of the inputs is required')
189
+
190
+ # Render as image
191
+ image = self.render(current_emb)
192
+
193
+ return image
194
+
195
+ def run_texture_scene(self, image_object_path, image_texture_path, image_scene_path):
196
+ # Process both inputs
197
+ image_object = self.process_image(image_object_path)
198
+ image_texture = self.process_image(image_texture_path)
199
+ image_scene = self.process_image(image_scene_path)
200
+
201
+ if image_object is None:
202
+ raise gr.Error('Object image is required')
203
+
204
+ current_emb = None
205
+
206
+ if image_texture is None:
207
+ scene_input = image_object
208
+ else:
209
+ # Run texturing
210
+ current_emb = self.run_binary(input_a=image_object, input_b=image_scene,prior_type='scene')
211
+ scene_input = current_emb.image_embeds
212
+
213
+ # Run scene
214
+ if image_scene is not None:
215
+ current_emb = self.run_binary(input_a=scene_input, input_b=image_texture,prior_type='texturing')
216
+
217
+ if current_emb is None:
218
+ raise gr.Error('At least one of the images is required')
219
+ # Render as image
220
+ image = self.render(current_emb)
221
+
222
+ return image
223
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ diffusers
2
+ transformers
3
+ Pillow
4
+ accelerate
5
+ torch
6
+ torchvision
style.css ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1, h2, h3 {
2
+ text-align: center;
3
+ margin: 0;
4
+ }
5
+ .vertical-center {
6
+ margin: 0;
7
+ position: absolute;
8
+ top: 50%;
9
+ -ms-transform: translateY(-50%);
10
+ transform: translateY(-50%);
11
+ }
12
+
13
+ .instruct {
14
+ min-height: 250px;
15
+ background-color: transparent;
16
+ border: transparent;
17
+ }
18
+
19
+
20
+
21
+ #component-0{
22
+ justify-content: center;
23
+ align-items: center;
24
+ }
25
+
26
+ #component-2{
27
+ justify-content: center;
28
+ align-items: center;
29
+ }
30
+
31
+ /*#component-3{*/
32
+ /* justify-content: center;*/
33
+ /* align-items: center;*/
34
+ /*}*/
35
+
36
+ .justified-element {
37
+ /*display: flex;*/
38
+ justify-content: center;
39
+ align-items: center;
40
+ }
41
+
42
+ .small-elem {
43
+ max-width: 400px;
44
+ }
45
+
46
+
47
+
48
+ .o-pops {
49
+ color: #82cf8e; /* Light green color */
50
+ font-weight: bold;
51
+ }
52
+ .o-operators {
53
+ color: #ac85cc; /* Light purple color */
54
+ font-weight: bold;
55
+ }