8.3 误差后向传播(BP)
原理和推导过程,参考慕课。
https://www.icourse163.org/course/ZJU-1003377027
输入值:x1, x2 = 0.5,0.3
输出值:y1, y2 =0.23, -0.07
激活函数:sigmoid
损失函数:MSE
初始权值:0.2 -0.4 0.5 0.6 0.1 -0.5 -0.3 0.8
目标:通过反向传播优化权值
=====正向计算:h1, h2, o1 ,o2=====0.56 0.5 0.48 0.53
=====损失函数:均方误差=====0.21
=====反向传播:误差传给每个权值=====0.01 0.01 0.01 0.01 0.03 0.08 0.03 0.07
=====更新前的权值=====0.2 -0.4 0.5 0.6 0.1 -0.5 -0.3 0.8
=====更新后的权值=====0.19 -0.41 0.49 0.59 0.07 -0.58 -0.33 0.73
import numpy as np def sigmoid(z): a = 1 / (1 + np.exp(-z)) return a if __name__ == "__main__": w1 = 0.2 w2 = -0.4 w3 = 0.5 w4 = 0.6 w5 = 0.1 w6 = -0.5 w7 = -0.3 w8 = 0.8 x1 = 0.5 x2 = 0.3 y1 = 0.23 y2 = -0.07 print("=====输入值:x1, x2;真实输出值:y1, y2=====") print(x1, x2, y1, y2) in_h1 = w1 * x1 + w3 * x2 out_h1 = sigmoid(in_h1) in_h2 = w2 * x1 + w4 * x2 out_h2 = sigmoid(in_h2) in_o1 = w5 * out_h1 + w7 * out_h2 out_o1 = sigmoid(in_o1) in_o2 = w6 * out_h1 + w8 * out_h2 out_o2 = sigmoid(in_o2) print("=====正向计算:h1, h2, o1 ,o2=====") print(round(out_h1, 2), round(out_h2, 2), round(out_o1, 2), round(out_o2, 2)) error = (1 / 2) * (out_o1 - y1)**2 + (1 / 2) * (out_o2 - y2)**2 print("=====损失函数:均方误差=====") print(round(error, 2)) # 反向传播 d_o1 = out_o1 - y1 d_o2 = out_o2 - y2 # print(round(d_o1, 2), round(d_o2, 2)) d_w5 = d_o1 * out_o1 * (1 - out_o1) * out_h1 d_w7 = d_o1 * out_o1 * (1 - out_o1) * out_h2 # print(round(d_w5, 2), round(d_w7, 2)) d_w6 = d_o2 * out_o2 * (1 - out_o2) * out_h1 d_w8 = d_o2 * out_o2 * (1 - out_o2) * out_h2 # print(round(d_w6, 2), round(d_w8, 2)) d_w1 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x1 d_w3 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x2 # print(round(d_w1, 2), round(d_w3, 2)) d_w2 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x1 d_w4 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x2 # print(round(d_w2, 2), round(d_w4, 2)) print("=====反向传播:误差传给每个权值=====") print(round(d_w1, 2), round(d_w2, 2), round(d_w3, 2), round(d_w4, 2), round(d_w5, 2), round(d_w6, 2), round(d_w7, 2), round(d_w8, 2)) print("=====更新前的权值=====") print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2), round(w8, 2)) w1 = w1 - d_w1 w2 = w2 - d_w2 w3 = w3 - d_w3 w4 = w4 - d_w4 w5 = w5 - d_w5 w6 = w6 - d_w6 w7 = w7 - d_w7 w8 = w8 - d_w8 print("=====更新后的权值=====") print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2), round(w8, 2))View Code
=====第6轮=====
正向计算:h1, h2, o1 ,o2
0.55 0.48 0.44 0.43
损失函数:均方误差
0.15
import numpy as np def sigmoid(z): a = 1 / (1 + np.exp(-z)) return a def forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8): in_h1 = w1 * x1 + w3 * x2 out_h1 = sigmoid(in_h1) in_h2 = w2 * x1 + w4 * x2 out_h2 = sigmoid(in_h2) in_o1 = w5 * out_h1 + w7 * out_h2 out_o1 = sigmoid(in_o1) in_o2 = w6 * out_h1 + w8 * out_h2 out_o2 = sigmoid(in_o2) print("正向计算:h1, h2, o1 ,o2") print(round(out_h1, 2), round(out_h2, 2), round(out_o1, 2), round(out_o2, 2)) error = (1 / 2) * (out_o1 - y1) ** 2 + (1 / 2) * (out_o2 - y2) ** 2 print("损失函数:均方误差") print(round(error, 2)) return out_o1, out_o2, out_h1, out_h2 def back_propagate(out_o1, out_o2, out_h1, out_h2): # 反向传播 d_o1 = out_o1 - y1 d_o2 = out_o2 - y2 # print(round(d_o1, 2), round(d_o2, 2)) d_w5 = d_o1 * out_o1 * (1 - out_o1) * out_h1 d_w7 = d_o1 * out_o1 * (1 - out_o1) * out_h2 # print(round(d_w5, 2), round(d_w7, 2)) d_w6 = d_o2 * out_o2 * (1 - out_o2) * out_h1 d_w8 = d_o2 * out_o2 * (1 - out_o2) * out_h2 # print(round(d_w6, 2), round(d_w8, 2)) d_w1 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x1 d_w3 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x2 # print(round(d_w1, 2), round(d_w3, 2)) d_w2 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x1 d_w4 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x2 # print(round(d_w2, 2), round(d_w4, 2)) print("反向传播:误差传给每个权值") print(round(d_w1, 2), round(d_w2, 2), round(d_w3, 2), round(d_w4, 2), round(d_w5, 2), round(d_w6, 2), round(d_w7, 2), round(d_w8, 2)) return d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 if __name__ == "__main__": w1 = 0.2 w2 = -0.4 w3 = 0.5 w4 = 0.6 w5 = 0.1 w6 = -0.5 w7 = -0.3 w8 = 0.8 x1 = 0.5 x2 = 0.3 y1 = 0.23 y2 = -0.07 print("=====输入值:x1, x2;真实输出值:y1, y2=====") print(x1, x2, y1, y2) print("=====更新前的权值=====") print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2), round(w8, 2)) out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2) # 步长 step = 1 w1 = w1 - step * d_w1 w2 = w2 - step * d_w2 w3 = w3 - step * d_w3 w4 = w4 - step * d_w4 w5 = w5 - step * d_w5 w6 = w6 - step * d_w6 w7 = w7 - step * d_w7 w8 = w8 - step * d_w8 print("第1轮更新后的权值") print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2), round(w8, 2)) print("=====第2轮=====") out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2) w1 = w1 - step * d_w1 w2 = w2 - step * d_w2 w3 = w3 - step * d_w3 w4 = w4 - step * d_w4 w5 = w5 - step * d_w5 w6 = w6 - step * d_w6 w7 = w7 - step * d_w7 w8 = w8 - step * d_w8 print("=====第3轮=====") out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2) w1 = w1 - step * d_w1 w2 = w2 - step * d_w2 w3 = w3 - step * d_w3 w4 = w4 - step * d_w4 w5 = w5 - step * d_w5 w6 = w6 - step * d_w6 w7 = w7 - step * d_w7 w8 = w8 - step * d_w8 print("=====第4轮=====") out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2) w1 = w1 - step * d_w1 w2 = w2 - step * d_w2 w3 = w3 - step * d_w3 w4 = w4 - step * d_w4 w5 = w5 - step * d_w5 w6 = w6 - step * d_w6 w7 = w7 - step * d_w7 w8 = w8 - step * d_w8 print("=====第5轮=====") out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2) w1 = w1 - step * d_w1 w2 = w2 - step * d_w2 w3 = w3 - step * d_w3 w4 = w4 - step * d_w4 w5 = w5 - step * d_w5 w6 = w6 - step * d_w6 w7 = w7 - step * d_w7 w8 = w8 - step * d_w8 print("=====第6轮=====") out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) print("更新后的权值") print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2), round(w8, 2))View Code
=====第6轮=====
正向计算:o1 ,o2
0.23 0.03
损失函数:均方误差
0.01
import numpy as np def sigmoid(z): a = 1 / (1 + np.exp(-z)) return a def forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8): in_h1 = w1 * x1 + w3 * x2 out_h1 = sigmoid(in_h1) in_h2 = w2 * x1 + w4 * x2 out_h2 = sigmoid(in_h2) in_o1 = w5 * out_h1 + w7 * out_h2 out_o1 = sigmoid(in_o1) in_o2 = w6 * out_h1 + w8 * out_h2 out_o2 = sigmoid(in_o2) print("正向计算:o1 ,o2") print(round(out_o1, 2), round(out_o2, 2)) error = (1 / 2) * (out_o1 - y1) ** 2 + (1 / 2) * (out_o2 - y2) ** 2 print("损失函数:均方误差") print(round(error, 2)) return out_o1, out_o2, out_h1, out_h2 def back_propagate(out_o1, out_o2, out_h1, out_h2): # 反向传播 d_o1 = out_o1 - y1 d_o2 = out_o2 - y2 # print(round(d_o1, 2), round(d_o2, 2)) d_w5 = d_o1 * out_o1 * (1 - out_o1) * out_h1 d_w7 = d_o1 * out_o1 * (1 - out_o1) * out_h2 # print(round(d_w5, 2), round(d_w7, 2)) d_w6 = d_o2 * out_o2 * (1 - out_o2) * out_h1 d_w8 = d_o2 * out_o2 * (1 - out_o2) * out_h2 # print(round(d_w6, 2), round(d_w8, 2)) d_w1 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x1 d_w3 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x2 # print(round(d_w1, 2), round(d_w3, 2)) d_w2 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x1 d_w4 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x2 # print(round(d_w2, 2), round(d_w4, 2)) print("反向传播:误差传给每个权值") print(round(d_w1, 2), round(d_w2, 2), round(d_w3, 2), round(d_w4, 2), round(d_w5, 2), round(d_w6, 2), round(d_w7, 2), round(d_w8, 2)) return d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 def update_w(w1, w2, w3, w4, w5, w6, w7, w8): # 步长 step = 50 w1 = w1 - step * d_w1 w2 = w2 - step * d_w2 w3 = w3 - step * d_w3 w4 = w4 - step * d_w4 w5 = w5 - step * d_w5 w6 = w6 - step * d_w6 w7 = w7 - step * d_w7 w8 = w8 - step * d_w8 return w1, w2, w3, w4, w5, w6, w7, w8 if __name__ == "__main__": w1 = 0.2 w2 = -0.4 w3 = 0.5 w4 = 0.6 w5 = 0.1 w6 = -0.5 w7 = -0.3 w8 = 0.8 x1 = 0.5 x2 = 0.3 y1 = 0.23 y2 = -0.07 print("=====输入值:x1, x2;真实输出值:y1, y2=====") print(x1, x2, y1, y2) print("=====更新前的权值=====") print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2), round(w8, 2)) out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2) w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w1, w2, w3, w4, w5, w6, w7, w8) print("第1轮更新后的权值") print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2), round(w8, 2)) print("=====第2轮=====") out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2) w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w1, w2, w3, w4, w5, w6, w7, w8) print("=====第3轮=====") out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2) w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w1, w2, w3, w4, w5, w6, w7, w8) print("=====第4轮=====") out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2) w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w1, w2, w3, w4, w5, w6, w7, w8) print("=====第5轮=====") out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2) w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w1, w2, w3, w4, w5, w6, w7, w8) print("=====第6轮=====") out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) print("更新后的权值") print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2), round(w8, 2))View Code
=====第999轮=====
正向计算:o1 ,o2
0.23038 0.00954
损失函数:均方误差
0.00316
import numpy as np def sigmoid(z): a = 1 / (1 + np.exp(-z)) return a def forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8): in_h1 = w1 * x1 + w3 * x2 out_h1 = sigmoid(in_h1) in_h2 = w2 * x1 + w4 * x2 out_h2 = sigmoid(in_h2) in_o1 = w5 * out_h1 + w7 * out_h2 out_o1 = sigmoid(in_o1) in_o2 = w6 * out_h1 + w8 * out_h2 out_o2 = sigmoid(in_o2) print("正向计算:o1 ,o2") print(round(out_o1, 5), round(out_o2, 5)) error = (1 / 2) * (out_o1 - y1) ** 2 + (1 / 2) * (out_o2 - y2) ** 2 print("损失函数:均方误差") print(round(error, 5)) return out_o1, out_o2, out_h1, out_h2 def back_propagate(out_o1, out_o2, out_h1, out_h2): # 反向传播 d_o1 = out_o1 - y1 d_o2 = out_o2 - y2 # print(round(d_o1, 2), round(d_o2, 2)) d_w5 = d_o1 * out_o1 * (1 - out_o1) * out_h1 d_w7 = d_o1 * out_o1 * (1 - out_o1) * out_h2 # print(round(d_w5, 2), round(d_w7, 2)) d_w6 = d_o2 * out_o2 * (1 - out_o2) * out_h1 d_w8 = d_o2 * out_o2 * (1 - out_o2) * out_h2 # print(round(d_w6, 2), round(d_w8, 2)) d_w1 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x1 d_w3 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x2 # print(round(d_w1, 2), round(d_w3, 2)) d_w2 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x1 d_w4 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x2 # print(round(d_w2, 2), round(d_w4, 2)) print("反向传播:误差传给每个权值") print(round(d_w1, 5), round(d_w2, 5), round(d_w3, 5), round(d_w4, 5), round(d_w5, 5), round(d_w6, 5), round(d_w7, 5), round(d_w8, 5)) return d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 def update_w(w1, w2, w3, w4, w5, w6, w7, w8): # 步长 step = 5 w1 = w1 - step * d_w1 w2 = w2 - step * d_w2 w3 = w3 - step * d_w3 w4 = w4 - step * d_w4 w5 = w5 - step * d_w5 w6 = w6 - step * d_w6 w7 = w7 - step * d_w7 w8 = w8 - step * d_w8 return w1, w2, w3, w4, w5, w6, w7, w8 if __name__ == "__main__": w1, w2, w3, w4, w5, w6, w7, w8 = 0.2, -0.4, 0.5, 0.6, 0.1, -0.5, -0.3, 0.8 x1, x2 = 0.5, 0.3 y1, y2 = 0.23, -0.07 print("=====输入值:x1, x2;真实输出值:y1, y2=====") print(x1, x2, y1, y2) print("=====更新前的权值=====") print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2), round(w8, 2)) for i in range(1000): print("=====第" + str(i) + "轮=====") out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2) w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w1, w2, w3, w4, w5, w6, w7, w8) print("更新后的权值") print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2), round(w8, 2))View Code
优化后的源代码:
import numpy as np import matplotlib.pyplot as plt def sigmoid(z): a = 1 / (1 + np.exp(-z)) return a def forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8): # 正向传播 in_h1 = w1 * x1 + w3 * x2 out_h1 = sigmoid(in_h1) in_h2 = w2 * x1 + w4 * x2 out_h2 = sigmoid(in_h2) in_o1 = w5 * out_h1 + w7 * out_h2 out_o1 = sigmoid(in_o1) in_o2 = w6 * out_h1 + w8 * out_h2 out_o2 = sigmoid(in_o2) error = (1 / 2) * (out_o1 - y1) ** 2 + (1 / 2) * (out_o2 - y2) ** 2 return out_o1, out_o2, out_h1, out_h2, error def back_propagate(out_o1, out_o2, out_h1, out_h2): # 反向传播 d_o1 = out_o1 - y1 d_o2 = out_o2 - y2 d_w5 = d_o1 * out_o1 * (1 - out_o1) * out_h1 d_w7 = d_o1 * out_o1 * (1 - out_o1) * out_h2 d_w6 = d_o2 * out_o2 * (1 - out_o2) * out_h1 d_w8 = d_o2 * out_o2 * (1 - out_o2) * out_h2 d_w1 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x1 d_w3 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x2 d_w2 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x1 d_w4 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x2 return d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 def update_w(step,w1, w2, w3, w4, w5, w6, w7, w8): #梯度下降,更新权值 w1 = w1 - step * d_w1 w2 = w2 - step * d_w2 w3 = w3 - step * d_w3 w4 = w4 - step * d_w4 w5 = w5 - step * d_w5 w6 = w6 - step * d_w6 w7 = w7 - step * d_w7 w8 = w8 - step * d_w8 return w1, w2, w3, w4, w5, w6, w7, w8 if __name__ == "__main__": w1, w2, w3, w4, w5, w6, w7, w8 = 0.2, -0.4, 0.5, 0.6, 0.1, -0.5, -0.3, 0.8 # 可以给随机值,为配合PPT,给的指定值 x1, x2 = 0.5, 0.3 # 输入值 y1, y2 = 0.23, -0.07 # 正数可以准确收敛;负数不行。why? 因为用sigmoid输出,y1, y2 在 (0,1)范围内。 N = 10 # 迭代次数 step = 10 # 步长 print("输入值:x1, x2;",x1, x2, "输出值:y1, y2:", y1, y2) eli = [] lli = [] for i in range(N): print("=====第" + str(i) + "轮=====") # 正向传播 out_o1, out_o2, out_h1, out_h2, error = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8) print("正向传播:", round(out_o1, 5), round(out_o2, 5)) print("损失函数:", round(error, 2)) # 反向传播 d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2) # 梯度下降,更新权值 w1, w2, w3, w4, w5, w6, w7, w8 = update_w(step,w1, w2, w3, w4, w5, w6, w7, w8) eli.append(i) lli.append(error) plt.plot(eli, lli) plt.ylabel('Loss') plt.xlabel('w') plt.show()