概述

GAN的发明者Ian Goodfellow2016年在Open AI任职期间发表了这篇论文,其中提到了GAN用于半监督学习(semi supervised)的方法。称为SSGAN。 作者给出了Theano+Lasagne实现。本文结合源码对这种方法的推导和实现进行讲解。1

半监督学习

考虑一个分类问题。 
如果训练集中大部分样本没有标记类别,只有少部分样本有标记。则需要用半监督学习(semi-supervised)方法来训练一个分类器。

wiki上的这张图很好地说明了无标记样本在半监督学习中发挥作用: 

半监督学习中有三个常用的基本假设 gan半监督分类_sed

如果只考虑有标记样本(黑白点),纯粹使用监督学习。则得到垂直的分类面。 
考虑了无标记样本(灰色点)之后,我们对样本的整体分布有了进一步认识,能够得到新的、更准确的分类面。

核心理念

在半监督学习中运用GAN的逻辑如下。

  • 无标记样本没有类别信息,无法训练分类器;
  • 引入GAN后,其中生成器(Generator)可以从随机信号生成伪样本;
  • 相比之下,原有的无标记样本拥有了人造类别:真。可以和伪样本一起训练分类器。 

举个通俗的例子:就算没人教认字,多练练分辨“是不是字”也对认字有好处。有粗糙的反馈,也比没有反馈强。

原理

框架

GAN中的两个核心模块是生成器(Generator)和鉴别器(Discriminator)。这里用分类器(Classifier)代替了鉴别器。 

半监督学习中有三个常用的基本假设 gan半监督分类_半监督学习中有三个常用的基本假设_02

训练集中包含有标签样本xlxl和无标签样本xuxu。 
生成器从随机噪声生成伪样本IfIf。 
分类器接受样本II,对于KK类分类问题,输出K+1K+1维估计ll,再经过softmax函数得到概率pp:其前KK维对应原有KK个类,最后一维对应“伪样本”类。 
pp的最大值位置对应为估计标签yy。



softmax(xi)=exp(xi)∑jexp(xj)softmax(xi)=exp⁡(xi)∑jexp⁡(xj)


三种误差

整个系统涉及三种误差。

对于训练集中的有标签样本,考察估计的标签是否正确。即,计算分类为相应的概率: 



Llabel=−E[lnp(y|x)]Llabel=−E[ln⁡p(y|x)]


对于训练集中的无标签样本,考察是否估计为“真”。即,计算不估计为K+1K+1类的概率: 



Lunlabel=−E[ln(1−p(K+1|x))]Lunlabel=−E[ln⁡(1−p(K+1|x))]


对于生成器产生的伪样本,考察是否估计为“伪”。即,计算估计为K+1K+1类的概率: 



Lfake=−E[lnp(K+1|x)]Lfake=−E[ln⁡p(K+1|x)]


推导

考虑softmax函数的一个特性: 



softmax(xi−c)=exp(xi−c)∑jexp(xj−c)=exp(xi)/exp(c)∑jexp(xj)/exp(c)=softmax(xi)softmax(xi−c)=exp⁡(xi−c)∑jexp⁡(xj−c)=exp⁡(xi)/exp(c)∑jexp⁡(xj)/exp⁡(c)=softmax(xi)



即,如果输入各维减去同一个数,softmax结果不变。 


于是,可以令

l→l−lK+1l→l−lK+1 ,有 lK+1=0lK+1=0 , p=softmax(l)p=softmax(l) 保持不变。

期望号略去不写,利用explK+1=1,exp⁡lK+1=1,后两种代价变为: 



Lunlabel=−ln[1−p(K+1|x)]=−ln[∑Kj=1explj∑Kj=1explj+explK+1]=−ln[∑j=1Kexplj]+ln[1+∑j=1Kexplj]Lunlabel=−ln⁡[1−p(K+1|x)]=−ln⁡[∑j=1Kexp⁡lj∑j=1Kexp⁡lj+exp⁡lK+1]=−ln⁡[∑j=1Kexp⁡lj]+ln⁡[1+∑j=1Kexp⁡lj]




