CenterNet算法笔记
- 1.核心思想
- 2.网络结构
- 2.1Backbone网络
- 2.2neck网络
- 2.3head网络
- 2.3.12D检测
- 2.3.1.1 keypoints heatmap
- 2.3.1.2 local offset
- 2.3.1.3 object size
- 2.3.2 3D检测
- 3.推理流程
- 4.我对CenterNet的几个问题
- 4.1 为什么CenterNet不需要NMS?
- 4.1.1 原论文的解释
- 4.1.2 对原论文解释的疑惑
- 4.1.3 CenterNet的heatmap和YOLOv3的置信度特征图有什么区别?
- 4.2 推理时获取peaks后是否需要阈值处理?
- 4.3 CenterNet是否易于硬件部署?
名为CenterNet的算法有两个,本文梳理的是论文《Objects as Points》,而不是《CenterNet:Keypoint Triplets for Object Detection》,望周知。由于本人目前的工作主要集中在目标检测方面,因此这里只对CenterNet算法的2D检测部分进行梳理,并简单介绍下3D检测,并不会涉及关键点检测部分。
1.核心思想
现有的检测算法都需要对大量的候选位置进行分类,并且还需要后处理(主要是NMS),比如二阶段的Faster RCNN、一阶段的YOLOv3、anchor-free的FCOS、CornerNet等算法,都是如此。
CenterNet将目标建模为目标的中心点,将目标检测问题转化为关键点检测问题(这一点与CornerNet一样,只不过CornerNet是检测gt box的左上和右下点,而CenterNet是检测gt box的中心点)。将图像输入网络,生成keypoints heatmap,heatmap上的峰值点(peaks,特征图上值>=8邻域内其它点的中心点,称为peak)对应的就是目标的中心点。检测到目标中心点后,利用每个peak点的特征回归目标框的其它属性,比如2D box的宽和高,3D box 的长宽高、深度、角度等。
CenterNet是anchor-free的,正样本的分配极其简单,一个目标只唯一对应heatmap上的一个peak,是无需NMS的。
2.网络结构
2.1Backbone网络
CenterNet把目标检测问题转化为关键点检测问题,因此使用的也都是关键点检测网络的backbone,这一点和CornerNet一样。CenterNet的backbone网络下采样倍数(stride)都是4,并且整体网络结构都是先下采样获得很大的stride,然后上采样回到stride=4,主要使用了三种backbone网络,如下:
(1)up-conv ResNet
ResNet网络的变种,使用了可变形卷积改变通道数,并且使用了转置卷积进行上采样。
(2)hourglass(沙漏)
和CornerNet使用的backbone一样,有非常复杂的skip connections,整体网络呈现对称特性。
(3)DLA(deep layer aggregation)
使用的是修改后的DLA-34。
2.2neck网络
CenterNet不使用任何FPN、PAN等结构,没有neck网络。
2.3head网络
Backbone网络的输出特征送入一个3x3卷积(256通道)、ReLu、一个1x1卷积得到最终的输出,输出的通道数与具体任务有关。
2.3.12D检测
主要包括keypoints heatmap、local offset、object size三部分的检测,通道数分别为C、2、2。
2.3.1.1 keypoints heatmap
C指的是类别个数,即keypoint heatmap的通道数是C,对COCO数据集来讲是80。每个通道的peaks表示该类目标出现的位置,peaks的值也是该目标的置信度。
(1)正负样本分配
计算每个gt box的中心点落在heatmap上的位置点,则该点即为正样本,回归标签为1,其它位置点为负样本,回归标签根据高斯分布来获得,高斯标准差根据目标的尺寸确定。如果一个负样本点处于多个高斯分布重叠区域,取较大的值作为回归标签。
(2)损失函数
对于损失函数用logistic regression + focal loss。单位:中心点坐标关于特征图上尺寸的归一化。
2.3.1.2 local offset
local offset预测特征图有2个通道分别表示中心点x、y方向的offset,因为backbone网络的stride=4,因此会有位置的偏移,这个offset预测就是用来修正目标中心点位置的。
(1)正负样本分配
只计算正样本,没有负样本。
(2)损失函数
L1 loss。单位:关于特征图上尺寸的归一化。
2.3.1.3 object size
object size,即宽和高,预测特征图具有2个通道。
(1)正负样本分配
只计算正样本,没有负样本。
(2)损失函数
L1 loss,并且作者进行了实验,发现对object size的回归来讲,L1 loss比Smooth L1 更好。单位:使用的是原始像素为单位,不进行标准化。
2.3.2 3D检测
不需要2D box的宽和高的预测,需要中心点预测,还需要另外三个预测:深度、3D box、角度。深度占一个通道,3D box占3个通道,对角度回归进行编码,占8个通道。
3.推理流程
这里只介绍2D检测的流程。
(1)前向传播:图像输入网络,获得输出特征图,包括三个部分,keypoints heatmap、local offset、object size。
(2)峰值检测:每个类别的heatmap单独进行peak处理。>=8邻域内其它点的中心点都算作peak,所有类别的所有peaks都提取出来,然后保留前100个。(注意,这里没有明说,但是我猜后续应该还会设置一个置信度阈值,对这100个检测结果进行过滤,具体有没有需要看代码)
(3)bbox解码:根据峰值的位置,进行offset调整,并获得object size。没有复杂的box解码过程,进一步提高了后处理速度。
总结:后处理过程没有NMS,也没有复杂的box解码过程,整体流程非常简洁。
4.我对CenterNet的几个问题
4.1 为什么CenterNet不需要NMS?
4.1.1 原论文的解释
先简单回顾一下Faster RCNN的正负样本分配。对于一个gt box,Faster RCNN需要计算其和所有anchors的IOU,**并取IOU大于0.7的所有anchor都作为正样本,来预测这个gt box,**这就导致在测试时,可能有多个pred box(预测框)对应的都是一个gt box,因此需要NMS去掉这些冗余的。
而对CenterNet来讲,原论文的说法是,每个目标只产生一个正样本点peak点,不像Faster RCNN那样一个gt box分给多个anchor去预测,没有冗余,因此CenterNet不需要NMS。
4.1.2 对原论文解释的疑惑
我对这种解释有些怀疑,我们先回顾一下YOLOv3的正负样本分配,对一个gt box,YOLOv3计算其和所有anchor的IOU,然后只把最大IOU的anchor作为正样本预测这个gt box,即YOLOv3也是一个gt box只分配给一个anchor、只产生一个正样本,和CenterNet一样,但YOLOv3是需要NMS的,为什么呢?
我猜测可能的原因是YOLOv3引入了忽略样本的原因,正样本邻域网格可能与gt box也有很高的IOU,但是邻域网格属于忽略样本,训练过程中不用于计算置信度损失,也就是说,训练过程并不会对邻域的置信度预测进行抑制,只是忽略而已,因此在测试时,邻域网格也可能输出了和gt box很高IOU的pred box,并且置信度分数也很高,和正样本网格非常接近,单纯靠置信度阈值处理无法去掉这些冗余的pred box,所以需要用NMS。
但是CenterNet在分配正样本时,目标中心点对应的heatmap上的点对应的回归标签就是1,而邻域根据高斯分布对回归值进行抑制,获得较小的回归值,没有重叠的问题,此只要选择peak作为目标就好,不需要NMS。不知道我的理解对不对?欢迎大家讨论。
但是如果这么考虑的话,只要把YOLOv3的忽略样本全部改成负样本,并且负样本的置信度标签比正样本对应的标签小,那么YOLOv3不也可以是NMS-free的吗?其实YOLOv5以及这么做了,yolov5的置信度标签是gt box和anchor的IOU值(也可能是CIOU?我没有核实),正样本就比负样本的置信度标签大,只不过YOLOv5的一个gt box对应多个正样本,仍然需要NMS。
4.1.3 CenterNet的heatmap和YOLOv3的置信度特征图有什么区别?
其实仔细一想,CenterNet是选择特征图(heatmap)的峰值peak作为目标中心点,无需NMS,而YOLOv3是先对置信度特征图进行阈值处理(取的不是唯一的极大值peak),然后把获得的pred box进行NMS处理,这样考虑的话,CenterNet的置信度特征图和YOLOv3的heatmap没有特别本质的区别,都是为了获取高质量的候选样本,只不过CenterNet把候选样本直接作为检测结果(因为正样本分配和回归尽量避免了冗余的情况),而YOLOv3对候选样本进行了NMS。不知道我的理解对不对?
4.2 推理时获取peaks后是否需要阈值处理?
论文中没有明说,这个需要看代码才能知道,或者哪位好心大佬告知一下?
4.3 CenterNet是否易于硬件部署?
我本人没有做过深度学习算法硬件落地方面的工作,但是我知道算法落地不仅考虑计算量的问题,还会考虑内存等方面的占用,所以对CenterNet这种使用hourglass、DLA等具有很多跳路连接的网络,以及可变形卷积、转置卷积等各种奇怪的运算等(硬件加速是否支持这些运算?),其内存占用相比YOLOv3要显著增加?也许后面做到算法落地的时候再考虑吧。