1.修改ssd中的代码
if phase == 'test':
self.softmax = nn.Softmax(dim=-1)
self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)
改成
if phase == 'test':
self.softmax = nn.Softmax()
self.detect = Detect()
另外该文件下的def forward()方法中的:
if self.phase == "test":
output = self.detect(
loc.view(loc.size(0), -1, 4), # loc preds
self.softmax(conf.view(conf.size(0), -1,
self.num_classes)), # conf preds
self.priors.type(type(x.data)) # default boxes
)
改为:
if self.phase == "test":
output = self.detect.apply(self.num_classes, 0, 200, 0.01, 0.45,
loc.view(loc.size(0), -1, 4), # loc preds
self.softmax(conf.view(-1,
self.num_classes)), # conf preds
self.priors.type(type(x.data)) # default boxes
)
2.修改layers中的function中的detection.py代码
class Detect(Function): |
@staticmethod |
def forward(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh, loc_data, conf_data, prior_data): |
""" |
Args: |
loc_data: (tensor) Loc preds from loc layers |
Shape: [batch,num_priors*4] |
conf_data: (tensor) Shape: Conf preds from conf layers |
Shape: [batch*num_priors,num_classes] |
prior_data: (tensor) Prior boxes and variances from priorbox layers |
Shape: [1,num_priors,4] |
""" |
self.num_classes = num_classes |
self.background_label = bkg_label |
self.top_k = top_k |
# Parameters used in nms. |
self.nms_thresh = nms_thresh |
if nms_thresh <= 0: |
raise ValueError('nms_threshold must be non negative.') |
self.conf_thresh = conf_thresh |
self.variance = cfg['variance'] |
num = loc_data.size(0) # batch size |
num_priors = prior_data.size(0) |
output = torch.zeros(num, self.num_classes, self.top_k, 5) |
conf_preds = conf_data.view(num, num_priors, |
self.num_classes).transpose(2, 1) |
|
# Decode predictions into bboxes. |
for i in range(num): |
decoded_boxes = decode(loc_data[i], prior_data, self.variance) |
# For each class, perform nms |
conf_scores = conf_preds[i].clone() |
#num_det = 0 |
for cl in range(1, self.num_classes): |
c_mask = conf_scores[cl].gt(self.conf_thresh) |
scores = conf_scores[cl][c_mask] |
if scores.size(0) == 0: |
continue |
l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes) |
boxes = decoded_boxes[l_mask].view(-1, 4) |
# idx of highest scoring and non-overlapping boxes per class |
ids, count = nms(boxes, scores, self.nms_thresh, self.top_k) |
output[i, cl, :count] = \ |
torch.cat((scores[ids[:count]].unsqueeze(1), |
boxes[ids[:count]]), 1) |
flt = output.contiguous().view(num, -1, 5) |
_, idx = flt[:, :, 0].sort(1, descending=True) |
_, rank = idx.sort(1) |
flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0) |
return output