深度残差收缩网络针对的是数据中含有强噪或冗余信息的情况,是深度残差网络的一种升级,将软阈值化放进深度残差网络的内部,通过消除冗余特征,增强高层特征的判别性。以下翻译了部分原文,仅以学习为目的。
【题目】Deep Residual Shrinkage Networks for Fault Diagnosis
【翻译】基于深度残差收缩网络的故障诊断
Abstract (摘要)
Abstract: This paper develops new deep learning methods, namely, deep residual shrinkage networks, to improve the feature learning ability from highly noised vibration signals and achieve a high fault diagnosing accuracy. Soft thresholding is inserted as nonlinear transformation layers into the deep architectures to eliminate unimportant features. Moreover, considering that it is generally challenging to set proper values for the thresholds, the developed deep residual shrinkage networks integrate a few specialized neural networks as trainable modules to automatically determine the thresholds, so that professional expertise on signal processing is not required. The efficacy of the developed methods is validated through experiments with various types of noise.
摘要:本文提出了一种新的深度学习方法,名为深度残差收缩网络,来提高深度学习方法从强噪声信号中学习特征的能力,并且取得较高的故障诊断准确率。软阈值化作为非线性层,嵌入到深度神经网络之中,以消除不重要的特征。更进一步地,考虑到软阈值化中的阈值是难以设定的,本文所提出的深度残差收缩网络,采用了一个子网络,来自动地设置这些阈值,从而回避了信号处理领域的专业知识。该方法的有效性通过多种不同噪声下的实验进行了验证。
【Keywords】Deep learning, deep residual networks, fault diagnosis, soft thresholding, vibration signal.
【翻译】深度学习,深度残差网络,故障诊断,软阈值化,振动信号
【翻译】旋转机械在制造业、电力供应、运输业和航天工业都是很重要的。然而,因为这些旋转机械工作在严酷的工作环境下,其机械传动系统不可避免地会遭遇一些故障,并且会导致事故和经济损失。准确的机械传动系统故障诊断,能够用来安排维修计划、延长服役寿命和确保人身安全。
【翻译】现有的机械传动系统故障诊断算法可分为两类:一类是基于信号分析的方法,另一类是基于机器学习的方法。通常,基于信号分析的故障诊断方法通过检测故障相关的振动成分或者特征频率,来确定故障类型。然而,对于大型旋转机械,其振动信号往往是由很多不同的振动信号混叠而成的,包括齿轮啮合频率、轴和轴承的旋转频率等。更重要地,当故障处于初期阶段,故障相关的振动成分往往是比较微弱的,容易被其他的振动成分和谐波所淹没。总而言之,传统基于信号分析的故障诊断方法经常难以检测到故障相关的振动成分和特征频率。
【翻译】从另一方面来讲,基于机器学习的故障诊断方法,在诊断故障的时候不需要确定故障相关的成分和特征频率。首先,一组统计特征(例如峭度、均方根值、能量、熵)能够被提取来表征健康状态;然后一个分类器(例如多分类支持向量机、单隐含层的神经网络、朴素贝叶斯分类器)能够被训练来诊断故障。然而,所提取的统计特征经常是判别性不足的,难以区分故障,从而导致了低的诊断准确率。因此,寻找一个判别性强的特征集,是基于机器学习的故障诊断中一个长期的挑战。
【翻译】近年来,深度学习方法,即有多个非线性映射层的机器学习方法,成为了基于振动信号进行故障诊断的有力工具。深度学习方法能够自动地从原始振动数据中学习特征,以取代传统的统计特征,来获得高的诊断准确率。例如,Ince等人采用一维卷积神经网络,从电流信号中学习特征,应用于实时电机故障诊断。Shao等人采用一种卷积深度置信网络,应用于电机轴承的故障诊断。然而,一个问题是,误差函数的梯度,在逐层反向传播的过程中,逐渐变得不准确。因此,在输入层附近的一些层的参数不能够被很好地优化。
【翻译】深度残差网络是卷积神经网络的一个新颖的变体,采用了恒等路径来减轻参数优化的难度。在深度残差网络中,梯度不仅逐层地反向传播,而且通过恒等路径直接传递到之前的层。由于优越的参数优化能力,深度残差网络在最近的一些研究中,已经被应用于故障诊断。例如,Ma等人将一种集成了解调时频特征的深度残差网络,应用于不稳定工况下的行星齿轮箱故障诊断。Zhao等人使用深度残差网络,来融合多组小波包系数,应用于故障诊断。相较于普通卷积神经网络,深度残差网络的优势已经在这些论文中得到了验证。
【翻译】从大型旋转机械(例如风电、机床、重型卡车)所采集的振动信号,经常包含着大量的噪声。在处理强噪声振动信号的时候,深度残差网络的特征学习能力经常会降低。深度残差网络中的卷积核,其实就是滤波器,在噪声的干扰下,可能不能检测到故障特征。在这种情况下,在输出层所学习到的高层特征,就会判别性不足,不能够准确地进行故障分类。因此,开发新的深度学习方法,应用于强噪声情况下旋转机械的故障诊断,是非常必要的。
【翻译】本文提出了两种深度残差收缩网络,即通道间共享阈值的深度残差收缩网络、通道间不同阈值的深度残差收缩网络,来提高从强噪声振动信号中学习特征的能力,最终提高故障诊断准确率。本文的主要贡献总结如下:
(1) 软阈值化(也就是一种流行的收缩方程)作为非线性层,被嵌入深度结构之中,以有效地消除噪声相关的特征。
(2) 采用特殊设计的子网络,来自适应地设置阈值,从而每段振动信号都有着自己独特的一组阈值。
(3) 在软阈值化中,共考虑了两种阈值,也就是通道间共享的阈值、通道间不同的阈值。这也是所提出方法名称的由来。
【翻译】本文的剩余部分安排如下。第二部分简要地回顾了经典的深度残差网络,并且详细阐述了所提出的深度残差收缩网络。第三部分进行了实验对比。第四部分进行了总结。
【翻译】如第一部分所述,作为一种潜在的、能够从强噪声振动信号中学习判别性特征的方法,本研究考虑了深度学习和软阈值化的集成。相对应地,本部分注重于开发深度残差网络的两个改进的变种,即通道间共享阈值的深度残差收缩网络、通道间不同阈值的深度残差收缩网络。对相关理论背景和必要的想法进行了详细介绍。
【翻译】不管是深度残差网络,还是所提出的深度残差收缩网络,都有一些基础的组成,是和传统卷积神经网络相同的,包括卷积层、整流线性单元激活函数、批标准化、全局均值池化、交叉熵误差函数。这些基础组成的概念在下面进行了介绍。
【翻译】卷积层是使得卷积神经网络不同于传统全连接神经网络的关键。卷积层能够大量减少所需要训练的参数的数量。这是通过用卷积,取代乘法矩阵,来实现的。卷积核中的参数,比全连接层中的权重,少得多。更进一步地,当参数较少时,深度学习不容易遭遇过拟合,从而能够在测试集上获得较高的准确率。输入特征图和卷积核之间的卷积运算,附带着加上偏置,能够用公式表示为…。卷积可以通过重复一定次数,来获得输出特征图。
【翻译】图1展示了卷积的过程。如图1(a)-(b)所示,特征图和卷积核实际上是三维张量。在本文中,一维振动信号是输入,所以特征图和卷积核的高度始终是1。如图1©所示,卷积核在输入特征图上滑动,从而得到输出特征图的一个通道。在每个卷积层中,通常有多于一个卷积核,从而输出特征图有多个通道。
【翻译】图1 (a) 特征图,(b) 卷积核和(c)卷积过程示意图
【翻译】批标准化是一种嵌入到深度结构的内部、作为可训练层的一种特征标准化方法。批标准化的目的在于减轻内部协方差漂移的问题,即特征的分布经常在训练过程中持续变化。在这种情况下,所需训练的参数就要不断地适应变化的特征分布,从而增大了训练的难度。批标准化,在第一步对特征进行标准化,来获得一个固定的分布,然后在训练过程中自适应地调整这个分布。后续介绍公式。
【翻译】激活函数通常是神经网络中必不可少的一部分,一般是用来实现非线性变换的。在过去的几十年中,很多种激活函数被提出来,例如sigmoid,tanh和ReLU。其中,ReLU激活函数最近得到了很多关注,这是因为ReLU能够很有效地避免梯度消失的问题。ReLU激活函数的导数要么是1,要么是0,能够帮助控制特征的取值范围大致不变,在特征在层间传递的时候。ReLU的函数表达式为max(x,0)。
【翻译】 全局均值池化是从特征图的每个通道计算一个平均值的运算。通常,全局均值池化是在最终输出层之前使用的。全局均值池化可以减少全连接输出层的权重数量,从而降低深度神经网络遭遇过拟合的风险。全局均值池化还可以解决平移变化问题,从而深度神经网络所学习得到的特征,不会受到故障冲击位置变化的影响。
【翻译】交叉熵损失函数通常作为多分类问题的目标函数,朝着最小的方向进行优化。相较于传统的均方差损失函数,交叉熵损失函数经常能够提供更快的训练速度。这是因为,交叉熵损失函数对于权重的梯度,相较于均方差损失函数,不容易减弱到零。为了计算交叉熵损失函数,首先要用softmax函数将特征转换到零一区间。然后交叉熵损失函数可以根据公式进行计算。在获得交叉熵损失函数之后,梯度下降法可以用来优化参数。在一定的迭代次数之后,深度神经网络就能够得到充分的训练。
【翻译】深度残差网络是一种新兴的深度学习方法,在近年来受到了广泛的关注。残差构建模块是基本的组成部分。如图2a所示,残差构建模块包含了两个批标准化、两个整流线性单元、两个卷积层和一个恒等路径。恒等路径是让深度残差网络优于卷积神经网络的关键。交叉熵损失函数的梯度,在普通的卷积神经网络中,是逐层反向传播的。当使用恒等路径的时候,梯度能够更有效地流回前面的层,从而参数能够得到更有效的更新。
图2b-2c展示了两种残差构建模块,能够输出不同尺寸的特征图。在这里,减小输出特征图尺寸的原因在于,减小后续层的运算量;增加通道数的原因在于,方便将不同的特征集成为强判别性的特征。
图2d展示了深度残差网络的整体框架,包括一个输入层、一个卷积层、一定数量的残差构建模块、一个批标准化、一个ReLU激活函数、一个全局均值池化和一个全连接输出层。同时,深度残差网络作为本研究的基准,以求进一步改进。
【翻译】图2 3种残差构建模块:(a) 输入特征图的尺寸=输出特征图的尺寸,(b)输出特征图的宽度减半,(c)输出特征图的宽度减半、通道数翻倍。(d)深度残差网络的整体框架。
【翻译】这一小节首先介绍了提出深度残差收缩网络的原始驱动,然后详细介绍了所提出深度残差收缩网络的结构。
【翻译】在过去的20年中,软阈值化经常被作为许多信号降噪算法中的关键步骤。通常,信号被转换到一个域。在这个域中,接近零的特征,是不重要的。然后,软阈值化将这些接近于零的特征置为零。例如,作为一种经典的信号降噪算法,小波阈值化通常包括三个步骤:小波分解、软阈值化和小波重构。为了保证信号降噪的效果,小波阈值化的一个关键任务是设计一个滤波器。这个滤波器能够将有用的信息转换成比较大的特征,将噪声相关的信息转换成接近于零的特征。然而,设计这样的滤波器需要大量的信号处理方面的专业知识,经常是非常困难的。深度学习提供了一种解决这个问题的新思路。这些滤波器可以通过反向传播算法自动优化得到,而不是由专家进行设计。因此,软阈值化和深度学习的结合是一种有效地消除噪声信息和构建高判别性特征的方式。软阈值化将接近于零的特征直接置为零,而不是像ReLU那样,将负的特征置为零,所以负的、有用的特征能够被保留下来。
【翻译】软阈值化的过程如图3(a)所示。可以看出,软阈值化的输出对于输入的导数要么是1,要么是0,所以在避免梯度消失和梯度爆炸的问题上,也是很有效的。
【翻译】图3 (a)软阈值化,(b)它的偏导
【翻译】在传统的信号降噪算法中,经常难以给阈值设置一个合适的值。同时,对于不同的样本,最优的阈值往往是不同的。针对这个问题,深度残差收缩网络的阈值,是在深度网络中自动确定的,从而避免了人工的操作。深度残差收缩网络中,这种设置阈值的方式,在后续文中进行了介绍。
【翻译】所提出的通道间共享阈值的深度残差收缩网络,是深度残差网络的一个变种,使用了软阈值化来消除与噪声相关的特征。软阈值化作为非线性层嵌入到残差构建模块之中。更重要地,阈值是在残差构建模块中自动学习得到的,介绍如下。
【翻译】图4 (a)通道间共享阈值的残差模块,(b)通道间共享阈值的深度残差收缩网络,(c)通道间不同阈值的残差模块,(d) 通道间不同阈值的深度残差收缩网络
【翻译】如图4(a)所示,名为“通道间共享阈值的残差收缩构建模块”,与图2(a)中残差构建模块是不同的,有一个特殊模块来估计软阈值化所需要的阈值。在这个特殊模块中,全局均值池化被应用在特征图的绝对值上面,来获得一维向量。然后,这个一维向量被输入到一个两层的全连接网络中,来获得一个尺度化参数。Sigmoid函数将这个尺度化参数规整到零和一之间。然后,这个尺度化参数,乘以特征图的绝对值得平均值,作为阈值。这样的话,就可以把阈值控制在一个合适的范围内,不会使输出特征全部为零。
【翻译】所提出的通道间共享阈值的深度残差收缩网络的结构简图如图4(b)所示,和图2(d)中经典深度残差网络是相似的。唯一的区别在于,通道间共享阈值的残差收缩模块(RSBU-CS),替换了普通的残差构建模块。一定数量的RSBU-CS被堆叠起来,从而噪声相关的特征被逐渐削减。另一个优势在于,阈值是自动学习得到的,而不是由专家手工设置的,所以在实施通道间共享阈值的深度残差收缩网络的时候,不需要信号处理领域的专业知识。
【翻译】道间不同阈值的深度残差收缩网络,是深度残差网络的另一个变种。与通道间共享阈值的深度残差收缩网络的区别在于,特征图的每个通道有着自己独立的阈值。通道间不同阈值的残差模块如图4©所示。特征图x首先被压缩成了一个一维向量,并且输入到一个两层的全连接层中。全连接层的第二层有多于一个神经元,并且神经元的个数等于输入特征图的通道数。全连接层的输出被强制到零和一之间。之后计算出阈值。与通道间共享阈值的深度残差收缩网络相似,阈值始终是正数,并且被保持在一个合理范围内,从而防止输出特征都是零的情况。
【翻译】通道间不同阈值的深度残差收缩网络的整体框架如图4(d)所示。一定数量的模块被堆积起来,从而判别性特征能够被学习得到。其中,软阈值化,作为收缩函数,用于非线性变换,来消除噪声相关的信息。
Reference:
M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
https://ieeexplore.ieee.org/document/8850096
Keras示例代码
以MNIST数据为例:
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sat Dec 28 23:24:05 2019 Implemented using TensorFlow 1.0.1 and Keras 2.2.1 M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898 @author: super_9527 """ from __future__ import print_function import keras import numpy as np from keras.datasets import mnist from keras.layers import Dense, Conv2D, BatchNormalization, Activation from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D from keras.optimizers import Adam from keras.regularizers import l2 from keras import backend as K from keras.models import Model from keras.layers.core import Lambda K.set_learning_phase(1) # Input image dimensions img_rows, img_cols = 28, 28 # The data, split between train and test sets (x_train, y_train), (x_test, y_test) = mnist.load_data() if K.image_data_format() == 'channels_first': x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) input_shape = (1, img_rows, img_cols) else: x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) input_shape = (img_rows, img_cols, 1) # Noised data x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1]) x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1]) print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') # convert class vectors to binary class matrices y_train = keras.utils.to_categorical(y_train, 10) y_test = keras.utils.to_categorical(y_test, 10) def abs_backend(inputs): return K.abs(inputs) def expand_dim_backend(inputs): return K.expand_dims(K.expand_dims(inputs,1),1) def sign_backend(inputs): return K.sign(inputs) def pad_backend(inputs, in_channels, out_channels): pad_dim = (out_channels - in_channels)//2 inputs = K.expand_dims(inputs,-1) inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last') return K.squeeze(inputs, -1) # Residual Shrinakge Block def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, downsample_strides=2): residual = incoming in_channels = incoming.get_shape().as_list()[-1] for i in range(nb_blocks): identity = residual if not downsample: downsample_strides = 1 residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) # Calculate global means residual_abs = Lambda(abs_backend)(residual) abs_mean = GlobalAveragePooling2D()(residual_abs) # Calculate scaling coefficients scales = Dense(out_channels, activation=None, kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(abs_mean) scales = BatchNormalization()(scales) scales = Activation('relu')(scales) scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales) scales = Lambda(expand_dim_backend)(scales) # Calculate thresholds thres = keras.layers.multiply([abs_mean, scales]) # Soft thresholding sub = keras.layers.subtract([residual_abs, thres]) zeros = keras.layers.subtract([sub, sub]) n_sub = keras.layers.maximum([sub, zeros]) residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub]) # Downsampling using the pooL-size of (1, 1) if downsample_strides > 1: identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity) # Zero_padding to match channels if in_channels != out_channels: identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity) residual = keras.layers.add([residual, identity]) return residual # define and train a model inputs = Input(shape=input_shape) net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs) net = residual_shrinkage_block(net, 1, 8, downsample=True) net = BatchNormalization()(net) net = Activation('relu')(net) net = GlobalAveragePooling2D()(net) outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net) model = Model(inputs=inputs, outputs=outputs) model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy']) model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test)) # get results K.set_learning_phase(0) DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0) print('Train loss:', DRSN_train_score[0]) print('Train accuracy:', DRSN_train_score[1]) DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0) print('Test loss:', DRSN_test_score[0]) print('Test accuracy:', DRSN_test_score[1])
TFLearn示例代码
以cifar10数据为例:
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Mon Dec 23 21:23:09 2019 Implemented using TensorFlow 1.0 and TFLearn 0.3.2 M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898 @author: super_9527 """ from __future__ import division, print_function, absolute_import import tflearn import numpy as np import tensorflow as tf from tflearn.layers.conv import conv_2d # Data loading from tflearn.datasets import cifar10 (X, Y), (testX, testY) = cifar10.load_data() # Add noise X = X + np.random.random((50000, 32, 32, 3))*0.1 testX = testX + np.random.random((10000, 32, 32, 3))*0.1 # Transform labels to one-hot format Y = tflearn.data_utils.to_categorical(Y,10) testY = tflearn.data_utils.to_categorical(testY,10) def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, downsample_strides=2, activation='relu', batch_norm=True, bias=True, weights_init='variance_scaling', bias_init='zeros', regularizer='L2', weight_decay=0.0001, trainable=True, restore=True, reuse=False, scope=None, name="ResidualBlock"): # residual shrinkage blocks with channel-wise thresholds residual = incoming in_channels = incoming.get_shape().as_list()[-1] # Variable Scope fix for older TF try: vscope = tf.variable_scope(scope, default_name=name, values=[incoming], reuse=reuse) except Exception: vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse) with vscope as scope: name = scope.name #TODO for i in range(nb_blocks): identity = residual if not downsample: downsample_strides = 1 if batch_norm: residual = tflearn.batch_normalization(residual) residual = tflearn.activation(residual, activation) residual = conv_2d(residual, out_channels, 3, downsample_strides, 'same', 'linear', bias, weights_init, bias_init, regularizer, weight_decay, trainable, restore) if batch_norm: residual = tflearn.batch_normalization(residual) residual = tflearn.activation(residual, activation) residual = conv_2d(residual, out_channels, 3, 1, 'same', 'linear', bias, weights_init, bias_init, regularizer, weight_decay, trainable, restore) # get thresholds and apply thresholding abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True) scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') scales = tflearn.batch_normalization(scales) scales = tflearn.activation(scales, 'relu') scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1) thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales)) # soft thresholding residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0)) # Downsampling if downsample_strides > 1: identity = tflearn.avg_pool_2d(identity, 1, downsample_strides) # Projection to new dimension if in_channels != out_channels: if (out_channels - in_channels) % 2 == 0: ch = (out_channels - in_channels)//2 identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch]]) else: ch = (out_channels - in_channels)//2 identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch+1]]) in_channels = out_channels residual = residual + identity return residual # Real-time data preprocessing img_prep = tflearn.ImagePreprocessing() img_prep.add_featurewise_zero_center(per_channel=True) # Real-time data augmentation img_aug = tflearn.ImageAugmentation() img_aug.add_random_flip_leftright() img_aug.add_random_crop([32, 32], padding=4) # Build a Deep Residual Shrinkage Network with 3 blocks net = tflearn.input_data(shape=[None, 32, 32, 3], data_preprocessing=img_prep, data_augmentation=img_aug) net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001) net = residual_shrinkage_block(net, 1, 16) net = residual_shrinkage_block(net, 1, 32, downsample=True) net = residual_shrinkage_block(net, 1, 32, downsample=True) net = tflearn.batch_normalization(net) net = tflearn.activation(net, 'relu') net = tflearn.global_avg_pool(net) # Regression net = tflearn.fully_connected(net, 10, activation='softmax') mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True) net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy') # Training model = tflearn.DNN(net, checkpoint_path='model_cifar10', max_checkpoints=10, tensorboard_verbose=0, clip_gradients=0.) model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500, show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10') training_acc = model.evaluate(X, Y)[0] validation_acc = model.evaluate(testX, testY)[0]