男女的身高和体重有着显著的差别,此次Python程序的任务是根据一个人的身高和体重,简单判断他(她)的性别。
采用最简单的单层神经网络,logistic regression模型,模型输入一个人身高和体重,判断性别男女。
训练样本是sex_train.txt的文本,部分训练样本数据如下,第一列数字为身高(m);第二列数字为体重(kg);第三列数字指的是性别,其中1指代男性,2指代女性。
代码需要自定义Dataset类和getitem读取数据
Dataset中读取数据并放入变量Data中,通过strip去掉空格和换行符,里面的words【0】,words【1】,words【2】分别代表读取身高,体重和性别。getitem根据索引取出数据。
len函数获取样本个数
网络模型的输入输出
训练的函数
实际训练是以Dataloader加载训练样本,并以批次进行训练,batchsize表示训练单元。
正式训练设置epochs表示学习的轮数,epochs=100,进行100轮的训练。
model.train()进入训练模式,model.eval进入检测模式。
第二次for循环是取一个批次的样本进行输出。
epochs = 100 for epoch in range(epochs): # training----------------------------------- model.train() train_loss = 0 train_acc = 0 for batch, (batch_x, batch_y) in enumerate(train_loader): batch_x, batch_y = Variable(batch_x), Variable(batch_y) out = model(batch_x) loss = loss_func(out, batch_y) train_loss += loss.item() pred = torch.max(out, 1)[1] train_correct = (pred == batch_y).sum() train_acc += train_correct.item() print('epoch: %2d/%d batch %3d/%d Train Loss: %.3f, Acc: %.3f' % (epoch + 1, epochs, batch, math.ceil(len(train_data) / batchsize), loss.item(), train_correct.item() / len(batch_x))) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() # 更新learning rate print('Train Loss: %.6f, Acc: %.3f' % (train_loss / (math.ceil(len(train_data)/batchsize)), train_acc / (len(train_data)))) # evaluation-------------------------------- model.eval() eval_loss = 0 eval_acc = 0 for batch_x, batch_y in val_loader: batch_x, batch_y = Variable(batch_x), Variable(batch_y) out = model(batch_x) loss = loss_func(out, batch_y) eval_loss += loss.item() pred = torch.max(out, 1)[1] num_correct = (pred == batch_y).sum() eval_acc += num_correct.item() print('Val Loss: %.6f, Acc: %.3f' % (eval_loss / (math.ceil(len(val_data)/batchsize)), eval_acc / (len(val_data)))) # save model -------------------------------- if (epoch + 1) % 1 == 0: torch.save(model.state_dict(), 'output/params_' + str(epoch + 1) + '.pth')
最后训练样本的效果
<iframe allowfullscreen="true" data-mediaembed="bilibili" id="MdllqxtL-1641700155130" src="https://player.bilibili.com/player.html?aid=850510347"></iframe>20220109