BackTo2014 commited on
Commit
0b22dd7
1 Parent(s): 860a47f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +51 -52
README.md CHANGED
@@ -8,64 +8,63 @@ This is a simple attempt. I trained with CIFAR-10 dataset.
8
  # 生成图像有误...以下代码需修改!!!
9
 
10
  import torch
11
- from diffusers import DDPMPipeline, DDPMScheduler
12
- from diffusers.models import UNet2DModel
13
  from PIL import Image
 
14
  import matplotlib.pyplot as plt
15
 
16
- # 模型ID
17
- model_id = "BackTo2014/DDPM-test"
18
-
19
- # 检查设备
20
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
 
22
- # 加载UNet模型和配置文件
23
- try:
 
 
24
  unet = UNet2DModel.from_pretrained(
25
- model_id,
 
26
  ignore_mismatched_sizes=True,
27
  low_cpu_mem_usage=False,
28
- ).to(device) # 将模型移动到GPU上
29
- except ValueError as e:
30
- print(f"Error loading model: {e}")
31
-
32
- # 获取模型的state_dict
33
- state_dict = unet.state_dict()
34
-
35
- # 手动初始化缺失的权重
36
- for key in e.args[0].split(': ')[1].split(', '):
37
- name, size = key.split('.')
38
- size = tuple(map(int, size.replace(')', '').replace('(', '').split(',')))
39
-
40
- # 创建随机权重
41
- new_weight = torch.randn(size).to(device) # 将权重移动到GPU上
42
-
43
- # 更新state_dict
44
- state_dict[name] = new_weight
45
-
46
- # 加载更新后的state_dict
47
- unet.load_state_dict(state_dict).to(device) # 将模型移动到GPU上
48
-
49
- # 如果sample_size未定义,则手动设置
50
- if unet.config.sample_size is None:
51
- # 假设样本尺寸为 32x32
52
- unet.config.sample_size = (32, 32)
53
-
54
- # 初始化Scheduler
55
- scheduler = DDPMScheduler.from_config(model_id)
56
-
57
- # 创建DDPMPipeline
58
- pipeline = DDPMPipeline(unet=unet, scheduler=scheduler)
59
-
60
- # 生成图像
61
- generator = torch.manual_seed(0)
62
- image = pipeline(num_inference_steps=1000, generator=generator).images[0]
63
-
64
- # 使用matplotlib显示图像
65
- plt.imshow(image)
66
- plt.axis('off') # 不显示坐标轴
67
- plt.show()
68
-
69
- # 保存图像
70
- image.save("generated_image.png")
71
  ```
 
8
  # 生成图像有误...以下代码需修改!!!
9
 
10
  import torch
11
+ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
 
12
  from PIL import Image
13
+ import os
14
  import matplotlib.pyplot as plt
15
 
16
+ # 设备选择
 
 
 
17
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
 
19
+ model_id = "BackTo2014/DDPM-test"
20
+
21
+ def load_and_eval(checkpoint_path, output_dir="./generated_images"):
22
+ # 加载 UNet 模型
23
  unet = UNet2DModel.from_pretrained(
24
+ model_id, # 替换为你的模型存储库名称
25
+ filename=checkpoint_path, # 使用传入的检查点文件名
26
  ignore_mismatched_sizes=True,
27
  low_cpu_mem_usage=False,
28
+ ).to(device)
29
+
30
+ # 确保 sample_size 是一个有效的尺寸信息
31
+ if unet.config.sample_size is None:
32
+ # 假设样本尺寸为 32x32 或者根据你的需求设置
33
+ unet.config.sample_size = (32, 32)
34
+
35
+ # 初始化调度器
36
+ scheduler = DDPMScheduler.from_config(model_id) # 替换为你的调度器存储库名称
37
+
38
+ # 创建管道
39
+ pipeline = DDPMPipeline(unet=unet, scheduler=scheduler)
40
+
41
+ # 设置生成参数
42
+ num_images = 4 # 生成4张图像
43
+ generator = torch.manual_seed(0) # 固定随机种子
44
+ num_inference_steps = 999 # 推理步数
45
+
46
+ # 生成图像
47
+ images = []
48
+ for _ in range(num_images):
49
+ image = pipeline(generator=generator, num_inference_steps=num_inference_steps).images[0]
50
+ images.append(image)
51
+
52
+ # 创建输出目录
53
+ if not os.path.exists(output_dir):
54
+ os.makedirs(output_dir)
55
+
56
+ # 保存图像
57
+ for i, img in enumerate(images):
58
+ img.save(os.path.join(output_dir, f"generated_image_{i}.png"))
59
+
60
+ # 使用 Matplotlib 显示图像
61
+ fig, axs = plt.subplots(1, len(images), figsize=(len(images) * 5, 5))
62
+ for ax, img in zip(axs.flatten(), images):
63
+ ax.imshow(img)
64
+ ax.axis('off')
65
+ plt.show()
66
+
67
+ if __name__ == "__main__":
68
+ checkpoint_path = "ckpt_141_.pt" # 检查点文件名
69
+ load_and_eval(checkpoint_path)
 
70
  ```