import torch from huggingface_hub import hf_hub_download from safetensors.torch import load_file def load_unet_model(base, repo, ckpt, device="cpu"): """ Load the UNet model from Hugging Face Hub. Args: base (str): Base model name. repo (str): Repository name. ckpt (str): Checkpoint filename. device (str): Device to load the model on. Returns: UNet2DConditionModel: Loaded UNet model. """ from diffusers import UNet2DConditionModel unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16) unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device)) return unet