作者:Zheng Li
https://zhuanlan.zhihu.com/p/684269963
大家别只点收藏,多点点赞~
《PromptKD: Unsupervised Prompt Distillation for Vision-Language Models》
主页:https://zhengli97.github.io/PromptKD/
代码:https:///zhengli97/PromptKD
论文:https://arxiv.org/abs/2403.02781
一句话概括:
PromptKD是一个简单有效的基于prompt的视觉语言模型蒸馏新方法,在prompt learning的11个benchmark数据集上大幅领先,达到了SOTA。
大白话背景介绍
已经很了解VLMs和prompt learning的同学可以直接跳过,到背景问题~
什么是视觉-语言模型(Vision-Language Models, VLMs)?
视觉语言模型VLM一般由两个部分构成,即视觉(Vision)部分和语言(Language)部分。
以一个经典的VLM网络 CLIP[1] 的结构为例:
图1. CLIP架构。图片来自于CLIP论文。
如图1所示,CLIP由text branch和image branch组成。
其中, text branch主要由transformer构成,当要进行cls_num个类的分类任务时,会取每个类别对应的名称,如"plane", "car", "dog",与"a photo of a"进行组合,作为prompt输入进text encoder,得到大小为[cls_num, feat_dim]的text feature。
image branch的核心就是对输入的图像提取image feature,其通常为ResNet或者ViT[2]。图像经过image encoder之后得到image feature,其大小为[batch_size, feat_dim]。
将两个feature进行相乘就得到了预测logits。
CLIP有两个明确的特性,是这个工作的基础:
- CLIP可以进行zero-shot分类,即对未见过的类别进行识别,并保持很高的性能。而传统的CNN或者ViT由于模型架构限制不可以。
- 对于已知的类别,CLIP的text branch只需要一次forward就可以得到对应text feature用于分类。
什么是提示学习(Prompt Learning)?
在Text Branch部分中,a photo of a {class_name} 这样的描述太过宽泛,明显不是最优的。例如对于图2(b)的花,手工设计的a flower photo a {class}要描述的更加精确,其产生的结果就更好。
图2. 蓝色方块代表手动设计的prompt,绿色方块代表网络学习得到的learnable prompt。绿色方块acc超越了蓝色。图片来自于CoOp论文。
这就产生来两个问题,第一,固定模板的prompt不是最优的。第二,针对性的手工设计费时费力,且无法泛化。
于是,提示学习(Prompt Learning)[3] [4]就提出将prompt变成了一种learnable的方式,通过优化的方法让prompt在下游数据集上学习适用的表征,来替代手工设计的prompt,参考图2中的绿色方块。
这样优势是,可以在少量数据的情况下,仅通过引入一少部分的可学习参数(即learnable prompt),就可以将原始的CLIP快速适用到下游的任务/数据,同时在性能上比全参数微调的结果更好[4]。
实验衡量指标是什么?
有三个指标,分别是base acc,novel acc和harmonic mean。
以imagenet-1k数据集为例,会取1000类中的前500类作为base class,后500类作为novel class。模型在base class上训练,完成后在base class和novel class上测试acc性能。因为novel class与base class数据类别不重复,所以novel acc可以有效反应模型泛化性能。harmonic mean指标是对base acc和novel acc的综合反映,为harmonic mean = (2*base acc*novel acc) / (base acc+novel acc)。总体的harmonic mean值越高,模型综合性能越好。
背景问题
prompt learning的核心作用是,保持原始CLIP参数不变,通过引入小部分learnable prompt参数,来将大的原始的经过预训练的CLIP模型适用到下游任务/数据上,提升CLIP模型在下游任务的性能,同时保持CLIP模型zero-shot能力。
除去一直发展至今的各种设计prompt形式的工作[3] [5] [6] [7] [8] [9] [10] [11] [12] [13],现如今最前沿的prompt learning方法主要还可以分为另外两类:
1. 引入额外数据/信息。这一类工作核心就是通过引入额外的数据或信息,做法包括但不限于,
(1) 通过LLM来生成{class_name}相关的语句,获得额外的有关{class_name}的特性 特征[14] [15][16],或者更多描述性语句[17] [18] [19] [20]。
(2) 引入额外的数据源,从wikipedia上引入文本描述[21],从额外数据集例如ImageNet-21K来做预训练 [22]。
(3) 设计给原始图像数据引入额外的tag或标注[23] [24] [25]。
从以上的方式我们看到,大部分引入额外数据信息的工作都是围绕text branch展开,本质原因是输入的text本身"{class_name}"或"a photo of a {classname}"包含信息太少,丰富度要远低于image,通过额外的域内文本信息的引入,可以显著增强text feature的质量。所以text feature的质量是关键。
同时,可以看到,围绕image branch的工作是相对较少的。这时候问题就来了:那我们可不可以用同样的思路来增强image feature呢?
诶,这个方法好!因为互联网内往往存在非常大量的图像数据,很容易获取。
但问题是这些图像往往是没有标注的,没办法用gt训,如果要去进行标注,需要消耗很多的时间或者钱。明显限制了这种方式的应用。
2. 利用原始CLIP自身信息约束模型学习[19] [26] [27] [28] [29] [30] [31],防止过拟合。
在Prompt learning中,learnable prompt的参数量是相对较少的,在经过大量base class数据训练之后,模型会对base class数据存在过拟合,丧失对novel class的泛化性能。要解决这个问题,一种非常有效的做法就是利用vanilla CLIP来约束带有prompt的模型的学习。
以ICCV 23 PromptSRC为例,如图3所示,
图3. PromptSRC结构图。图片来自于PromptSRC论文。
图3这篇工作就看两条线,蓝线和灰线。
蓝线,就是原始CLIP的前向计算路径,分别会得到对应的image和text feature。
灰线,就是带有learnable prompt的计算过程,也会得到对应的feature。
在两条线的末尾,计算了三个loss,这里就是用原始CLIP产生的image和text feature来约束由含有learnable prompt产生的image和text feature。通过这样的约束,限制了prompt向着base class过拟合,达到了SOTA的性能。
由这个工作我们就想,如果换一个更好的模型来做约束是不是性能会更好?
于是,这就引出了我们的工作。
方法
PromptKD其实核心就在做一件事,引入更大的CLIP模型作为teacher,解决了上面提到的三个问题。
(1) 重用(Reuse) teacher CLIP产生的text feature用于学生的训练和推断。这样确保了text feature高质量的同时,还显著的节省计算量,训练时只涉及student的image encoder。
(2) 对齐学生CLIP和教师CLIP的logits。让大的CLIP模型给小的学生CLIP模型提供更好的监督。
(3) 因为有了教师CLIP的存在,就解决了数据量限制的问题,我们可以用大量的无标签domain data来训学生,不再拘泥于原来有限的有标签数据。在训练时,我们直接可以使用数据集的全量数据作为无标签数据进行蒸馏,这样一来就prompt就可以学到更广泛的domain knowledge。同时高性能的教师CLIP也保证了用于蒸馏的软标签的准确性。
我们先来看一个简单的结构缩略图:
图4. PromptKD框架简略图。
黄色的方块部分代表的就是教师CLIP,在教师CLIP经过训练之后,直接一次forward,得到并保存下来对应类别的text feaure,也就得到了图4中的Pre-stored Text Feature。
蓝色的方块代表的是学生CLIP,这里其实就只有一个image encoder,在带有learnablr prompt的输入进入image encoder之后会得到对应的image feature,这是因为与teacher text feature在维度上不匹配,所以经过一个Projector,将512转成768维的特征。然后再与Pre-stored Text Feature相乘,得到logits。
然后进行蒸馏。
完整的框架图如图5所示:
图5. PromptKD整体框架图。
图5里就是图4过程的细化。
这里将PromptKD的每个阶段都进行了详细的阐明。大家看图就明白了~
第一阶段,教师模型的预训练。在这里,我们选择之前的SOTA方法PromptSRC去预训练我们的教师ViT-L/14 CLIP模型,我们的学生模型是ViT-B/16 CLIP模型。
注意,这里的预训练不是必须的一步,选择去预训练教师模型,是为了让教师有一个更好的性能,从而有更好的学生蒸馏结果。如果直接使用vanilla ViT-L/14 CLIP作为教师,相比于baseline,也取得了明显的性能提升,具体结果请参考表4。
第二阶段,学生CLIP模型的蒸馏。
第三阶段,学生的推断。
最后再来一个简洁明了的流程概括图:
图6. 计算流程
实验结果
我们的PromptKD方法在prompt learning的11个benchmark dataset上都达到了SOTA的性能。
Base-to-novel实验:
表1. Base-to-novel实验结果。
图7. HM分数在11个数据集上的总揽图。
Cross-dataset实验
表2. Cross-dataset实验结果。
消融实验
为了实验快速进行,消融实验里使用的不是全量数据集,而是64 shots per class进行的训练。所以会与表1中的数据相比略低。
与其他同样使用了无标签数据的工作的性能对比
表3. 在Flowers102数据集上与使用了无标签数据的其他方法的对比结果。
教师预训练方法的选择
在PromptKD中,任意类型的ViT-L/14 CLIP教师模型都可以蒸馏出一个很好的ViT-B/16 CLIP模型,相比于baseline (70.22 HM)都有明显的提升。
这里有一点非常有意思的是,我们可以看到,第四行的Teacher(CLIP) ViT-L/14也就是原始的CLIP模型,在经过PromptKD的蒸馏之后,我们的ViT-B/16 CLIP的结果(表1(b))明显超过了原始的ViT-L/14 CLIP模型。(77.62 vs. 76.52)
表4. 不同教师预训练方法对PromptKD蒸馏效果的影响。
不同容量教师模型的选择
如表5所示,绿色代表学生ViT-B/16 CLIP的HM分数,土黄色代表教师的HM分数。教师的性能越高,越能训练出更好的学生。
图8. 不同容量的CLIP模型作为教师进行蒸馏。
欢迎大家试用PrompKD~
Acknowledgement
这篇论文解读感谢师弟武戈同学的部分论文总结,PromptKD这篇工作也非常感谢蚂蚁的申书恒,张长浩和傅幸同学的讨论和帮助。