本文为稀土技术社区首发签约文章,14天内制止转载,14天后未获授权制止转载,侵权必究!

作者简介:秃头小苏,致力于用最通俗的言语描述问题

往期回忆:对立生成网络GAN系列——GAN原理及手写数字生成小事例   对立生成网络GAN系列——DCGAN简介及人脸图画生成事例   对立生成网络GAN系列——AnoGAN原理及缺点检测实战   对立生成网络GAN系列——EGBAD原理及缺点检测实战

近期目标:写好专栏的每一篇文章

支持小苏:点赞、收藏⭐、留言

对立生成网络GAN系列——WGAN原理及实战演练

写在前面

​  在前面我现已写过好几篇关于GAN的文章,感兴趣的能够点击下列链接了解详情:

  • [1]对立生成网络GAN系列——GAN原理及手写数字生成小事例
  • [2]对立生成网络GAN系列——DCGAN简介及人脸图画生成事例
  • [3]对立生成网络GAN系列——CycleGAN原理
  • [4] 对立生成网络GAN系列——AnoGAN原理及缺点检测实战
  • [5]对立生成网络GAN系列——EGBAD原理及缺点检测实战

​  在文章[5]的最终,咱们从练习结果中能够看出作用并不是很理想,其实难以练习是GAN网络练习普遍存在的问题。本篇文章将为咱们带来WGAN(Wasserstein Generative Adversarial Networks ),旨在处理GAN难train的问题。这篇论文中有许多的理论推导,我不会带着咱们一个个推,当然许多我也不会,可是我尽或许的把一些要害部分给咱们叙说清楚,让咱们从心底认可WGAN,觉得WGAN是合理的,是美好的。

​  预备好了吗,就让咱们一同来学学WGAN吧!

WGAN原理详解

GAN为什么练习困难

​  上文咱们谈到GAN普遍存在练习困难的问题,而WGAN是用来处理此问题的。所谓对症下药,咱们榜首步应该要知道GAN为什么会存在练习难的问题,这样咱们才能从本质上对其进行改进。下面就跟随我的脚步一同来看看吧!!!

​  还记得咱们在[1]中给出的GAN网络的丢失函数吗?如下图所示:

对抗生成网络GAN系列——WGAN原理及实战演练

​  为了便利表述,对上图公式做如下调整:

min⁡Gmax⁡DV=Ex∼pr[log⁡D(x)]+Ex∼pg[log⁡(1−D(x))]\mathop {\min }\limits_G \mathop {\max }\limits_D V = {E_{x \sim {{\rm{p}}_r}}}[\log D(x)] + {E_{x \sim {{\rm{p}}_g}}}[\log (1 – D(x))]

​  即把后一项的G(z)G(z)用x来表明,x为生成器生成的散布。

​  咱们现在希望求得生成器固定是最大化的判别器D。首要,咱们关于一个随机的样本,它或许是实在样本,也或许是生成的样本,那么关于这个样本的丢失为:

Pr(x)log⁡D(x)+Pg(x)log⁡[1−D(x)]{{\rm{P}}_r}(x)\log D(x) + {P_g}(x)\log [1 – D(x)]

​  预求最大化的D,则对上式的D(x)求导,并让上式导函数为0,即:

Pr(x)D(x)−Pg(x)1−D(x)=0\frac{{{{\rm{P}}_r}(x)}}{{D(x)}} – \frac{{{{\rm{P}}_g}(x)}}{{1 – D(x)}} = 0

​  化简上式,得最优的D表达式为:

D∗(x)=Pr(x)Pr(x)+Pg(x){D^*}(x) = \frac{{{{\rm{P}}_r}(x)}}{{{{\rm{P}}_r}(x) + {{\rm{P}}_g}(x)}}        ———— 式1

​  咱们能够来看看这个式1是否契合咱们的直觉。若pr(x)=0、pg(x)≠0p_r(x)=0、p_g(x)\ne 0表明x彻底遵守生成散布,是假图画,所以最优判别器D给出的概率为0;若pr(x)≠0、pg(x)=0p_r(x) \ne 0、p_g(x)= 0表明x彻底遵守实在散布,是真图画,所以最优判别器D给出的概率为1;若pr(x)=pg(x)p_r(x)=p_g(x),表明x遵守生成散布和遵守实在散布的数据相同多,因而最优判别器D给出的概率为50%。

