作者简介:秃头小苏,致力于用最浅显的语言描绘问题
往期回顾:方针检测系列——Faster R-CNN原理详解 还不懂方针检测嘛?一起来看看Faster R-CNN源码解读
近期方针:拥有10000粉丝
支撑小苏:点赞、保藏⭐、留言
对立生成网络GAN系列——GAN
写在前面
其实关于GAN的解说我早就做过一期,点击☞☞☞了解详情因为最近会用到GAN的一些常识,自己又对GAN进行了一些收拾,有了一些新的知道,便写了这篇文章。那么这篇文章和早期的文章有什么区别呢?首要,早期的文章仅仅对GAN做了一个大概的知道,而这篇文章会贴合论文较为详细的解说GAN网络;其次,这次我预备写一个GAN系列,介绍一些经典的GAN网络,所以这篇文章和后面方案写的文章关联性更强。【注:我觉得咱们能够先去读一下我之前的文章,关于丢失函数部分我通过一个例子来解说,仍是很好了解的,文章也很短,能让咱们快速对GAN有一个感性的知道】
预备好了嘛,下面就正式发车了。
GAN简介
这儿先来简单的介绍一下GAN,其完整的名称为Generative Adversarial Nets (生成对立网络)
。其实这个起名还有个小故事,我扼要的说一下,咱们随便听听,就当放松了。其时作者Goodfellow
关于这篇文章其实是有好几个备选姓名的,后来一个中国人说GAN(干)
在中国有一种对立的意思,作者一听,直接拍案挑选了这个名称。
接下来让咱们看看论文中对GAN的解说,如下图所示:
我简单的来翻译一下,其大致意思是说:在咱们提出的对立生成网络中,有一个生成模型,也有一个对立模型,它们相互对立,相互促进。文中也举了个小例子,生成模型能够被以为是一个假币伪造团队,企图生产假币并运用,而判别器类似于差人,企图发现假币。这便是一个相互博弈的进程,生成模型不断的产生伪造水平高的假币,而判别器不断提高差人识别假币水平,直至两者到达一个平衡。这个平衡是指什么呢?即判别器关于生成模型产生的假币区分的成功率大致为50%,即很难区分真假。
生成对立网络✨✨✨
GAN丢失函数
这部分咱们首要结合生成对立网络的丢失函数来介绍网络的整个流程,首要呢,咱们需求对一些字母做一些解说。如下:
ZZ | 随机噪声 |
---|---|
Pz(Z)P_z(Z) | 随机噪声Z遵守的概率散布 |
G(Z;g)G(Z;\theta_g) | 生成器:输出为噪声Z,输出为假图画 |
PgP_g | 生成器生成的假图画遵守的概率散布 |
X∼PdataX \sim P_{data} | 实在数据遵守的概率散布 |
D(X;d)D(X;\theta_d) | 判别器:输入为图画,输出为该图画为实在图画的概率,概率在[0,1]之间 |
对上述字母有一定的了解后,下面就能够给出生成对立网络的丢失函数了,如下图所示:
图片来自B站同济子豪兄
乍一看这个公式你应该是懵逼的,下面就跟着我的思路来分化分化上述公式。首要这个公式应该有两部分,一部分为给定G,找到使V最大化的D;另一部分为给定D,找到使V最小化的G。
咱们先来看第一部分,即给定G,找到使V最大化的D。如下图所示:【注:咱们为什么想要找到使V最大化的D,是因为使V最大化的D会使判别器的作用最好】
首要看第①部分,因为判别器此刻的输入为XX,是实在数据,EX∼Pdata[logD(X)]E_{X \sim P_{data}}[logD(X)] 值越大表明判别器以为输入X为实在数据的概率越大,也即表明判别器才能越强,因而这项的输出越大对判别器来说越好。接着来看第②部分,留意此刻判别器的输入为G(Z)G(Z),即输入为假图画,那么此刻关于D(G(Z))D(G(Z))来说这个值越小,表明判别器断定假图画为实在数据的概率越小,相同表明判别器才能越强。需求留意的是第二项为log(1−D(G(Z)))log(1-D(G(Z))) 的希望,当判别器越强时,D(G(Z))D(G(Z)) 值越小,而Ez∼pz(z)[log(1−D(G(Z)))]E_{z \sim p_z(z)}[log(1-D(G(Z)))] 越大。【注:部分①和部分②要想使给定G时,判别器的作用最好,都需求最大化V,即给定G,找到最大化V的D会使判别器的作用最好。】为便利咱们了解,画出log(1−D(G(Z)))log(1-D(G(Z))) 的函数图画如下:
接着咱们来看第二部分,即给定D,找到使V最小化的G。如下图所示:【注:咱们为什么想要找到使V最小化的G,是因为使V最小化的G会使生成器的作用最好】
相同的,先来看第①部分,因为这次咱们是固定了D,而①只和D有关,因而这部分是常量,其对最小化V是没有任何影响的,能够舍去。那么咱们就来看看第②部分,此刻判别器的输入相同是G(Z)G(Z),为假图画。不同的是现在咱们期待的是生成器的作用好,即尽可能的瞒过判别器,也即希望D(G(Z))D(G(Z)) 尽可能大。D(G(Z))D(G(Z))越大就表明判别器断定假图画为实在数据的概率越大,也就表明生成器生成的图画作用好,能够很成功的骗过判别器。相同的D(G(Z))D(G(Z)) 值越大,Ez∼pz(z)[log(1−D(G(Z)))]E_{z \sim p_z(z)}[log(1-D(G(Z)))] 就越小,因而给定D,找到最小化V的G会使生成器的作用最好。
GAN流程
论文中在给出丢失函数后,又给了一个图例来解说GAN的进程,用原文的话来说便是一个不怎么正式,却更具教育意义的解说。(See Figure 1 for a less formal, more pedagogical explanation of the approach
)
图片来自B站同济子豪兄
其实上图中的文字标示现已将这个进程解说的适当详细了,我再来简单的复述一遍。首要图中黑点表明实在图画的散布,绿点表明生成图画的概率散布,蓝点表明判别器猜测XX为实在数据的概率。在(a)时,黑点和绿点的散布相差较大,判别器能大致区分实在图画和生成图画,但分辩作用不好。【在黑点会集区域蓝点的值普遍较高,表明猜测黑点为实在图画的概率较大;同理,在绿点会集区域蓝点的值普遍较低,表明猜测绿点为实在图画的概率较小。但蓝点存在一定的波动,作用不是很好。】 从(a)到(b)经过了判别器的练习,这会导致什么成果呢,从图(b)中能够发现,此刻蓝点表现的愈加稳定,在黑点会集处猜测概率大,在绿点会集处猜测概率小,也便是说此刻的判别器现已能很好的分辩什么是实在图画,什么是生成的假图画了。接下来从(b)到(c)经过了生成器的练习,这会导致什么成果呢,从图(c)中能够发现,此刻绿点逐步像黑点接近,即生成的图画愈加实在了,而此刻蓝点没有变化,这就会导致现在判别器对实在图画和生成图画的区分难度变大了。这样不断的练习判别器和生成器,最终变成图(d),即实在图画分部和生成器生成图画散布彻底一致,判别器猜测概率恒为0.5,也即此刻判别器彻底无法区分实在图画很生成图画了。
接下来论文中给出了练习GAN网络的伪代码,如下图所示:
图片来自B站同济子豪兄
假如我前文的描绘你都听懂了的话,其实这个进程就没什么好说的了,便是对判别器和生成器不断的迭代更新。需求留意的有两点,第一是在练习进程中,咱们是练习K次判别器,练习一次生成器;第二是在练习生成器进程中,咱们的丢失函数没有了1m∑i=1mlogD(x(i))\frac{1}{m}\sum\limits_{i = 1}^m {\log D({x^{(i)}})} 这一项,这个我在GAN丢失函数
这节有说到,因为练习生成器G时固定了判别器D,该项是定值,可省掉。【注:这儿的1m∑i=1mlogD(x(i))\frac{1}{m}\sum\limits_{i = 1}^m {\log D({x^{(i)}})} 和EX∼Pdata[logD(X)]E_{X \sim P_{data}}[logD(X)] 彻底一样,仅仅一个是用均值表明,一个用希望表明。】
GAN完成作用
论文中给出了GAN的一些完成作用的图片,如下图所示:
上面四个图中,留意黄框框住的并不是GAN生成的图片,它们表明与GAN生成图片最类似的原始实在图片。而GAN生成的图片为黄框左侧第一张图片,能够看出,GAN生成的作用仍是挺好的。
运用GAN生成手写数字小demo✨✨✨
上文算是把原理叙述清楚了,若你还不明白,渐渐的阅读每句话,加入自己的考虑,或许会有不一样的收获。那么这节我讲来讲讲通过GAN网络生成手写数字的小demo,通过这部分你会了解建立GAN网络的基本流程。下面就让咱们一起来学学吧!!!【注:其实大致的流程和一般分类网络的建立是类似的,相关分类网络的建立流程可参阅我的这篇博文】
首要练习一个模型肯定少不了数据集,咱们通过一下代码获取torch
自带的MNIST数据集
,代码如下:
#MNIST数据集获取
dataset = torchvision.datasets.MNIST("mnist_data", train=True, download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.Resize(28),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5], [0.5]),
]
)
)
之后咱们通过DataLoader
方法加载数据集,代码如下:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
这样数据就预备好了,下面就来构建咱们的模型,分为生成器(Generator)和判别器(Discriminator)。【注:因为这期算是入门GAN,所以模型建立只采用了全连接层】
生成器模型建立:
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
torch.nn.BatchNorm1d(128),
torch.nn.GELU(),
nn.Linear(128, 256),
torch.nn.BatchNorm1d(256),
torch.nn.GELU(),
nn.Linear(256, 512),
torch.nn.BatchNorm1d(512),
torch.nn.GELU(),
nn.Linear(512, 1024),
torch.nn.BatchNorm1d(1024),
torch.nn.GELU(),
nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
nn.Sigmoid(),
)
def forward(self, z):
# shape of z: [batchsize, latent_dim]
output = self.model(z)
image = output.reshape(z.shape[0], *image_size)
return image
判别器模型建立:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(np.prod(image_size, dtype=np.int32), 512),
torch.nn.GELU(),
nn.Linear(512, 256),
torch.nn.GELU(),
nn.Linear(256, 128),
torch.nn.GELU(),
nn.Linear(128, 64),
torch.nn.GELU(),
nn.Linear(64, 32),
torch.nn.GELU(),
nn.Linear(32, 1),
nn.Sigmoid(),
)
def forward(self, image):
# shape of image: [batchsize, 1, 28, 28]
prob = self.model(image.reshape(image.shape[0], -1))
return prob
模型建立好后,咱们会对丢失函数、优化器等参数进行设置:
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
loss_fn = nn.BCELoss()
需求留意,这儿采用的是BCELOSS丢失函数,这个函数其实就对应着咱们GAN理论部分的丢失函数,这儿想了解更多的话能够参阅这个视频:BCE丢失函数阐明
这些设置好后,咱们就来练习咱们的GAN网络了,相关代码如下:这一部分我仍是主张咱们看一下这个视频,解说的比较清楚。【可直接空降到41分钟】
num_epoch = 200
for epoch in range(num_epoch):
for i, mini_batch in enumerate(dataloader):
gt_images, _ = mini_batch
z = torch.randn(batch_size, latent_dim)
pred_images = generator(z)
g_optimizer.zero_grad()
g_loss = loss_fn(discriminator(pred_images), labels_one)
g_loss.backward()
g_optimizer.step()
d_optimizer.zero_grad()
real_loss = loss_fn(discriminator(gt_images), labels_one)
fake_loss = loss_fn(discriminator(pred_images.detach()), labels_zero)
d_loss = (real_loss + fake_loss)
# 调查real_loss与fake_loss,同时下降同时到达最小值,并且差不多大,阐明D现已稳定了
d_loss.backward()
d_optimizer.step()
最终,我来展示一下练习成果吧!!!我是在服务器上进行练习的,所以仍是比较快的。先来看一下初始的图,都是一些随机的噪声,如下图所示:
再来看练习一段时间的成果,发现作用仍是蛮不错滴
论文下载地址
论文下载地址
参阅链接
生成对立网络GAN开山之作论文精读
原始GAN论文详解
GAN原了解说与PyTorch手写逐行解说
如若文章对你有所帮助,那就
咻咻咻咻~~duang~~点个赞呗
我正在参加技术社区创作者签约方案招募活动,点击链接报名投稿。