当前位置: 移动技术网 > IT编程>数据库>Mysql > SSD原理及Pytorch代码解读——网络架构(二):特征提取网络及总体计算过程

SSD原理及Pytorch代码解读——网络架构(二):特征提取网络及总体计算过程

2020年07月28日  | 移动技术网IT编程  | 我要评论

特征提取网络

前面我们已经知道了SSD采用PriorBox机制,也知道了SSD多层特征图来做物体检测,浅层的特征图检测小物体,深层的特征图检测大物体。上一篇博客也看到了SSD是如何在VGG基础的网络结构上进行一下改进。但现在的问题是SSD是使用哪些卷积层输出的特征图来做目标检测的?如下图所示:

从上图中可以看到,SSD使用了第4、7、8、9、10、11层的这6个卷积层输出作为特征图来做目标检测,但是这些特征图通道大小不一且数量很大,所以SSD在每一个特征图后面都接上了一个分类与位置卷积层使得输出的通道数符合要求。还有也可以从上图看出这6个特征图尺寸越来越小,而其对应的感受野越来越大。6个特征图上的每一个点分别对应4、6、6、6、4、4个PriorBox。接下来分别利用3×3的卷积,即可得到每一个PriorBox对应的类别与位置预测量。
举个例子,第8个卷积层得到的特征图大小为10×10×512,每个点对应6个PriorBox,一共有600个PriorBox。由于采用的PASCAL VOC数据集的物体类别为21类,因此3×3卷积后得到的类别特征维度为6×21=126,位置特征维度为6×4=24。

源码

代码文件为ssd.py。

# 每个特征图上一个点对应的PriorBox数量
mbox = [4, 6, 6, 6, 4, 4]

def multibox(vgg, extra_layers, cfg, num_classes):
	"""
	建立特征提取网络
	parameter:
		vgg: 基础VGG结构层列表,type:list
		extra_layers: 深度卷积层列表,type:list
		cfg:# 每个特征图上一个点对应的PriorBox数量
		num_classes:类别数量
	return:
		vgg: 基础VGG结构层列表,type:list
		extra_layers: 深度卷积层列表,type:list
		(loc_layers, conf_layers):元组,分别是每一个特征图上的回归层输出列表和分类层输出列表
	"""
    loc_layers = []		# 回归层输出
    conf_layers = []	# 分类层输出
    vgg_source = [21, -2]
    # 取第4、7卷积层输出并接上3×3的卷积
    for k, v in enumerate(vgg_source):
        loc_layers += [nn.Conv2d(vgg[v].out_channels,
                                 cfg[k] * 4, kernel_size=3, padding=1)]
        conf_layers += [nn.Conv2d(vgg[v].out_channels,
                        cfg[k] * num_classes, kernel_size=3, padding=1)]
                        
	# 取第8、9、10、11卷积层输出并接上3×3的卷积
    for k, v in enumerate(extra_layers[1::2], 2):
        loc_layers += [nn.Conv2d(v.out_channels, cfg[k]
                                 * 4, kernel_size=3, padding=1)]
        conf_layers += [nn.Conv2d(v.out_channels, cfg[k]
                                  * num_classes, kernel_size=3, padding=1)]
    return vgg, extra_layers, (loc_layers, conf_layers)

总体网络计算过程

为了更好地梳理网络的前向过程,将从代码角度讲述SSD网络的整个前向过程。