​  现在咱们现已找到了最大化的D,现要将其固定,寻找最小化的G,即求下式:

min⁡GV=Ex∼pr[log⁡D(x)]+Ex∼pg[log⁡(1−D(x))]\mathop {\min }\limits_G V = {E_{x \sim {{\rm{p}}_r}}}[\log D(x)] + {E_{x \sim {{\rm{p}}_g}}}[\log (1 – D(x))]    ———— 式2

​  将最大化的D即式1的D∗(x)D^*(x)代入式2得:

min⁡GV=Ex∼pr[log⁡(pr(x)pr(x)+pg(x))]+Ex∼pg[log⁡(pg(x)pr(x)+pg(x))]\mathop {\min }\limits_G V = {E_{x \sim {{\rm{p}}_r}}}[\log (\frac{{{{\rm{p}}_r}(x)}}{{{{\rm{p}}_r}(x) + {{\rm{p}}_g}(x)}})] + {E_{x \sim {{\rm{p}}_g}}}[\log (\frac{{{{\rm{p}}_g}(x)}}{{{{\rm{p}}_r}(x) + {{\rm{p}}_g}(x)}})]

​  再化简一步得式3:

min⁡GV=Ex∼pr[log⁡(pr(x)12pr(x)+pg(x))]+Ex∼pg[log⁡(pg(x)12pr(x)+pg(x))]−2log⁡2\mathop {\min }\limits_G V = {E_{x \sim {{\rm{p}}_r}}}[\log (\frac{{{{\rm{p}}_r}(x)}}{{\frac{1}{2}{{\rm{p}}_r}(x) + {{\rm{p}}_g}(x)}})] + {E_{x \sim {{\rm{p}}_g}}}[\log (\frac{{{{\rm{p}}_g}(x)}}{{\frac{1}{2}{{\rm{p}}_r}(x) + {{\rm{p}}_g}(x)}})] – 2\log 2    ———— 式3

​  化简成式3的方式是为了和KL散度和JS散度相联系。咱们先来介绍KL散度和JS散度的界说:【咱们不要觉得散度这个词如同很巨大上相同,咱们其实就能够理解为间隔】

  • KL散度

    KL(P1∣∣P2)=Ex∼P1logP1P2KL(P_1||P_2)=E_{x \sim P_1}log \frac {P_1}{P_2}

    上式是用希望来表明KL散度的,咱们也能够用积分或求和方式来表达,如下:

    KL(P1∣∣P2)=∫xP1log⁡P1P2dxKL(P_1||P_2)=\int\limits_x {{P_1}\log \frac{{{P_1}}}{{{P_2}}}} dxKL(P1∣∣P2)=∑p1log⁡P1P2KL(P_1||P_2)= \sum {{p_1}\log } \frac{{{P_1}}}{{{P_2}}}

    我来大致的解说一下KL散度的意义,用KL(P1∣∣P2)=∑P1(−logP2−(−logP1))KL(P_1||P_2)=\sum {{P_1}( – log{P_2} – ( – log{P_1}))}解说,如下图:

    对抗生成网络GAN系列——WGAN原理及实战演练

    ​  关于KL散度我就介绍这么多,它还有一些性质,像非对称性等等,我这儿就不过多介绍了,感兴趣的能够自己去阅览阅览相关材料。

  • JS散度 JS(P1∣∣P2)=12KL(P1∣∣P1+P22)+12KL(P2∣∣P1+P22)JS({P_1}||{P_2}) = \frac{1}{2}KL({P_1}||\frac{{{P_1} + {P_2}}}{2}) + \frac{1}{2}KL({P_2}||\frac{{{P_1} + {P_2}}}{2})

    能够看出,JS散度是依据KL散度来界说的,它具有对称性。


​  有了JS散度的界说,咱们就能够将式3变换成如下方式:

min⁡GV=2JS(Pr∣∣Pg)−2log⁡2\mathop {\min }\limits_G V = 2JS({P_r}||{P_g}) – 2\log 2      ———— 式4