Lfake=−ln[p(K+1|x)]=ln[1+∑j=1Kexplj]Lfake=−ln⁡[p(K+1|x)]=ln⁡[1+∑j=1Kexp⁡lj]


上述推导可以让我们省去lK+1lK+1,让分类器仍然输出K维的估计ll。

对于第一个代价,由于分类器输入必定来自前K类,所以可以直接使用ll的前K维: 



Llabel=−ln[p(y|x,y<K+1)]=−ln[exply∑Kj=1explj]=−ly+ln[∑j=1Kexplj]Llabel=−ln⁡[p(y|x,y<K+1)]=−ln⁡[exp⁡ly∑j=1Kexp⁡lj]=−ly+ln⁡[∑j=1Kexp⁡lj]


引入两个函数,使得书写更为简洁:



LSE(x)=ln[∑j=1expxj]LSE(x)=ln⁡[∑j=1exp⁡xj]




softplus(x)=ln(1+expx)softplus(x)=ln⁡(1+exp⁡x)


三个误差: 



Llabel=−ly+LSE(l)Llabel=−ly+LSE(l)




Lunlabel=−LSE(l)+softplus(LSE(l))Lunlabel=−LSE(l)+softplus(LSE(l))




Lfake=softplus(LSE(l))Lfake=softplus(LSE(l))


优化目标

对于分类器来说,希望上述误差尽量小。引入权重ww,得到分类器优化目标: 



LD=Llabel+w2(Lunlabel+Lfake)LD=Llabel+w2(Lunlabel+Lfake)


对于生成器来说,希望其输出的伪样本能够骗过分类器。生成器优化目标与分类器的第三项相反: 



LG=−LfakeLG=−Lfake


实验

本文的实验包含三个图像分类问题。分类器接受图像xx,输出KK类分类结果ll。生成器从均匀分布的噪声zz生成一张图像xx。

MNIST

10分类问题,图像为28*28灰度。

生成器是一个3层线性网络: 

半监督学习中有三个常用的基本假设 gan半监督分类_半监督学习_03

分类器是一个6层线性网络: 

半监督学习中有三个常用的基本假设 gan半监督分类_sed_04

训练样本60K个,测试样本10K个。 
选择不同数量的训练样本给予标记,考察测试样本中错误个数。使用不同随机数种子重复10次:

有标记样本

20

50

100

200

占比

0.033%

0.083%

0.17%

0.33%

错误个数

1677±452

221±136

93±6.5

90±4.2

Cifar10

10分类问题,图像为32*32彩色。

生成器是一个4层反卷积网络: 

半监督学习中有三个常用的基本假设 gan半监督分类_生成器_05

分类器是一个9层卷积网络: 

半监督学习中有三个常用的基本假设 gan半监督分类_半监督学习中有三个常用的基本假设_06

训练样本50K个,测试样本10K个。 
选择不同数量的训练样本给予标记,考察测试样本中错误个数。使用不同的测试/训练分割重复10次:

有标记样本

1000

2000

4000

8000

占比

2%

4%

8%

16%

错误个数

21.83±2.01

19.61±2.09

18.63±2.32

17.72±1.82

SVHN

10分类问题,图像为32*32彩色。

生成器(上)以及分类器(下)和CIFAR10的结构非常类似。 

半监督学习中有三个常用的基本假设 gan半监督分类_半监督学习中有三个常用的基本假设_07

训练样本73K,测试样本26K。 
选择不同数量的训练样本给予标记,考察测试样本中错误个数。使用不同的测试/训练分割重复10次:

有标记样本

500

1000

2000

占比

0.68%

1.4%

2.7%

错误个数

18.84±4.8

8.11±1.3

6.16±0.58


  1. USC的Shao-Hua Sun也给出了一个Tensorflow实现。但没有处理训练集中的无标签样本