1.修改ssd中的代码

if phase == 'test':
             self.softmax = nn.Softmax(dim=-1)
             self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)

改成

if phase == 'test':
             self.softmax = nn.Softmax()
             self.detect = Detect()

另外该文件下的def forward()方法中的:

if self.phase == "test":
             output = self.detect(
                 loc.view(loc.size(0), -1, 4),                   # loc preds
                 self.softmax(conf.view(conf.size(0), -1,
                              self.num_classes)),                # conf preds
                 self.priors.type(type(x.data))                  # default boxes
             )

改为:

if self.phase == "test":
             output = self.detect.apply(self.num_classes, 0, 200, 0.01, 0.45,
                 loc.view(loc.size(0), -1, 4),                   # loc preds
                 self.softmax(conf.view(-1,
                              self.num_classes)),                # conf preds
                 self.priors.type(type(x.data))                  # default boxes
             )

2.修改layers中的function中的detection.py代码

class Detect(Function):

@staticmethod

def forward(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh, loc_data, conf_data, prior_data):

"""

Args:

loc_data: (tensor) Loc preds from loc layers

Shape: [batch,num_priors*4]


conf_data: (tensor) Shape: Conf preds from conf layers

Shape: [batch*num_priors,num_classes]

prior_data: (tensor) Prior boxes and variances from priorbox layers

Shape: [1,num_priors,4]

"""

self.num_classes = num_classes

self.background_label = bkg_label

self.top_k = top_k

# Parameters used in nms.

self.nms_thresh = nms_thresh

if nms_thresh <= 0:

raise ValueError('nms_threshold must be non negative.')

self.conf_thresh = conf_thresh

self.variance = cfg['variance']

num = loc_data.size(0) # batch size

num_priors = prior_data.size(0)

output = torch.zeros(num, self.num_classes, self.top_k, 5)

conf_preds = conf_data.view(num, num_priors,

self.num_classes).transpose(2, 1)

 

# Decode predictions into bboxes.

for i in range(num):

decoded_boxes = decode(loc_data[i], prior_data, self.variance)

# For each class, perform nms

conf_scores = conf_preds[i].clone()

#num_det = 0

for cl in range(1, self.num_classes):

c_mask = conf_scores[cl].gt(self.conf_thresh)

scores = conf_scores[cl][c_mask]

if scores.size(0) == 0:

continue

l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)

boxes = decoded_boxes[l_mask].view(-1, 4)

# idx of highest scoring and non-overlapping boxes per class

ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)

output[i, cl, :count] = \

torch.cat((scores[ids[:count]].unsqueeze(1),

boxes[ids[:count]]), 1)

flt = output.contiguous().view(num, -1, 5)

_, idx = flt[:, :, 0].sort(1, descending=True)

_, rank = idx.sort(1)

flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)

return output