Spaces:
Sleeping
Sleeping
Allow non-512x512 with chameleon tokenizer
Browse files- .gitignore +1 -0
- app.py +5 -3
- chameleon/image_tokenizer.py +3 -2
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
app.py
CHANGED
@@ -130,11 +130,13 @@ class ChameleonVQImageRoundtripPipeline(ImageRoundtripPipeline):
|
|
130 |
def roundtrip_image(self, image, output_type="pil"):
|
131 |
# image = self.tokenizer._vqgan_input_from(image).to(device)
|
132 |
image = self.preprocess(image).to(device)
|
|
|
133 |
_, _, [_, _, latents] = self.tokenizer._vq_model.encode(image)
|
134 |
-
|
135 |
-
|
|
|
136 |
# we actually do want this to be a grid, sorry!
|
137 |
-
latents = latents.reshape(
|
138 |
|
139 |
return (
|
140 |
output,
|
|
|
130 |
def roundtrip_image(self, image, output_type="pil"):
|
131 |
# image = self.tokenizer._vqgan_input_from(image).to(device)
|
132 |
image = self.preprocess(image).to(device)
|
133 |
+
_, _, im_height, im_width = image.shape
|
134 |
_, _, [_, _, latents] = self.tokenizer._vq_model.encode(image)
|
135 |
+
scale = self.vae_scale_factor
|
136 |
+
shape = (1, im_height // scale, im_width // scale)
|
137 |
+
output = self.tokenizer.pil_from_img_toks(latents, shape=shape)
|
138 |
# we actually do want this to be a grid, sorry!
|
139 |
+
latents = latents.reshape(*shape)
|
140 |
|
141 |
return (
|
142 |
output,
|
chameleon/image_tokenizer.py
CHANGED
@@ -115,10 +115,11 @@ class ImageTokenizer:
|
|
115 |
|
116 |
return pil_image
|
117 |
|
118 |
-
|
|
|
119 |
emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
|
120 |
codebook_entry = self._vq_model.quantize.get_codebook_entry(
|
121 |
-
img_tensor, (
|
122 |
)
|
123 |
pixels = self._vq_model.decode(codebook_entry)
|
124 |
return self._pil_from_chw_tensor(pixels[0])
|
|
|
115 |
|
116 |
return pil_image
|
117 |
|
118 |
+
# darknoon: added shape parameter
|
119 |
+
def pil_from_img_toks(self, img_tensor: torch.Tensor, shape = (1, 32, 32,)) -> PIL.Image:
|
120 |
emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
|
121 |
codebook_entry = self._vq_model.quantize.get_codebook_entry(
|
122 |
+
img_tensor, (*shape, emb_dim)
|
123 |
)
|
124 |
pixels = self._vq_model.decode(codebook_entry)
|
125 |
return self._pil_from_chw_tensor(pixels[0])
|