Edit model card
YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Original model github address:DenoisingDiffusionProbabilityModel-ddpm-

This is a simple attempt. I trained with CIFAR-10 dataset.

Usage

# 生成图像有误...以下代码需修改!!!

import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from PIL import Image
import os
import matplotlib.pyplot as plt

# 设备选择
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_id = "BackTo2014/DDPM-test"

def load_and_eval(checkpoint_path, output_dir="./generated_images"):
    # 加载 UNet 模型
    unet = UNet2DModel.from_pretrained(
        model_id,  # 替换为你的模型存储库名称
        filename=checkpoint_path,  # 使用传入的检查点文件名
        ignore_mismatched_sizes=True,
        low_cpu_mem_usage=False,
    ).to(device)

    # 确保 sample_size 是一个有效的尺寸信息
    if unet.config.sample_size is None:
        # 假设样本尺寸为 32x32 或者根据你的需求设置
        unet.config.sample_size = (32, 32)

    # 初始化调度器
    scheduler = DDPMScheduler.from_config(model_id)  # 替换为你的调度器存储库名称

    # 创建管道
    pipeline = DDPMPipeline(unet=unet, scheduler=scheduler)

    # 设置生成参数
    num_images = 4  # 生成4张图像
    generator = torch.manual_seed(0)  # 固定随机种子
    num_inference_steps = 999  # 推理步数

    # 生成图像
    images = []
    for _ in range(num_images):
        image = pipeline(generator=generator, num_inference_steps=num_inference_steps).images[0]
        images.append(image)

    # 创建输出目录
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 保存图像
    for i, img in enumerate(images):
        img.save(os.path.join(output_dir, f"generated_image_{i}.png"))

    # 使用 Matplotlib 显示图像
    fig, axs = plt.subplots(1, len(images), figsize=(len(images) * 5, 5))
    for ax, img in zip(axs.flatten(), images):
        ax.imshow(img)
        ax.axis('off')
    plt.show()

if __name__ == "__main__":
    checkpoint_path = "ckpt_141_.pt"  # 检查点文件名
    load_and_eval(checkpoint_path)
Downloads last month
20
Inference API
Unable to determine this model’s pipeline type. Check the docs .