CLIP: Learning Transferable Visual Models From Natural Language Supervision

OpenAI的神作CLIP,文章发表在ICML-2021,于2021年3月挂在arXiv上的。

摘要

当前的计算机视觉(CV)模型通常被训练用于预测有限的物体类别。这种严格的监督训练方式限制了模型的泛化性和实用性,因为这样的模型通常还需要额外的标注数据来完成训练时未曾见过的视觉“概念”。**直接从图片的描述文本中学习是一个有潜力的选择,因为这样我们可以获取更多的监督信号。**这篇文章中,我们证明了利用一个简单的预训练任务(即预测哪个文本描述对应当前图像)在一个从互联网上搜集的4亿个(图像,文本)对的数据集上可以取得SOTA的图像表征。预训练完之后,在下游任务上,我们可以通过用自然语言(文本)匹配视觉概念(图像)从而实现zero-shot transfer。我们在30个不同类型的下游CV 任务上进行了基准测试,并展示了我们模型强大的迁移能力,其在很多下游任务上不需要任何额外的数据也能比拟完全supervised的模型。比如,我们的模型在ImageNet上的zero-shot accuracy能达到在ImageNet上全监督训练的ResNet-50的性能。

Motivation

在NLP中,预训练的方法目前其实已经被验证很成功了,像BERT和GPT系列之类的。其中,GPT-3从网上搜集了400 billion byte-pair-encoded tokens进行预训练然后可以在很多下游任务上实现SOTA性能和zero-shot learning。这其实说明从web-scale的数据中学习是可以超过高质量的人工标注的NLP数据集的。 然而,对于CV领域,目前预训练模型基本都是基于人工标注的ImageNet数据集(含有1400多万张图像),那么借鉴NLP领域的GPT-3从网上搜集大量数据的思路,我们能不能也从网上搜集大量图像数据用于训练视觉表征模型呢? 作者先是回顾了并总结了和上述相关的两条表征学习路线: (1)构建image和text的联系,比如利用已有的(image,text)pair数据集,从text中学习image的表征; (2)获取更多的数据(不要求高质量,也不要求full labeled)然后做弱监督预训练,就像谷歌使用的JFT-300M数据集进行预训练一样(在JFT数据集中,类别标签是有噪声的)。具体来说,JFT中一共有18291个类别,这能教模型的概念比ImageNet的1000类要多得多,但尽管已经有上万类了,其最后的分类器其实还是静态的、有限的,因为你最后还是得固定到18291个类别上进行分类,那么这样的类别限制还是限制了模型的zero-shot能力。 这两条路线其实都展现了相当的潜力,前者证明paired text-image可以用来训练视觉表征,后者证明扩充数据能极大提升性能,即使数据有noise。于是high-level上,作者考虑从网上爬取大量的(text,image)pair以扩充数据,同时这样的pairs是可以用来训练视觉表征的。作者随即在互联网上采集了4亿个(text,image)对,准备开始训练模型。

Model

3.1 Objective 海量的(image,text)数据有了,问题是怎么设计并高效地训练模型。作者提出CLIP的模型,可以认为是ConVIRT的简化版。这里先简单回顾下ConVIRT (咋一看是不是觉得CLIP和ConVIRT一摸一样... ).

CLIP: Learning Transferable Visual Models From Natural Language Supervision文献_clip

VonVIRT用(image,text)对来训练模型,其有一个image encoder和一个text encoder,训练目标是让两路的representation尽可能得一致(对偶地最大化表征的agreement),其中gv和gu函数是一个non-linear得projection head,负责分别将图像和文本表征投影到一个shared的空间,从而计算距离。

CLIP: Learning Transferable Visual Models From Natural Language Supervision文献_zero-shot_02

CLIP: Learning Transferable Visual Models From Natural Language Supervision文献_clip_03

其实就是构造了一个对称的contrastive loss,在一个batch内预测谁是正样本

基于ConVIRT,CLIP主要做出了以下简化:

  • ConVIRT中的image encoder的参数是ImageNet初始化的,而CLIP直接用random初始化;
  • ConVIRT的projection head是non-linear的,而CLIP采用linear的projection;
  • CLIP去掉了ConVIRT中text transformation(指均匀从text中采样句子),因为CLIP数据集中有很多只出现过一次的(image,text);
  • CLIP的image transformation只用了resize和squared crop;
  • CLIP loss中的temperature参数τ是可学的。


CLIP: Learning Transferable Visual Models From Natural Language Supervision文献_zero-shot_04

一个batch里有N对(image,text),然后和ConVIRT一样做对称的contrastive learning,伪代码如下:

CLIP: Learning Transferable Visual Models From Natural Language Supervision文献_clip_05

3.2 Inference / Zero-shot prediction

一旦CLIP训练好了,我们就可以做zero-shot prediction了,如图所示:

CLIP: Learning Transferable Visual Models From Natural Language Supervision文献_多模态_06

步骤可以整理成下面这样:

  • Sample所有N个class,得到N个input text,都经过text encoder编码得到对应的N个class text embedding(我这里之所叫embedding而不叫representation是想说明这个特征是经过encoding和projection得到的);
  • Sample一个要预测的image,得到其image embedding;
  • 以N个text embedding为key,以当前image embedding为query,算cosine相似度,相似度最高的即为Top-1的prediction class。
import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")