Conditional DETR for Fast Training Convergence
- Conditional DETR
- DETR Decoder Cross-Attention
- Conditional Cross-Attention
- Experiments
- Ablations
- The effect of the way forming the conditional spatial query
- Focal loss and offset regression with respect to learned reference point
- The effect of linear projections T forming the transformation
论文连接:https://arxiv.org/abs/2108.06152v2
源码连接:https://github.com/Atten4Vis/ConditionalDETR
DETR中的交叉注意力模块高度依赖于内容嵌入来定位四端和预测方框,这增加了对高质量内容嵌入的需求,从而增加了训练的难度
Conditional DETR的主要内容是从decoder embedding学习出一个conditional spatial query(条件空间查询)
其好处是,通过条件空间查询,每个交叉注意头都能够关注一个不同区域,例如,一个对象的末端或对象框内的一个区域
这缩小了定位对象类别和盒子区域的空间范围,从而放松了对内容嵌入的依赖,简化了训练
DETR方法的训练收敛速度缓慢,需要500个epochs才能获得良好的性能
在交叉注意力中,内容嵌入是起着最主要的作用,空间嵌入是作为次要的贡献
如果在DETR中,移除key中的空间位置编码,,移除第二个decoder以后的object queries
只使用内容嵌入和query,mAP下降的不多
第一行是Conditional 训练50轮
第二行是DETR训练50轮
第三行是DETR训练500轮
可以看出第二行最后两个预测还没有学习好
原因为:
- 空间查询,即对象查询,只给出一般的注意力权重图,而没有利用特定的图像信息
- 由于训练时间较短,content queries不足以很好地匹配spatial keys,因为它们也被用于匹配content keys
Conditional DETR方法,该方法从相应的解码器输出嵌入中学习每个查询的条件空间嵌入,以形成所谓的解码器多头交叉注意的条件空间查询
条件空间查询是通过将回归对象框的信息映射到嵌入空间
Conditional DETR
Conditional DETR的模型构成与DETR相同:
- backbone
- encoder
- decoder
- ffn
Conditional DETR主要是修改了decoder部分,其他部分是相同的
DETR Decoder Cross-Attention
DETR解码器交叉注意机制采用三种输入:query,key,value。
每个键(key)由内容键(content key )(Ck)(来自编码器的content embedding输出)和一个空间键(spatial key)(Pk)(相应的归一化二维坐标的位置嵌入)来形成的
value是就是编码器(encoder)的content embedding输出
在原始的DETR方法中,每个查询(query)由内容查询(content query)(Cq)(来自解码器自注意的embedding output)和 一个空间查询(spatial query)(Pq)(即对象查询 Object query Oq)来形成的。
注意权重是基于查询和键之间的点积
Conditional Cross-Attention
交叉注意权重由内容注意权重和空间注意权重两个组成部分组成
与DETR交叉注意不同,我们的机制将内容查询和空间查询的角色分开,使空间查询和内容查询分别关注空间注意权重和内容注意权重
另一个重要的任务是从前一个解码器层的嵌入f中计算空间查询Pq。我们首先确定了不同区域的空间信息是由解码器嵌入和参考点这两个因素共同决定的
Conditional spatial query prediction
我们从嵌入的f和参考点s来预测条件空间查询
可视化了每个头的注意力权重图
- 第一行是 spatial attention weight
- 第二行是 content attention weight
- 第三行是 这两者的结合
decoder的self attention的输出作为query,同时需要查询出内容以及空间位置信息
那么在原始的DETR,就会需要很长的时间才能够学好
而条件空间查询就是在有意的把一份空间信息concat到self attention的输出上
Experiments
Ablations
The effect of the way forming the conditional spatial query
- CSQ-C 表示只使用decoder的content embedding
- CSQ-T 表示最有转换
- CSQ-P 表示只有位置编码 Ps
- CSQ-I 表示 与
Focal loss and offset regression with respect to learned reference point
- OR表示偏移量回归(offset regression)
- FL 表示focal loss
- CSQ 就是 Conditioanal spatial query
The effect of linear projections T forming the transformation
- 一个单位矩阵,意味着不学习线性投影
- 一个单个标量
- 一个块对角矩阵,意味着每个头部有一个学习的32×32线性投影矩阵
- 一个没有约束的完整矩阵
- 一个对角矩阵
有趣的是,单标量有助于提高性能,这可能是由于缩小了目标区域的空间范围