Latent Reconstruction yields noise
#3
by
Blackroot
- opened
Hi, I'm getting some unexpected results from a use of the model. I was only able to get the code to run after modifying slightly asdc_ae = DCAE_HF.from_pretrained(f"mit-han-lab/dc-ae-f64c128-in-1.0")
expects a model name.
My test code:
from efficientvit.ae_model_zoo import DCAE_HF
dc_ae = DCAE_HF.from_pretrained(f"./dc-ae", model_name="dc-ae-f32c32-sana-1.0")
def get_tensor_bytes(tensor):
bytes = tensor.nelement() * tensor.element_size()
units = ['B', 'KB', 'MB', 'GB', 'TB']
size = float(bytes)
unit_index = 0
while size >= 1024 and unit_index < len(units) - 1:
size /= 1024
unit_index += 1
return f"{size:.2f} {units[unit_index]}"
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
from efficientvit.apps.utils.image import DMCrop
device = torch.device("cuda")
dc_ae = dc_ae.to(device).eval()
transform = transforms.Compose([
DMCrop(512),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
image = Image.open("/home/blackroot/Desktop/vaes/ComfyUI_00312_.png").convert('RGB')
x = transform(image).unsqueeze(0).to(device)
print("Input shape:", x.shape)
print("Input range:", torch.min(x).item(), torch.max(x).item())
with torch.no_grad():
latent = dc_ae.encode(x)
print("Latent shape:", latent.shape)
print("Latent range:", torch.min(latent).item(), torch.max(latent).item())
reconstructed = dc_ae.decode(latent)
print("Output shape:", reconstructed.shape)
print("Output range:", torch.min(reconstructed).item(), torch.max(reconstructed).item())
save_image(reconstructed * 0.5 + 0.5, "reconstructed.png")
Here's the example output, the reconstruction seems to be mostly random noise. The shapes all look correct, but you can see the reconstruction below is noise. Any idea what might be going on here?