​  这时候必定许多人会想了,为什么咱们需要把式3化成式4这个JS散度的方式,这是由于咱们的GAN练习困难很大原因便是这个JS散度捣的鬼。

​  为什么说练习困难是JS散度捣的鬼呢?经过上文咱们知道,咱们将判别器练习的越好,即判别器越挨近最优判别器,此刻生成器的loss会等价为式4中的JS散度的方式,当咱们练习生成器来最小化生成器丢失时也便是最小化式4中的JS散度。这这过程看似十分的合理,只要咱们不断的练习,实在散布和生成散布越来越挨近,JS散度应该越来越小,直到两个散布彻底一致,此刻JS散度为0。

​  可是理想很饱满,实际很骨感,JS散度并不会跟着实在散布和生成散布越来越近而使其值越来越小,而是很大概率保持log2log2不变。

对抗生成网络GAN系列——WGAN原理及实战演练

​  如上图所示,跟着pgp_gpdatap_{data}(也便是上文中所述的prp_r)越来越近,它们的JS散度却不发生变化,一向等于log2,只有当pgp_gpdatap_{data}有堆叠时才会发生变化,当彻底堆叠时,JS散度的值为0。也便是说,不论pgp_gpdatap_{data}相距多远,只要它们没有堆叠,那么它们的JS散度就一向是常数log2,也便是在练习过程中loss一向不发生变化,这也意味着生成器的梯度一向为0,即生成器产生了梯度消失的现象。

