从源码学习 DenseNet

Learn from source code

Posted by 刘知安 on 2021-03-25
文章目录
  1. DenseNet
    1. 1. 优点
    2. 2. 细节
    3. 3. 实现
    4. 4. 节省显存

DenseNet

image-20210324215011741

1. 优点

  • 由于引入了更多shorter connections,缓解了梯度消失(gradient vanishing)的问题;
  • 特征的传导更加简单,strengthen feature propagation;
  • 特征更大程度被复用,feature reuse;

  • 有点反直觉的是,参数量居然更小了!因为每个layer的channel数只有12!

2. 细节

  • DenseBlock中的每个DenseLayer层,包含BN,ReLU,3*3Conv,三者的顺序和平常有点不一样;
  • DenseBlock中的的每个layer,论文中称为“bottleneck layer”,因为每个层都需要接受它前面所有层的特征作为输入,必然会导致越靠后的层channel数更多,为了解决这个问题,作者在每个dense layer的前面添加了1*1卷积来降维,具体的setting为 BN-ReLU-1*1Conv-BN-ReLU-3*3Conv。此外,在最后的3*3Conv后面还加了dropout(p=0.2)。论文中记为DenseNet-B
  • 需要注意的是,第$i$个dense block中的layers并不会接受第$(i-1)$个dense block中各layers的输出,不然的话,可想而知参数量必然指数增长。此外,就算是第$i$个dense block的第一层接收的是第$(i-1)$个dense block的最后一层输出(size记为$[bs,C,H,W]$),这里的$C$也会比较大,于是论文中运用了transition layer来压缩通道数,具体为BN-ReLU-1*1Conv-2*2Pooling。论文中记为DenseNet-C
  • 实现的时候需要注意吃显存的问题,这可以用gradient checking技术用增加一些计算时间的代价来换空间效率,具体可以参考论文Training Deep Nets with Sublinear Memory Cost, by Chen et al. (2016)

3. 实现

源代码来源于efficient-densenet-pytorch, 这里为了方便说明问题,我将detailed network architecture画出来了,对应的是论文中的DenseNet-BC(k=12)的架构,如下:

和原论文中不太一样的是,我们这里的input_conv用的是3*3卷积,而且每个dense block中都包含了16个dense layer。其中Input Conv就是一个简单的nn.Conv2d(),而Final Conv则包括BN-ReLU-AdaptiveAvgPool2d。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# This implementation is based on the DenseNet-BC implementation in torchvision
# https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict


def _bn_function_factory(norm, relu, conv):
def bn_function(*inputs):
concated_features = torch.cat(inputs, 1)
bottleneck_output = conv(relu(norm(concated_features)))
return bottleneck_output

return bn_function


class _DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False):
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate,
kernel_size=1, stride=1, bias=False)),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False)),
self.drop_rate = drop_rate
self.efficient = efficient

def forward(self, *prev_features):
bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
bottleneck_output = cp.checkpoint(bn_function, *prev_features)
else:
bottleneck_output = bn_function(*prev_features)
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return new_features


class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class _DenseBlock(nn.Module):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, efficient=False):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
efficient=efficient,
)
self.add_module('denselayer%d' % (i + 1), layer)

def forward(self, init_features):
features = [init_features]
for name, layer in self.named_children():
new_features = layer(*features)
features.append(new_features)
return torch.cat(features, 1)


class DenseNet(nn.Module):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 3 or 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
small_inputs (bool) - set to True if images are 32x32. Otherwise assumes images are larger.
efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower.
"""

def __init__(self, growth_rate=12, block_config=(16, 16, 16), compression=0.5,
num_init_features=24, bn_size=4, drop_rate=0,
num_classes=10, small_inputs=True, efficient=False):

super(DenseNet, self).__init__()
assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1'

# First convolution
if small_inputs:
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)),
]))
else:
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
]))
self.features.add_module('norm0', nn.BatchNorm2d(num_init_features))
self.features.add_module('relu0', nn.ReLU(inplace=True))
self.features.add_module('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1,
ceil_mode=False))

# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
efficient=efficient,
)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1: # 每个dense_block的结尾
trans = _Transition(num_input_features=num_features,
num_output_features=int(num_features * compression))
self.features.add_module('transition%d' % (i + 1), trans)
num_features = int(num_features * compression)

# Final batch norm
self.features.add_module('norm_final', nn.BatchNorm2d(num_features))

# Linear layer
self.classifier = nn.Linear(num_features, num_classes)

# Initialization
for name, param in self.named_parameters():
if 'conv' in name and 'weight' in name:
n = param.size(0) * param.size(2) * param.size(3)
param.data.normal_().mul_(math.sqrt(2. / n))
elif 'norm' in name and 'weight' in name:
param.data.fill_(1)
elif 'norm' in name and 'bias' in name:
param.data.fill_(0)
elif 'classifier' in name and 'bias' in name:
param.data.fill_(0)

def forward(self, x):
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)
out = self.classifier(out)
return out


if __name__ == '__main__':
growth_rate = 12
depth = 100
block_cfg = [(depth - 4) // 6 for _ in range(3)]
efficient = False

model = DenseNet(
growth_rate=growth_rate,
block_config=block_cfg,
num_init_features=growth_rate * 2,
num_classes=10,
small_inputs=True,
efficient=efficient,
).to("cuda:0")
# print(model)

# Print number of parameters
num_params = sum(p.numel() for p in model.parameters())
print("Total parameters: ", num_params)

dummy_in = torch.rand(16, 3, 32, 32).to("cuda:0")
dummy_out = model(dummy_in)

4. 节省显存

gradient checking trick,TBD