深度残差收缩网络是一种比较新颖的深度学习方法,实际上是深度残差网络的升级版本,能够在一定程度上提高深度学习方法在含噪数据上的特征学习效果。
首先,简单回顾一下深度残差网络,深度残差网络的基本模块如下图所示。相较于普通的卷积神经网络,深度残差网络引入了跨层的恒等映射,来减小模型训练的难度,提高准确率。
然后,在深度残差网络的基础上,深度残差收缩网络引入了一个小型的子网络,用这个子网络学习得到一组阈值,对特征图的各个通道进行软阈值化。这个过程其实可以看成一个可训练的特征选择的过程。具体而言,就是通过前面的卷积层将重要的特征转换成绝对值较大的值,将冗余信息所对应的特征转换成绝对值较小的值;通过子网络学习得到二者之间的界限,并且通过软阈值化将冗余特征置为零,同时使重要的特征有着非零的输出。
深度残差收缩网络其实是一种通用的方法,不仅可以用于含噪数据,也可以用于不含噪声的情况。这是因为,深度残差收缩网络中的阈值是根据样本情况自适应确定的。换言之,如果样本中不含冗余信息、不需要软阈值化,那么阈值可以被训练得非常接近于零,从而软阈值化就相当于不存在了。
最后,堆叠一定数量的基本模块,就得到了完整的网络结构。
利用深度残差收缩网络进行MNIST数据集的分类,可以看到,效果还是不错的。下面是深度残差收缩网络的程序:
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Dec 26 07:46:00 2019 Implemented using TensorFlow 1.0 and TFLearn 0.3.2 M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898 @author: me """ import tflearn import tensorflow as tf from tflearn.layers.conv import conv_2d # Data loading from tflearn.datasets import mnist X, Y, testX, testY = mnist.load_data(one_hot=True) X = X.reshape([-1,28,28,1]) testX = testX.reshape([-1,28,28,1]) def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, downsample_strides=2, activation='relu', batch_norm=True, bias=True, weights_init='variance_scaling', bias_init='zeros', regularizer='L2', weight_decay=0.0001, trainable=True, restore=True, reuse=False, scope=None, name="ResidualBlock"): # residual shrinkage blocks with channel-wise thresholds residual = incoming in_channels = incoming.get_shape().as_list()[-1] # Variable Scope fix for older TF try: vscope = tf.variable_scope(scope, default_name=name, values=[incoming], reuse=reuse) except Exception: vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse) with vscope as scope: name = scope.name #TODO for i in range(nb_blocks): identity = residual if not downsample: downsample_strides = 1 if batch_norm: residual = tflearn.batch_normalization(residual) residual = tflearn.activation(residual, activation) residual = conv_2d(residual, out_channels, 3, downsample_strides, 'same', 'linear', bias, weights_init, bias_init, regularizer, weight_decay, trainable, restore) if batch_norm: residual = tflearn.batch_normalization(residual) residual = tflearn.activation(residual, activation) residual = conv_2d(residual, out_channels, 3, 1, 'same', 'linear', bias, weights_init, bias_init, regularizer, weight_decay, trainable, restore) # get thresholds and apply thresholding abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True) scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') scales = tflearn.batch_normalization(scales) scales = tflearn.activation(scales, 'relu') scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1) thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales)) residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0)) # Downsampling if downsample_strides > 1: identity = tflearn.avg_pool_2d(identity, 1, downsample_strides) # Projection to new dimension if in_channels != out_channels: if (out_channels - in_channels) % 2 == 0: ch = (out_channels - in_channels)//2 identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch]]) else: ch = (out_channels - in_channels)//2 identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch+1]]) in_channels = out_channels residual = residual + identity return residual # Real-time data preprocessing img_prep = tflearn.ImagePreprocessing() img_prep.add_featurewise_zero_center(per_channel=True) # Building A Deep Residual Shrinkage Network net = tflearn.input_data(shape=[None, 28, 28, 1]) net = tflearn.conv_2d(net, 8, 3, regularizer='L2', weight_decay=0.0001) net = residual_shrinkage_block(net, 1, 8, downsample=True) net = tflearn.batch_normalization(net) net = tflearn.activation(net, 'relu') net = tflearn.global_avg_pool(net) # Regression net = tflearn.fully_connected(net, 10, activation='softmax') mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=40000, staircase=True) net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy') # Training model = tflearn.DNN(net, checkpoint_path='model_mnist', max_checkpoints=10, tensorboard_verbose=0, clip_gradients=0.) model.fit(X, Y, n_epoch=200, snapshot_epoch=False, snapshot_step=500, show_metric=True, batch_size=100, shuffle=True, run_id='model_mnist') training_acc = model.evaluate(X, Y)[0] validation_acc = model.evaluate(testX, testY)[0]
接下来是深度残差网络ResNet的程序:
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Dec 26 07:46:00 2019 Implemented using TensorFlow 1.0 and TFLearn 0.3.2 K. He, X. Zhang, S. Ren, J. Sun, Deep Residual Learning for Image Recognition, CVPR, 2016. @author: me """ import tflearn # Data loading from tflearn.datasets import mnist X, Y, testX, testY = mnist.load_data(one_hot=True) X = X.reshape([-1,28,28,1]) testX = testX.reshape([-1,28,28,1]) # Real-time data preprocessing img_prep = tflearn.ImagePreprocessing() img_prep.add_featurewise_zero_center(per_channel=True) # Building a deep residual network net = tflearn.input_data(shape=[None, 28, 28, 1]) net = tflearn.conv_2d(net, 8, 3, regularizer='L2', weight_decay=0.0001) net = tflearn.residual_block(net, 1, 8, downsample=True) net = tflearn.batch_normalization(net) net = tflearn.activation(net, 'relu') net = tflearn.global_avg_pool(net) # Regression net = tflearn.fully_connected(net, 10, activation='softmax') mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=40000, staircase=True) net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy') # Training model = tflearn.DNN(net, checkpoint_path='model_mnist', max_checkpoints=10, tensorboard_verbose=0, clip_gradients=0.) model.fit(X, Y, n_epoch=200, snapshot_epoch=False, snapshot_step=500, show_metric=True, batch_size=100, shuffle=True, run_id='model_mnist') training_acc = model.evaluate(X, Y)[0] validation_acc = model.evaluate(testX, testY)[0]
上述两个程序构建了只有1个基本模块的小型网络,MNIST数据集中没有添加噪声。准确率如下表所示(每次运行结果会有些波动),可以看到,即使是对于不含噪声的数据,深度残差收缩网络的结果也是不错的:
(不加噪MNIST数据) | 深度残差网络 | 深度残差收缩网络 |
---|---|---|
训练准确率 | 98.562% | 98.738% |
测试准确率 | 97.770% | 97.950% |
M. Zhao, S. Zhong, X. Fu, et al., Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
https://ieeexplore.ieee.org/d...