C/C++教程

PyTorch程序设计

本文主要是介绍PyTorch程序设计,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

PyTorch程序设计

  • 1. 程序示例
  • 2. 加载数据
  • 3. 创建网络(BN、卷积、ReLU)
  • 4. 优化器(优化器比较、参数lr、权重衰减)
  • 5. 损失函数比较
  • 6. 模型保存与加载

1. 程序示例

import torch, os
from torch.utils.data import Dataset, DataLoader
import numpy as np


class MyDataSet(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __getitem__(self, item):
        image = self.images[item].reshape(1, 2, 2)  # cwh
        label = self.labels[item].reshape(1, 1, 1)
        return image, label

    def __len__(self):
        return len(self.labels)


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        layers = []
        layers.append(torch.nn.Conv2d(1, 1, kernel_size=2, stride=1, bias=True, padding=0))
        self.net = torch.nn.ModuleList(layers)

    def forward(self, input):
        return self.net[0](input)


def train():
    # 1.输入数据,[w, h, c] = [2, 2, 1]。
    x = np.random.randint(low=0, high=255, size=(50, 2, 2), dtype=np.int)
    x = np.divide(x, 255).astype(np.float32)
    # 2.权重和偏置。
    w = np.array([[1, 2], [3, 4]], dtype=np.float32)
    b = 5
    # 3.标签。
    y = np.array([(np.multiply(w, x[i])).sum() + b for i in range(x.shape[0])])
    # 4.加载数据、网络、优化器、损失函数。
    train_data = MyDataSet(x, y)
    data_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False)
    net = Net()
    if os.path.exists('parameters.pth'):
        net.load_state_dict(torch.load('parameters.pth'))
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=0.0005)
    loss_op = torch.nn.L1Loss(size_average=True, reduce=True)
    # 5.迭代。
    for epoch in range(400):
        optimizer.param_groups[0]['lr'] = 0.001  # 设置学习率。
        for x, y in data_loader:
            out = net(x)
            loss = loss_op(y, out)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_numpy = loss.cpu().detach().numpy()
            if loss_numpy < 1e-10:
                break
        print('--epoch:', epoch, 'loss:', loss_numpy)
    for k, v in net.named_parameters():
        print(k, v.cpu().detach().numpy())
    torch.save(net.state_dict(), 'parameters.pth')  # 保存参数。
    torch.save(net, 'model.pth')  # 保存网络结构和参数。


def testNet():
    net = torch.load('model.pth')
    y = net(torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).reshape(1, 1, 2, 2))
    print(y)


def testParams():
    net = Net()
    net.load_state_dict(torch.load('parameters.pth')) 
    y = net(torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).reshape(1, 1, 2, 2))
    print(y)


if __name__ == '__main__':
    train()
    testNet()
    testParams()

2. 加载数据

3. 创建网络(BN、卷积、ReLU)

4. 优化器(优化器比较、参数lr、权重衰减)

5. 损失函数比较

6. 模型保存与加载

这篇关于PyTorch程序设计的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!