class SSD(nn.Module):
    """Single Shot Multibox Architecture
    The network is composed of a base VGG network followed by the
    added multibox conv layers.  Each multibox layer branches into
        1) conv2d for class conf scores
        2) conv2d for localization predictions
        3) associated priorbox layer to produce default bounding
           boxes specific to the layer's feature map size.
    See: https://arxiv.org/pdf/1512.02325.pdf for more details.

    Args:
        phase: (string) 模型所处阶段,为"test"或者"train"
        size: 输入图像大小
        base: 基础VGG16结构层列表,输入尺寸为300或者500,type:list
        extras: 深度卷积层列表,type:list
        head: "multibox head" 元组,分别是每一个特征图上的回归层输出列表和分类层输出列表
        num_classes:类别数量
    """

    def __init__(self, phase, size, base, extras, head, num_classes):
        super(SSD, self).__init__()
        self.phase = phase
        self.num_classes = num_classes
        self.cfg = voc	# voc为配置信息,用于生成prior box
        self.priorbox = PriorBox(self.cfg)	
        #import pdb
        #pdb.set_trace()
        self.priors = self.priorbox.forward()	# 生成每个特征图上的prior box
        self.size = size

        # SSD network
        self.vgg = nn.ModuleList(base)	# 生成基础VGG结构网络
        # Layer learns to scale the l2 normalized features from conv4_3
        self.L2Norm = L2Norm(512, 20)
        self.extras = nn.ModuleList(extras)	# 生成深度卷积层结构网络

        self.loc = nn.ModuleList(head[0])	# 生成回归网络结构
        self.conf = nn.ModuleList(head[1])	# 生成分类网络结构

        if phase == 'test':
            self.softmax = nn.Softmax(dim=-1)
            self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)

    def forward(self, x):
        """.
		SSD前向传播过程
        Args:
            x: 批量数据. Shape: [batch,3,300,300].

        Return:
            Depending on phase:
            test:
                Variable(tensor) of output class label predictions,
                confidence score, and corresponding location predictions for
                each object detected. Shape: [batch,topk,7]

            train:
                list of concat outputs from:
                    1: 回归网络输出, Shape: [batch,num_priors,4]
                    2:分类网络输出, Shape: [batch,num_priors,num_classes]
                    3: prior box, Shape: [num_priors,4]
        """
        # sources保存特征图,loc与conf保存所有PriorBox的位置与类别预测特征
        sources = list()
        loc = list()
        conf = list()

        # 对输入图像卷积到conv4_3,将特征添加到sources中
        for k in range(23):
            x = self.vgg[k](x)

        s = self.L2Norm(x)
        sources.append(s)

        # 继续卷积到conv7,将特征添加到sources中
        for k in range(23, len(self.vgg)):
            x = self.vgg[k](x)
        sources.append(x)

        # 继续利用额外的卷积层计算,并将特征添加到sources中
        for k, v in enumerate(self.extras):
            x = F.relu(v(x), inplace=True)
            if k % 2 == 1:
                sources.append(x)

        # 对sources中的特征图利用类别与位置网络进行卷积计算,并保存到loc与conf中
        # 列表元素的尺寸为【batch, f_h, f_w, priors, f_h, f_w是该特征图的高和宽,priors是该特征图上每一个点对应的priorbox数量
        for (x, l, c) in zip(sources, self.loc, self.conf):
            loc.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf.append(c(x).permute(0, 2, 3, 1).contiguous())
		
		# 合并多层特征并修改尺寸,num_priors为priors box总数量,如果输入图像为300*300,那么就一共又8732个priors box
        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)	# shape[batch, num_priors*4]
        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)	# shape[batch, num_priors*num_classes]
        if self.phase == "test":
            output = self.detect(
                loc.view(loc.size(0), -1, 4),                   # loc preds
                self.softmax(conf.view(conf.size(0), -1,
                             self.num_classes)),                # conf preds
                self.priors.type(type(x.data))                  # default boxes
            )
        else:
            # 对于训练来说,output包括了loc与conf的预测值以及PriorBox的信息
            output = (
                loc.view(loc.size(0), -1, 4),
                conf.view(conf.size(0), -1, self.num_classes),
                self.priors
            )
        return output

本文地址:https://blog.csdn.net/weixin_41693877/article/details/107580985

如对本文有疑问, 点击进行留言回复!!

相关文章:

验证码:
移动技术网