WGAN


文章目录

  • WGAN
  • 摘要
  • 1、问题
  • 1.1 The nature of data
  • 1.2 sampling
  • 1.3 What is the problem of JS divergence?
  • 2、WGAN
  • 2.1 Wasserstein distance
  • 2.2 Wasserstein distance的计算
  • 3、GAN的问题
  • 3.1 在训练上
  • 3.2 GAN for Sequence Generation
  • 4、GAN的评估
  • 方法
  • 总结与展望



摘要


本章要讲的是WGAN及GAN的评估。在对GAN进行训练时,有几个难点,因为在高维空间中P_G和P_data的分布重叠部分是非常少的,且就算重叠部分多当训练数据不够多时D也可以轻易辨别G生成的数据,同时JS divergence对于没有重叠的分布计算结果都是log2,从而无法正确测量出P_G和P_data的距离,模型的进步难以看出,所以GAN是很难训练的,为了解决该问题可以用WGAN,WGAN本质是用Wasserstein distance替换掉JS divergence,以确保D可以有效测量出P_G和P_data的距离。GAN的生成主要有两类问题,一类是生成数据不够真实,另一类是生成的数据多样性不够,所以在对GAN的好坏进行评估时,需要从以上两个方面进行评估。



1、问题

1.1 The nature of data

图片是P_G和P_data在低维空间的manifold,在高维空间中P_G和P_data重叠的部分往往是特别少,以二维空间为例,它们的分布曲线就像两条直线,相交点是极少的。

gan与wgan及其实战 pytorch wgan和gan_概率分布

1.2 sampling

我们在训练时,是不知道P_G和P_data的分布的,都是从其data中sampling一些点来进行训练,当sampling的点量不够多时,就算P_G和P_data有较大相交面积可以重合,D也是可以很容易的将其区分的。

gan与wgan及其实战 pytorch wgan和gan_数据_02

1.3 What is the problem of JS divergence?

对于两个没有重叠的分布,JS divergence的结果都是log2,所有我们就只能知道相交的和不相交的。

gan与wgan及其实战 pytorch wgan和gan_数据_03


所以在实际操作时,当我们用JS divergence时,我们训练一个Binary classifier时,当训练完后会发现,D的正确率接近100%,也就是说G并没有骗过D,这是为啥?是因为我们sample的图片数量太少,D可以轻易的分辨。

2、WGAN

第一节可以发现训练困难主要是JS divergence的问题,要解决这个问题,需要找一个新的计算divergence的方式。

2.1 Wasserstein distance

现在有两个分布P和Q,将P的分布以Q为目标进行改造,改造时移动的平均距离就是Wasserstein distance,当P的分布比较复杂时,变成Q的方法就有很多种了,不同的方法算出来的平均距离也就不同了。所以此时是需穷举所有的方法,取其中平均距离最小的方法作为Wasserstein distance.

gan与wgan及其实战 pytorch wgan和gan_数据_04


如此,当P_G和P_data比较相似时,Wasserstein distance的值就会比较小,解决了JS divergence存在的问题.

gan与wgan及其实战 pytorch wgan和gan_数据_05

2.2 Wasserstein distance的计算

从下面的function可以看出这个就是要求D给P_data的分数越高越好,给P_G的分数越低越好.

gan与wgan及其实战 pytorch wgan和gan_数据_06


其中D也必须是足够平滑的function.

3、GAN的问题

3.1 在训练上

虽然我们有了WGAN,但是GAN仍然是比较难训练的,它有一个本质上的问题,在GAN中,Generator(G)要做的事情是产生假的图片去骗过Discriminator(D),D要做的事情就是去辨别真的图片和G产生的图片,在训练时,它们两个是互相进步共同成长的,只有其中一个发生问题停止进步,那么另外一个也会停止进步,也就是说我们在训练GAN时,当D一下子没训练好,无法分辨真的和G生成的图片差异,G就会失去进步的目标就会停止进步,与此同时D也会跟着停下来,而此时若还没找到最好的结果,这次训练就失败了.所有说,GAN现在还是很难训练的.

3.2 GAN for Sequence Generation

training GAN时,最难的是training句子生成的GAN,例如下图,我们在对其进行训练时,用梯度下降的优化算法时,改变Decoder的参数时,因为改变幅度很小,所以输出的最大值max可能并没有变化,这就导致输出没有变化,然后model就坏掉了。当然Reinforcement learning(RL)是可以对其进行硬training的,但是RL本身就是比较难train的,GAN也很难train,加在一起就更难train了。但是最近有一篇ScrachGAN 文章有介绍了如何将GAN train起来。

gan与wgan及其实战 pytorch wgan和gan_生成图片_07

4、GAN的评估

将产生的图片y,丢给影像分类系统里看其产生怎样的图片分布P(c|y),当分布是比较集中的,说明影像分类系统很明确的知道产生了怎样的图片,而当影像分类系统识别不出产生的图片,则说明此时产生的图片是不真实的。

gan与wgan及其实战 pytorch wgan和gan_生成图片_08


但是光用这个办法是不够的,会被Mode Collapse的现象骗过去, Mode Collapse是machine产生的图片虽然很真实但其实只是在几张图片上进行了微调,显然这样的GAN是不好的,但是影像分类系统识别不出来。

gan与wgan及其实战 pytorch wgan和gan_生成图片_09


现在并么有好的办法可以解决这个问题,还有一个Mode Dropping的问题,例如在做人脸生成时,生成的人脸全部都是黄皮肤的,这个现在也没有得到本质上的解决。

gan与wgan及其实战 pytorch wgan和gan_生成图片_10


这两个问题的本质都是生成图片多样性不够。有一个评估的办法是将一定量的GAN产生的图片丢给影像分类系统,取每张图片得到的分布的平均值,当平均分布比较集中则说明生成图片多样性不够。

gan与wgan及其实战 pytorch wgan和gan_生成图片_11

方法

  1. WGAN是用Wasserstein distance替代JS divergence来测量P_G和P_data的距离,Wasserstein distance的公式如下:
  2. 对GAN进行评估时,以影像分类系统为例,对生成图片进行真实性评估时可以将G生成图片丢给影像分类系统,观察影像分类系统输出的概率分布,当概率分布比较集中则说明G生成图片真实度高。对生成图片进行多样性评估时可以将G生成的多个图片丢给影像分类系统,再计算其平均概率分布,当概率分布比较集中则说明G生成的图片多样性不够,反之亦然。

总结与展望

GAN的训练困难,主要原因在于JS divergence无法准确的测量P_G和P_data的分布差距,WGAN用Wasserstein distance替代JS divergence,可以很好的解决这个问题。虽然如此,但是WGAN依旧是很难train的,本质在于G和D是共同进步的,当训练时,D遇到错误停止进步了,G也会跟着停下来,至今还没有很好的办法能解决这个问题。对GAN的好坏进行评估时,可以用分类系统对GAN生成的数据进行分析可以很好的完成model的评估。