安全矩阵

 找回密码
 立即注册
搜索
查看: 1612|回复: 0

ResNet 原理与代码复现

[复制链接]

251

主题

270

帖子

1797

积分

金牌会员

Rank: 6Rank: 6

积分
1797
发表于 2022-8-4 23:30:56 | 显示全部楼层 |阅读模式
ResNet 原理与代码复现
转载于: 悬鱼铭 [url=]CV算法恩仇录[/url] 2022-08-03 20:02 发表于北京
收录于合集
#计算机视觉23个
#深度学习14个
#人工智能12个
#网络架构3个
ResNet 模型原理
VGG 网络在特征表示上有极大的优势,但深度网络训练起来非常困难。为了解决这个问题,研究者提出了一系列的训练技巧,如 Dropout、归一化(批量正则化,Batch Normalization)。
2015年,何凯明为了降低网络训练难度,解决梯度消失的问题,提出了残差网络(Residual Network,ResNet)。
图1 梯度消失
ResNet 通过引入跳跃结构(skip connection),让 CNN学习残差映射。残差结构(Bottleneck)如图 2 所示。
图2 残差结构
图2 的残差结构中,输入 x ,先是 1 x 1 卷积核,64 卷积层,最后是 1 x 1 卷积核,256 卷积层,维度先变小再变大。网络的输出为 H(x),如果没有引入跳跃结构分支, H(x) = F(x),根据链式法则对 x 求导,梯度变得越来越小。引入分支之后,H(x) = F(x) + x,对 x 求导,得到的局部梯度为 1,且当梯度进行反向传播时,梯度也不会消失。
图 3 是 ResNet 的结构,图中展示了 18 层、34 层、50 层、101 层、152 层框架细节,图中 “ x 2” 和 “ x 23 ” 表示该卷积层重复 2 次或 23 次。我们可以发现所有的网络都分成 5 部分,分别是 conv1、conv2_x、conv3_x、conv4_x、conv5_x。
图3 ResNet的结构
图 3 中 conv1 使用的是 7 x 7 的卷积核。当通道数一致时,卷积参数的计算量是 7 x 7 的卷积核 大于 3 x 3 的卷积核 ;当通道数不一致时,若通道数小,则可以采用大的卷积核。
对于第一个卷积层的通道数为 3 时,3个 3 x 3 卷积核与 1 个 7 x 7 卷积核的感受野效果一样,但 1 个 7 x 7 却比 3 个 3 x 3 的参数多。在 VGG 19 层和 ResNet 34 层里,参数的计算量如图 4 所示,ResNet 34 层采用 1 个 7 x 7 的卷积核的计算量远小于 VGG 19 层采用 3 个 3 x 3 的卷积核。
图4 参数的计算量
图 3 中卷积层 conv2_x 和 conv3_x 的输出(output size)的大小分别为56 x 56 和28 x 28,如果卷积层 conv2_x 采用跳跃结构到 conv3_x,由于特征图的维度不一致,不能直接相加,此时的跳跃结构可采用卷积,以保证特征图的维度一致,特征图可以进行相加操作。
图 3 中最后一行的 FLOPs (floating-point operations) 指的是浮点运算次数,可以衡量框架的复杂度。框架的复杂度与权重和偏差(bias)有关。输入图像的高、宽、通道数分别用 H_in、 W_in、D_in 表示;输出的特征图的高、宽、通道数分别用 H_out、 W_out、D_out 表示;卷积核的宽和高分别用 F_w、F_h表示;N_p表示特征图一个点的计算量,其计算公式如下:
一次卷积的 FLOPs 的计算公式如下:
对于全连接层,输入的特征图会拉伸为1 x N_in 的向量,输出的向量维度为 1 x N_out,则一次全连接层的 FLOPs 计算公式如下:
可以使用工具包 Flops 在 PyTorch 中计算网络的复杂度。
图5 ResNet 34 与 VGG 16 网络的 FLOPs
ResNet 代码复现
ResNet 网络参考了 VGG 19 网络,在其基础上进行了修改,变化主要体现在 ResNet直接使用 stride=2 的卷积做下采样,并且用 Global Average Pool 层替换了全连接层。
ResNet 使用两种残差结构,如下图 5 所示。左图对应的是浅层网络,当输入和输出维度一致时,可以直接将输入加到输出上。右图对应的是深层网络。对于维度不一致时(对应的是维度增加一倍),采用 1 x 1 的卷积,先降维再升维。
图5 残差结构
两种残差结构的代码实现如下,class BasicBlock(nn.Module) 指的是浅层网络 ResNet 18/34 的残差单元:
class BottleNeck(nn.Module)指的是深层网络 ResNet 50/101/152 的残差单元:
ResNet 的整体结构如下:
在 ResNet 类中的 forward( )函数规定了网络数据的流向:
(1)数据进入网络后先经过卷积(conv1),再进行下采样pool(f1);
(2)然后进入中间卷积部分(conv2_x, conv3_x, conv4_x, conv5_x);
(3)最后数据经过一个平均池化(avgpool)和全连接层(fc)输出得到结果;
中间卷积部分主要是下图中的蓝框部分,红框部分中的 [2, 2, 2, 2] 和 [3, 4, 6, 3] 等则代表了 bolck 的重复次数。
ResNet18和其他Res系列网络的差异主要在于 conv2_x ~conv5_x,其他的部件都是相似的。
参考资料:

后台回复【ResNet】即可获得代码

关注+星标【CV算法恩仇录】
学习不迷路**
**

回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|安全矩阵

GMT+8, 2024-11-29 18:31 , Processed in 0.012692 second(s), 18 queries .

Powered by Discuz! X4.0

Copyright © 2001-2020, Tencent Cloud.

快速回复 返回顶部 返回列表