isLandLZ commited on
Commit
5a3dda2
·
1 Parent(s): 5cfa643

Upload gan.py

Browse files
Files changed (1) hide show
  1. gan.py +259 -0
gan.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jittor as jt
2
+ from jittor import init
3
+ from jittor import nn
4
+ from jittor.dataset.mnist import MNIST
5
+ import jittor.transform as transform
6
+ import argparse
7
+ import os
8
+ import numpy as np
9
+ import math
10
+ import time
11
+ import cv2
12
+
13
+ jt.flags.use_cuda = 1
14
+ os.makedirs('images', exist_ok=True)
15
+ os.makedirs("saved_models", exist_ok=True)
16
+
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--n_epochs', type=int, default=200, help='训练的时期数')
19
+ parser.add_argument('--batch_size', type=int, default=64, help='批次大小')
20
+ parser.add_argument('--lr', type=float, default=0.0002, help='学习率')
21
+ parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减')
22
+ parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减')
23
+ parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数')
24
+ parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度')
25
+ parser.add_argument('--img_size', type=int, default=28, help='每个图像尺寸的大小')
26
+ parser.add_argument('--channels', type=int, default=1, help='图像通道数')
27
+ parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔')
28
+
29
+ opt = parser.parse_args()
30
+ print(opt)
31
+ img_shape = (opt.channels, opt.img_size, opt.img_size)
32
+
33
+ # 保存生成器生成的图片样本数据
34
+ def save_image(img, path, nrow=None):
35
+ N,C,W,H = img.shape# (25, 1, 28, 28)
36
+ '''
37
+ [-1,700,28] , img2的形状(1,700,28)
38
+ img[0][0][0] = img2[0][0]
39
+ img2:[
40
+ [1*28]
41
+ ......(一共700个)
42
+ ](1,700,28)
43
+ '''
44
+ img2=img.reshape([-1,W*nrow*nrow,H])
45
+ # [:,:28*5,:],img:(1,140,28)
46
+ img=img2[:,:W*nrow,:]
47
+ for i in range(1,nrow):#[1,5)
48
+ '''
49
+ img(1,140,28),img2(1,700,28)
50
+ img从(1,140,28)->(1,140,28+28)->...->(1,140,28+28+28+28)=(1,140,140)
51
+ np.concatenate把两个三维数组合并
52
+ '''
53
+ img=np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2)
54
+ # img中的数据大小从(-1,1)--(+1)-->(0,2)--(/2)-->(0,1)--(*255)-->(0,255)转换成了像素值
55
+ img=(img+1.0)/2.0*255
56
+ # (1,140,140)--->(140,140,1)
57
+ # (channels通道数,imagesize,imagesize)转化为(imagesize,imagesize,channels通道数)
58
+ img=img.transpose((1,2,0))
59
+ # 根据地址保存图片样本数据
60
+ cv2.imwrite(path,img)
61
+
62
+ # 生成器
63
+ class Generator(nn.Module):
64
+
65
+ def __init__(self):
66
+ super(Generator, self).__init__()
67
+
68
+ def block(in_feat, out_feat, normalize=True):
69
+ layers = [nn.Linear(in_feat, out_feat)]
70
+ if normalize:
71
+ layers.append(nn.BatchNorm1d(out_feat, 0.8))
72
+ layers.append(nn.LeakyReLU(scale=0.2))
73
+ return layers
74
+ self.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh())
75
+
76
+ def execute(self, z):
77
+ img = self.model(z)
78
+ img = img.view((img.shape[0], *img_shape))
79
+ return img
80
+
81
+ # 判别器
82
+ class Discriminator(nn.Module):
83
+
84
+ def __init__(self):
85
+ super(Discriminator, self).__init__()
86
+ self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(scale=0.2), nn.Linear(512, 256), nn.LeakyReLU(scale=0.2), nn.Linear(256, 1), nn.Sigmoid())
87
+
88
+ def execute(self, img):
89
+ img_flat = img.view((img.shape[0], (- 1)))
90
+ validity = self.model(img_flat)
91
+ return validity
92
+
93
+ # bce loss分类器 (b这里指的是binary,所以用于二分类问题)
94
+ '''
95
+ 源码:
96
+ class BCELoss(Module):
97
+ def __init__(self, weight=None, size_average=True):
98
+ self.weight = weight
99
+ self.size_average = size_average
100
+ def execute(self, output, target):
101
+ return bce_loss(output, target, self.weight, self.size_average)
102
+
103
+ # weight:表示对loss中每个元素的加权权值,默认为None
104
+ # size_average:指定输出的格式,包括'mean','sum'
105
+ # output:判别器对生成的数据的判别结果(64*1)
106
+ # target:判别器对真实的数据的判别结果(64*1)
107
+ def bce_loss(output, target, weight=None, size_average=True):
108
+ # jt.maximum(x,y):返回x和y的元素最大值
109
+ # 公式:损失值 = -权重*[ 理想结果*log(判别结果) + (1-理想结果)*log(1-判别结果) ]
110
+ loss = - (
111
+ target * jt.log(jt.maximum(output, 1e-20))
112
+ +
113
+ (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))
114
+ )
115
+ if weight is not None:
116
+ loss *= weight
117
+ if size_average:
118
+ return loss.mean()# 求均值
119
+ else:
120
+ return loss.sum()# 求和
121
+ '''
122
+ # 对抗性损失函数
123
+ adversarial_loss = nn.BCELoss()
124
+
125
+ # 初始化生成器和判别器
126
+ generator = Generator()
127
+ discriminator = Discriminator()
128
+
129
+ # 配置数据加载器
130
+ transform = transform.Compose([
131
+ transform.Resize(size=opt.img_size),
132
+ transform.Gray(),
133
+ transform.ImageNormalize(mean=[0.5], std=[0.5]),
134
+ ])
135
+ dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)
136
+
137
+ # 优化器
138
+ optimizer_G = jt.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
139
+ optimizer_D = jt.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
140
+
141
+ warmup_times = -1
142
+ run_times = 3000
143
+ total_time = 0.
144
+ cnt = 0
145
+
146
+ # ----------
147
+ # 训练
148
+ # ----------
149
+
150
+ for epoch in range(opt.n_epochs):#[1,200),200次迭代
151
+ for (i, (real_imgs, _)) in enumerate(dataloader):
152
+
153
+ '''
154
+ valid表示真,全为1,fake表示假,全为0
155
+ img.shape[0]:图像的垂直尺寸(高度)h
156
+ [ [1.0]...(一共h个)...[1.0] ] 64*1的数组
157
+ '''
158
+ valid = jt.ones([real_imgs.shape[0], 1]).stop_grad()
159
+ fake = jt.zeros([real_imgs.shape[0], 1]).stop_grad()
160
+
161
+ # ---------------------
162
+ # 训练生成器
163
+ # ---------------------
164
+
165
+ # TODO 第一步:生成服从正态分布的噪音数据
166
+ '''
167
+ 随机生成一个符合正态分布的噪声,numpy.random.normal(loc=0.0, scale=1.0, size=None)
168
+ loc:正态分布的均值,对应着这个分布的中心,0说明这一个以Y轴为对称轴的正态分布
169
+ scale:正态分布的标准差,对应分布的宽度,scale越大,正态分布的曲线越矮胖,scale越小,曲线越高瘦
170
+ shape:(图片的高度h,潜在空间的维度100) == (64,100) == z.shape
171
+ '''
172
+ z = jt.array(np.random.normal(0, 1, (real_imgs.shape[0], opt.latent_dim)).astype(np.float32))
173
+ # TODO 第二步:生成器加载噪音数据生成图片数据[64,1,28,28]
174
+ '''
175
+ gen_imgs的形状:(64,1,28,28), 64*1中每个元素都是28*28
176
+ [
177
+ [28*28]
178
+ ...... (一共64个28*28)
179
+ ]
180
+ '''
181
+ gen_imgs = generator(z)
182
+ # TODO 第三步:根据生成数据的判别结果和真的数据(都是64*1)计算损失值
183
+ '''
184
+ 把生成的图片数据放进判别器中,让判别器对其进行分类,计算出数据可能是真实数据的概率值(0-1之间的数)
185
+ valid当作是判别器分类的结果,全为1说明判别器认为这个数据来源于真实图片
186
+ adversarial_loss会调用bce_loss求损失值
187
+ 因为我们需要使生成器生成的数据越来越像真实的数据,所以我们需要这两个数据越来越相似[discriminator(gen_imgs)和valid]
188
+ loss(x,y)=-w*[ylogx+(1-y)log(1-x)]
189
+ 生成器理想条件下,discriminator(gen_imgs)=1,loss(1,1)=0
190
+ '''
191
+ g_loss = adversarial_loss(discriminator(gen_imgs), valid)
192
+ # TODO 第四步:反向传播,训练生成器的参数
193
+ optimizer_G.step(g_loss)
194
+
195
+ # ---------------------
196
+ # 训练判别器
197
+ # ---------------------
198
+
199
+ # TODO 第一步:根据训练集中的数据和真的数据计算损失值
200
+ '''
201
+ real_imgs:加载的训练集数据
202
+ 把训练集数据放进判别器,得到判别器对训练集数据的判别结果,计算出数据可能是真实数据的概率值
203
+ valid当作是判别器分类的结果,全为1说明判别器认为这个数据来源于真实图片
204
+ 因为我们需要使判别器把训练集数据判别为真实数据,所以我们需要使这两个数据越来越相似[discriminator(real_imgs), valid]
205
+ loss(x,y)=-w*[ylogx+(1-y)log(1-x)]
206
+ 判别器理想条件下,discriminator(real_imgs)=1,loss(1,1)=0
207
+ '''
208
+ real_loss = adversarial_loss(discriminator(real_imgs), valid)#
209
+ # TODO 第二步:根据生成数据的判别结果和假的数据(都是64*1)计算损失值
210
+ '''
211
+ gen_imgs:生成器生成的图片数据
212
+ 把生成的图片数据放进判别器中,让判别器对其进行分类,计算出数据可能是真实数据的概率值(0-1之间的数)
213
+ fake当作是判别器分类的结果,全为0说明判别器认为这个数据来源于生成的数据,而不是真实现实中的数据
214
+ 调用bce_loss求损失值
215
+ 因为我们需要使判别器能识别出机器生成的图片数据,所以我们需要使这两个数越来越相似[discriminator(gen_imgs), fake]
216
+ loss(x,y)=-w*[ylogx+(1-y)log(1-x)]
217
+ 判别器理想条件下,discriminator(gen_imgs)=0,loss(0,0)=0
218
+ '''
219
+ fake_loss = adversarial_loss(discriminator(gen_imgs), fake)#
220
+ # TODO 第三步:对这两个损失值求平均
221
+ d_loss = ((real_loss + fake_loss) / 2)
222
+ # TODO 第四步:反向传播,训练判别器的参数
223
+ optimizer_D.step(d_loss)
224
+
225
+ # ---------------------
226
+ # 打印训练进度,打印生成器和判别器的损失值
227
+ # 保存生成器生成的图片样本数据
228
+ # ---------------------
229
+
230
+ if warmup_times==-1:
231
+ '''
232
+ D loss:判别器的损失值,越小越好(0-1的数)
233
+ G loss:生成器的损失值,越小越好(0-1的数)
234
+ numpy():把Var数据类型的数据转换成array数据类型
235
+ '''
236
+ print(('[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]' % (epoch, opt.n_epochs, i, len(dataloader), d_loss.numpy()[0], g_loss.numpy()[0])))
237
+ # [0,200) * 938 + [0,938) = [0,199*938+937] = [0,187599]
238
+ batches_done = ((epoch * len(dataloader)) + i)
239
+ # opt.sample_interval = 400 , 187599 / 400 = 468
240
+ if ((batches_done % opt.sample_interval) == 0):
241
+ # gen_imgs.data[:25] -> (25, 1, 28, 28)
242
+ save_image(gen_imgs.data[:25], ('images/GAN_images/%d.png' % batches_done), nrow=5)
243
+ else:
244
+ jt.sync_all()
245
+ cnt += 1
246
+ print(cnt)
247
+ if cnt == warmup_times:
248
+ jt.sync_all(True)
249
+ sta = time.time()
250
+ if cnt > warmup_times + run_times:
251
+ jt.sync_all(True)
252
+ total_time = time.time() - sta
253
+ print(f"run {run_times} iters cost {total_time} seconds, and avg {total_time / run_times} one iter.")
254
+ exit(0)
255
+
256
+ # 指定地址保存训练好的模型
257
+ if (epoch+1) % 10 == 0:
258
+ generator.save("saved_models/generator_last.pkl")
259
+ discriminator.save("saved_models/discriminator_last.pkl")