​  读了上文,我觉得你至少应该有两个疑问,其一是为什么两个散布不重合,它们JS散度一向为log2,其二为既然只要pgp_gpdatap_{data}有堆叠,那么它们的JS散度就不是log2,这样咱们是不是就能够更新梯度了呢?那么它们堆叠的概率大吗?咱们一个一个的来处理咱们的疑问。

  1. 什么两个散布不重合,它们JS散度一向为log2

    这个其实是由JS散度的性质来决议的,这儿带咱们推一推。首要JS散度的表达式如下:

    JS(P1∣∣P2)=12KL(P1∣∣P1+P22)+12KL(P2∣∣P1+P22)JS({P_1}||{P_2}) = \frac{1}{2}KL({P_1}||\frac{{{P_1} + {P_2}}}{2}) + \frac{1}{2}KL({P_2}||\frac{{{P_1} + {P_2}}}{2})

    上文我也给出了KL散度的公式,将KL散度公式代入上式,得:

    JS(P1∣∣P2)=12∑p1(x)log⁡2p1(x)p1(x)+p2(x)+12∑p2(x)log⁡2p2(x)p1(x)+p2(x)JS({P_1}||{P_2}) = \frac{1}{2}\sum {{p_1}(x)\log } \frac{{2{p_1}(x)}}{{{p_1}(x) + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{2{p_2}(x)}}{{{p_1}(x) + {p_2}(x)}}

    再化简上式,将log2提出来,留意这儿能够提是由于∑p1(x)=∑p2(x)=1\sum {{p_1}(x) = } \sum {{p_2}(x) = } 1,提取后公式如下:

    JS(P1∣∣P2)=12∑p1(x)log⁡p1(x)p1(x)+p2(x)+12∑p2(x)log⁡p2(x)p1(x)+p2(x)+log⁡2JS({P_1}||{P_2}) = \frac{1}{2}\sum {{p_1}(x)\log } \frac{{{p_1}(x)}}{{{p_1}(x) + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{{p_2}(x)}}{{{p_1}(x) + {p_2}(x)}} + \log 2

    现在咱们留意到上式现已有一个log2了,为证明咱们的定论,咱们只需证明两个散布不重合时上式左面两项为0即可。咱们以下图为例进行解说:

    对抗生成网络GAN系列——WGAN原理及实战演练

    ​  关于上图中的两个散布P1和P2P_1和P_2,能够看出它们没有堆叠部分。咱们设p1(x)p_1(x)为x落在P1P_1散布上的概率,设p2(x)p_2(x)为x落在P2P_2散布上的概率,咱们来看JS(P1∣∣P2)JS(P_1||P_2)的左面两项,即12∑p1(x)log⁡p1(x)p1(x)+p2(x)+12∑p2(x)log⁡p2(x)p1(x)+p2(x)\frac{1}{2}\sum {{p_1}(x)\log } \frac{{{p_1}(x)}}{{{p_1}(x) + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{{p_2}(x)}}{{{p_1}(x) + {p_2}(x)}}

    ​  当x<5x<5时,p2(x)p_2(x)近似为0,则:

    12∑p1(x)log⁡p1(x)p1(x)+p2(x)+12∑p2(x)log⁡p2(x)p1(x)+p2(x)=12∑p1(x)log⁡p1(x)p1(x)+0+12∑0log⁡0p1(x)+0=0\frac{1}{2}\sum {{p_1}(x)\log } \frac{{{p_1}(x)}}{{{p_1}(x) + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{{p_2}(x)}}{{{p_1}(x) + {p_2}(x)}} = \frac{1}{2}\sum {{p_1}(x)\log } \frac{{{p_1}(x)}}{{{p_1}(x) + 0}} + \frac{1}{2}\sum {0\log } \frac{0}{{{p_1}(x) + 0}} = 0

    ​  当x≥5x \ge 5时,p1(x)p_1(x)近似为0,则:

    12∑p1(x)log⁡p1(x)p1(x)+p2(x)+12∑p2(x)log⁡p2(x)p1(x)+p2(x)=12∑0log⁡00+p2(x)+12∑p2(x)log⁡p2(x)0+p2(x)=0\frac{1}{2}\sum {{p_1}(x)\log } \frac{{{p_1}(x)}}{{{p_1}(x) + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{{p_2}(x)}}{{{p_1}(x) + {p_2}(x)}} = \frac{1}{2}\sum {0\log } \frac{0}{{0 + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{{p_2}(x)}}{{0 + {p_2}(x)}} = 0

    ​  因而关于任意的x都会有12∑p1(x)log⁡p1(x)p1(x)+p2(x)+12∑p2(x)log⁡p2(x)p1(x)+p2(x)=0\frac{1}{2}\sum {{p_1}(x)\log } \frac{{{p_1}(x)}}{{{p_1}(x) + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{{p_2}(x)}}{{{p_1}(x) + {p_2}(x)}}=0 ,进而会有JS(P1∣∣P2)=log2JS(P_1||P_2)=log2


    这部分主要参阅了这篇文章,我认为其对JS散度为什么等于log2解说的十分清楚


  2. 任意的两个散布pgp_gpdatap_{data}堆叠的概率有多大

    ​  首要,我直接给咱们清晰定论,堆叠概率很小,特别小。我阐明一下这儿所说的堆叠概率指的是堆叠部分不行疏忽的概率,也便是说pgp_gpdatap_{data}或许会堆叠,可是往往堆叠部分都是能够疏忽的。这部分咱们能够参阅这篇文章,写的十分好,关于这部分的解说也比较全面。

    ​  我从两个视点协助咱们理解,其一,pgp_gpdatap_{data}假设为一条曲线,pgp_g开端往往都是随机生成的,所以很难和pdatap_{data}重合,往往都是有一些交点,而这些交点是能够疏忽的。实际上,pgp_gpdatap_{data}的维度往往很高,它们彻底堆叠的概率十分之低。

    对抗生成网络GAN系列——WGAN原理及实战演练

    ​  其二,咱们往往是从pgp_gpdatap_{data}中采样一些点出来练习,即便pgp_gpdatap_{data}有堆叠,可是经过采样后,它们很大或许就没有堆叠了。

    对抗生成网络GAN系列——WGAN原理及实战演练


​  解说了这么多,咱们应该知道为什么GAN会练习困难了吧。最终总结一下,这是由于当判别器练习最优时,生成器的丢失函数等价于JS散度,其梯度往往一向为0,得不到更新,所以很难train。


​  我不知道咱们是否会有这样的疑问,既然说生成器梯度得不大更新,为什么咱们在练习GAN时仍是能得到一些较好的图片呢,不应该什么也得不到呢。我是这样认为的,咱们上述的推导都是树立在最优判别器的基础上的,可是在咱们实操过程中往往一开端判别器功能是不理想的,所以生成器仍是有梯度更新的。还有一点,便是我之前的文章代码中生成器的丢失其实运用的−logD(x)-logD(x),这种丢失函数也会存在一些问题,往往呈现梯度不稳定,模式崩溃问题,可是梯度消失问题呈现比较少,所以仍是能练习的。关于运用这种丢失函数带来的缺点我这儿也不叙说了,咱们感兴趣的能够看看这篇文章。

EM distance(推土机间隔)的引入

​  咱们上文详细分析了一般GAN存在的缺点,主要是由于和JS散度相关的丢失函数导致的。大佬们就在考虑能否有一种丢失能够代替JS散度呢?于是,WGAN应运而生,其提出了一种新的衡量两个散布间隔的规范——Wasserstein Metric,也叫推土机间隔(Earth-Mover distance)。下面就让咱们来看看什么是推土机间隔吧!!!


【这部分参阅此篇文章】

​  下图左面有6个盒子,咱们希望将它们都移动到右侧的虚线框内。比方将盒子1从方位1移动到方位7,移动了6步,咱们就将此过程的价值设为6;同理,将盒子5从方位3移动到方位10,移动了7步,那么此过程的价值为7,依此类推。

对抗生成网络GAN系列——WGAN原理及实战演练

                图1 EM distance示例

​  很显然,咱们有许多种不同的计划,下图给出了两种不同的计划:

对抗生成网络GAN系列——WGAN原理及实战演练

​  上图中右侧的表格表明盒子是如何移动的。比方在1{\gamma _1}中,其榜首行榜首列的值为1,表明有1个框从方位1移动到了方位7;榜首行的第四列的值为2,表明有2个框从方位1移动到了方位10。上图中给出的两种计划的价值总和都为42,可是关于一个问题并不是一切的移动计划的价值都是固定的,比方下图:

对抗生成网络GAN系列——WGAN原理及实战演练

​  在这个比如中,上图展现的两种计划的价值是不同的,一个为2,一个为6。,而推土机间隔便是穷举一切的移动计划,最小的移动价值对应的便是推土机间隔。对应本列来说,推土机间隔等于2。

​  相信经过上文的表述,你现已对推土机间隔有了必定了了解。现给出推土机间隔是数学界说,如下:

W(Pr,Pg)=inf⁡∈∏(Pr,Pg)E(x,y)[∣∣x−y∣∣]{\rm{W}}({P_r},{P_g}) = \mathop {\inf }\limits_{\gamma \in \prod ({P_r},{P_g})} {E_{(x,y)}}[||x – y||]

​  看到这个公式咱们是不是都懵逼了呢,这儿做相关的解说。∏(Pr,Pg){\prod ({P_r},{P_g})}表明边际散布Pr和PgP_r和P_g一切组合起来的联合散布(x,y)\gamma(x,y)的集合。咱们仍是用图1中的比如来解说,∏(Pr,Pg){\prod ({P_r},{P_g})}就表明一切的运送计划\gamma,下图仅列举了两种计划:

对抗生成网络GAN系列——WGAN原理及实战演练

E(x,y)[∣∣x−y∣∣]{E_{(x,y)}}[||x – y||] 能够看成是关于一个计划\gamma 移动的价值,inf⁡∈∏(Pr,Pg)E(x,y)[∣∣x−y∣∣]\mathop {\inf }\limits_{\gamma \in \prod ({P_r},{P_g})} {E_{(x,y)}}[||x – y||] 就表明在一切的计划中的最小价值,这个最小咱们便是W(Pr,Pg)W(P_r,P_g),即推土机间隔。


​  现在咱们现已知道了推土机间隔是什么,可是咱们还没解说清楚咱们为什么要用推土机间隔,即推土机间隔为什么能够代替JS散度成为更优的丢失函数?咱们来看这样的一个比如,如下图所示:

对抗生成网络GAN系列——WGAN原理及实战演练

​  上图有两个散布P1和P2P_1和P_2P1P_1在线段AB上均匀散布,P2P_2在CD上均匀散布,参数\theta能够操控两个散布的间隔。咱们由前文对JS散度的解说,能够得到:

JS(P1∣∣P2)={log⁡2≠00=0JS({P_1}||{P_2}) = \left\{ \begin{array}{l} \log 2{\rm{ }} \quad \theta \ne {\rm{0}}\\ 0{\rm{ }} \quad \quad \ \ \theta {\rm{ = 0}} \end{array} \right.

​  而关于推土机间隔来说,能够得到:

W(P1,P2)=∣∣W(P_1,P_2)=|\theta|

​  这样比照能够看出,推土机间隔是滑润的,这样在练习时,即便两个散布不堆叠推土机间隔仍然能够提高梯度,这一点是JS散度无法完成的。


WGAN的完成

​  现在咱们现已有了推土机间隔的界说,同时也解说了推土机间隔相较于JS散度的优势。可是咱们想要直接运用推土机间隔来界说生成器的丢失似乎是困难的,由于这个式子W(Pr,Pg)=inf⁡∈∏(Pr,Pg)E(x,y)[∣∣x−y∣∣]{\rm{W}}({P_r},{P_g}) = \mathop {\inf }\limits_{\gamma \in \prod ({P_r},{P_g})} {E_{(x,y)}}[||x – y||] 是难以直接求解的,可是呢,作者大大用了一个定理将上式变化成了如下方式:

W(Pr,Pg)=1ksup⁡∣∣f∣∣L≤KEx∼Pr[f(x)]−Ex∼Pg[f(x)]{\rm{W}}({P_r},{P_g}) = \frac{1}{k}\mathop {\sup }\limits_{||f|{|_L} \le K} {E_{x \sim {{\rm{P}}_r}}}[f(x)] – {E_{x \sim {{\rm{P}}_g}}}[f(x)]

​  留意上式中的f有一个约束,即∣∣f∣∣L≤K{||f|{|_L} \le K},咱们称为lipschitz接连条件。这个约束其实便是约束了函数f的导数。它的界说如下:

∣f(x1)−f(x2)∣≤K∣x1−x2∣|f(x_1)-f(x_2)| \le K|x_1-x_2|

​ 即:

∣f(x1)−f(x2)∣∣x1−x2∣≤K\frac{|f(x_1)-f(x_2)|}{|x_1-x_2|} \le K

​  很显然,lipschitz接连就约束了f的斜率的绝对值小于等于K,这个K称为Libschitz常数。咱们来举个比如协助咱们理解,如下图所示:

对抗生成网络GAN系列——WGAN原理及实战演练

​  上图中log(x)的斜率无界,故log(x)不满意lipschitz接连条件;而sin(x)斜率的绝对值都小于1,故sin(x)满意lipschitz接连条件。


​  这样,咱们只需要找到一个lipschitz函数,就能够核算推土机间隔了。至于怎样找这个lipschitz函数呢,便是咱们搞深度学习的那一套啦,只需要树立一个深度学习网络来进行学习就好啦。实际上,咱们新树立的判别器网络和之前的的基本是一致的,只是最终没有运用sigmoid函数,而是直接输出一个分数,这个分数能够反应输入图画的实在程度。

​  呼呼呼~~,WGAN的原理就为咱们介绍到这儿了,咱们掌握了多少呢?其实我认为只看一篇文章是很难把WGAN的一切细节都理解的,咱们能够看看本文的参阅文献,结合多篇文章看看能不能协助咱们处理一些困惑。

WGAN代码实战

​  WGAN的代码实战我不计划贴出一堆代码了,只阐明一下WGAN相较于一般GAN做了哪些改变。首要咱们给出论文中练习WGAN的流程图,如下:

对抗生成网络GAN系列——WGAN原理及实战演练

​  其实WGAN相较于原始GAN只做了4点改变,别离如下:

  1. 判别器最终不运用sigmoid函数
  2. 生成器和判别器的loss不取log
  3. 每次更新判别器参数后将判别器的权重截断
  4. 不适应根据动量的优化算法,引荐运用RMSProp

现在代码中别离对上述的4点做相关解说:【很简略,所以咱们想要将原始GAN修正为WGAN就按照下列的几点来修正就好了】

  1. 判别器最终不运用sigmoid函数

    这个咱们一般只需要删去判别器网络中的最终一个sigmoid层就能够了,十分简略。可是我还想提示咱们一下,有时候你在看别人的原始GAN时,他的判别器网络中并没有sigmoid函数,而是在界说丢失函数时运用了BCEWithLogitsLoss函数,这个函数会先对数据做sigmoid,相关代码如下:

    # 界说丢失函数
    criterion = nn.BCEWithLogitsLoss(reduction='mean')
    

    假如这时你想删去sigmoid函数,只需将BCEWithLogitsLoss函数修正成BCELoss即可。关于BCEWithLogitsLossBCELoss的差异我这篇文章有相关解说,感兴趣的能够看看。

  2. 生成器和判别器的loss不取log

    咱们先来看原始GAN判别器的loss是怎样界说的,如下:

    d_loss_real = criterion(d_out_real.view(-1), label_real)
    d_loss_fake = criterion(d_out_fake.view(-1), label_fake)
    d_loss = d_loss_real + d_loss_fake
    

    原始GAN运用了criterion函数,这便是咱们上文界说的BCEWithLogitsLossBCELoss,其内部是一个log函数,不清楚的相同能够参阅我的这篇文章。

    而WGAN判别器的丢失直接是两个希望(均值)的差,就对应理论中Ex∼Pr[f(x)]−Ex∼Pg[f(x)]{E_{x \sim {{\rm{P}}_r}}}[f(x)] – {E_{x \sim {{\rm{P}}_g}}}[f(x)] 这部分啦【可是需要加个负号,由于咱们都是运用的梯度下降方法,要最小化这个丢失,上式是最大化的公式】,咱们来看看代码中的完成:

    d_loss = -(torch.mean(d_out_real.view(-1))-torch.mean(d_out_fake.view(-1)))
    

    看完判别器的丢失,咱们再来看生成器的丢失:

    g_loss = criterion(d_out_fake.view(-1), label_real)  #原始GAN丢失
    ---------------------------------------------------
    g_loss = -torch.mean(d_out_fake.view(-1))            #WGAN丢失
    
  3. 每次更新判别器参数后将判别器的权重截断

    首要来说说没什么要进行权重截断,这是由于lipschitz接连条件不好确认,作者为了便利直接简略粗暴的约束了权重参数在[-c,c]这个范围了,这样就必定会存在一个常数K使得函数f满意lipschitz接连条件,详细的完成代码如下:

    # clip D weights between -0.01, 0.01  权重剪裁
      for p in D.parameters():
          p.data.clamp_(-0.01, 0.01)
    

    留意这步是在判别器每次反向传达更新梯度结束后进行的。还需要提示咱们一点,在练习WGAN时往往会多练习几次判别器然后再练习一次生成器,论文中是练习5次判别器后练习一次生成器,关于这一点在上文WGAN的流程图中也有所体现。为什么要这样做呢?我考虑这样做或许是想把判别器练习的更好后再来练习生成器,由于在前面的理论部分咱们的推导都是树立在最优判别器的前提下的。

  4. 不适应根据动量的优化算法,引荐运用RMSProp

    这部分就属于玄学部分了,作者做实验发现像Adam这类根据动量的优化算法作用不好,然后运用了RMSProp优化算法。我决议这部分咱们也不要纠结,直接用就好。相同咱们来看看代码是怎样写的,如下:

    optimizerG = torch.optim.RMSprop(G.parameters(), lr=5e-5)
    optimizerD = torch.optim.RMSprop(D.parameters(), lr=5e-5)
    

​  这样,WGAN的代码实战我就为咱们介绍到这儿,不论你前文的WGAN原理听理解了否,可是WGAN代码相信你是必定会修正的,改动的十分之少,咱们快去试试吧。

小结

​  这部分的理论确实是有必定难度的,我也看了十分十分多的视频和博客,写了许多笔记。我觉得咱们不必定要弄懂每一个细节,只要对里边的一些要害公式,要害思维有清晰的掌握即可;而这部分的实验较原始GAN需要修正的仅有四点,十分简略,咱们都能够试试。最终希望咱们都能够有所收获,就像WGAN相同稳定的进步,一同加油吧!!!

参阅链接

令人赞不绝口的Wasserstein GAN

GAN:两者散布不重合JS散度为log2的数学证明

GAN — Wasserstein GAN & WGAN-GP

GAN — Spectral Normalization

李宏毅【機器學習2021】生成式對抗網路

如若文章对你有所协助,那就         

对抗生成网络GAN系列——WGAN原理及实战演练