深度神经网络中使用的激活函数有很多种,这里介绍下tanh。它的公式如下,截图来自于维基百科(https://en.wikipedia.org/wiki/Activation_function):
tanh又称双曲正切,它解决了sigmoid非零中心问题。tanh取值范围在(-1, 1)内,它也是非线性的。它也不能完全解决梯度消失问题。
C++实现如下:
template<typename _Tp> int activation_function_tanh(const _Tp* src, _Tp* dst, int length) { for (int i = 0; i < length; ++i) { _Tp ep = std::exp(src[i]); _Tp em = std::exp(-src[i]); dst[i] = (ep - em) / (ep + em); } return 0; } template<typename _Tp> int activation_function_tanh_derivative(const _Tp* src, _Tp* dst, int length) { for (int i = 0; i < length; ++i) { dst[i] = (_Tp)1. - src[i] * src[i]; } return 0; } int test_activation_function() { std::vector<float> src{ 1.1f, -2.2f, 3.3f, 0.4f, -0.5f, -1.6f }; int length = src.size(); std::vector<float> dst(length); fprintf(stderr, "source vector: \n"); fbc::print_matrix(src); fprintf(stderr, "calculate activation function:\n"); fprintf(stderr, "type: tanh result: \n"); fbc::activation_function_tanh(src.data(), dst.data(), length); fbc::print_matrix(dst); fprintf(stderr, "type: tanh derivative result: \n"); fbc::activation_function_tanh_derivative(dst.data(), dst.data(), length); fbc::print_matrix(dst); }
执行结果如下:
Python和PyTorch实现如下:
import numpy as np import torch data = [1.1, -2.2, 3.3, 0.4, -0.5, -1.6] # numpy impl def tanh(x): lists = list() for i in range(len(x)): lists.append((np.exp(x[i]) - np.exp(-x[i])) / (np.exp(x[i]) + np.exp(-x[i]))) return lists def tanh_derivative(x): return 1 - np.power(tanh(x), 2) output = [round(value, 4) for value in tanh(data)] # 通过round保留小数点后4位 print("numpy tanh:", output) print("numpt tanh derivative:", [round(value, 4) for value in tanh_derivative(data)]) print("numpt tanh derivative2:", [round(1. - value*value, 4) for value in tanh(data)]) # call pytorch interface input = torch.FloatTensor(data) m = torch.nn.Tanh() output2 = m(input) print("pytorch tanh:", output2) print("pytorch tanh derivative:", 1. - output2*output2)
执行结果如下:
由以上执行结果可知:C++、Python、PyTorch三种实现方式结果完全一致。
GitHub:
https://github.com/fengbingchun/NN_Test
https://github.com/fengbingchun/PyTorch_Test