Latent Reconstruction yields noise

#3
by Blackroot - opened

Hi, I'm getting some unexpected results from a use of the model. I was only able to get the code to run after modifying slightly as
dc_ae = DCAE_HF.from_pretrained(f"mit-han-lab/dc-ae-f64c128-in-1.0")
expects a model name.

My test code:

from efficientvit.ae_model_zoo import DCAE_HF

dc_ae = DCAE_HF.from_pretrained(f"./dc-ae", model_name="dc-ae-f32c32-sana-1.0")

def get_tensor_bytes(tensor):
    bytes = tensor.nelement() * tensor.element_size()
    
    units = ['B', 'KB', 'MB', 'GB', 'TB']
    size = float(bytes)
    unit_index = 0
    
    while size >= 1024 and unit_index < len(units) - 1:
        size /= 1024
        unit_index += 1
    
    return f"{size:.2f} {units[unit_index]}"

from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
from efficientvit.apps.utils.image import DMCrop

device = torch.device("cuda")
dc_ae = dc_ae.to(device).eval()

transform = transforms.Compose([
    DMCrop(512),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

image = Image.open("/home/blackroot/Desktop/vaes/ComfyUI_00312_.png").convert('RGB')
x = transform(image).unsqueeze(0).to(device)

print("Input shape:", x.shape)
print("Input range:", torch.min(x).item(), torch.max(x).item())

with torch.no_grad():
    latent = dc_ae.encode(x)
    print("Latent shape:", latent.shape)
    print("Latent range:", torch.min(latent).item(), torch.max(latent).item())
    
    reconstructed = dc_ae.decode(latent)
    print("Output shape:", reconstructed.shape)
    print("Output range:", torch.min(reconstructed).item(), torch.max(reconstructed).item())

save_image(reconstructed * 0.5 + 0.5, "reconstructed.png")

Here's the example output, the reconstruction seems to be mostly random noise. The shapes all look correct, but you can see the reconstruction below is noise. Any idea what might be going on here?
original.png

reconstructed.png

MIT HAN Lab org

Hi, sorry for the late reply.

It is supposed to use this command dc_ae = DCAE_HF.from_pretrained(f"mit-han-lab/dc-ae-f64c128-in-1.0") without modification to load the pretrained model. Did you encounter any errors using this command?

You can refer to the instructions here and here.

Best,
Junyu

Sign up or log in to comment