# -*- coding: utf-8 -*- """ Created on Wed Feb 23 20:37:01 2022 @author: koneko """ import numpy as np import matplotlib.pyplot as plt def sigmoid(x): return 1 / (1 + np.exp(-x)) def mean_squared_error(y, t): return 0.5 * np.sum((y-t)**2) class Sigmoid: def __init__(self): self.out = None def forward(self, x): out = sigmoid(x) self.out = out return out def backward(self, dout): dx = dout * (1.0 - self.out) * self.out return dx x = np.linspace(-np.pi, np.pi, 1000) y = np.sin(x) plt.plot(x,y) x = x.reshape(1, x.size) y = y.reshape(1, y.size) # 初始化权重 W1 = np.random.randn(3,1) b1 = np.random.randn(3,1) W2 = np.random.randn(2,3) b2 = np.random.randn(2,1) W3 = np.random.randn(1,2) b3 = np.random.randn(1,1) sig1 = Sigmoid() sig2 = Sigmoid() lr = 0.001 for i in range(30000): a1 = W1 @ x + b1 c1 = sig1.forward(a1) a2 = W2 @ c1 + b2 c2 = sig2.forward(a2) y_pred = W3 @ c2 + b3 #y_pred = W2 @ c1 + b2 Loss = mean_squared_error(y, y_pred) print(f"Loss[{i}]: {Loss}") dy_pred = y_pred - y dc2 = W3.T @ dy_pred da2 = sig2.backward(dc2) dc1 = W2.T @ da2 da1 = sig1.backward(dc1) # 计算Loss对各层参数的偏导数 dW3 = dy_pred @ c2.T db3 = np.sum(dy_pred) dW2 = da2 @ c1.T db2 = np.sum(da2, axis=1) db2 = db2.reshape(db2.size, 1) dW1 = da1 @ x.T db1 = np.sum(da1, axis=1) db1 = db1.reshape(db1.size, 1) W3 -= lr*dW3 b3 -= lr*db3 W2 -= lr*dW2 b2 -= lr*db2 W1 -= lr*dW1 b1 -= lr*db1 if i % 100 == 99: plt.cla() plt.plot(x.T,y.T) plt.plot(x.T,y_pred.T)