本文主要是介绍PyTorch程序设计,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
PyTorch程序设计
- 1. 程序示例
- 2. 加载数据
- 3. 创建网络(BN、卷积、ReLU)
- 4. 优化器(优化器比较、参数lr、权重衰减)
- 5. 损失函数比较
- 6. 模型保存与加载
1. 程序示例
import torch, os
from torch.utils.data import Dataset, DataLoader
import numpy as np
class MyDataSet(Dataset):
def __init__(self, images, labels):
self.images = images
self.labels = labels
def __getitem__(self, item):
image = self.images[item].reshape(1, 2, 2) # cwh
label = self.labels[item].reshape(1, 1, 1)
return image, label
def __len__(self):
return len(self.labels)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
layers = []
layers.append(torch.nn.Conv2d(1, 1, kernel_size=2, stride=1, bias=True, padding=0))
self.net = torch.nn.ModuleList(layers)
def forward(self, input):
return self.net[0](input)
def train():
# 1.输入数据,[w, h, c] = [2, 2, 1]。
x = np.random.randint(low=0, high=255, size=(50, 2, 2), dtype=np.int)
x = np.divide(x, 255).astype(np.float32)
# 2.权重和偏置。
w = np.array([[1, 2], [3, 4]], dtype=np.float32)
b = 5
# 3.标签。
y = np.array([(np.multiply(w, x[i])).sum() + b for i in range(x.shape[0])])
# 4.加载数据、网络、优化器、损失函数。
train_data = MyDataSet(x, y)
data_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False)
net = Net()
if os.path.exists('parameters.pth'):
net.load_state_dict(torch.load('parameters.pth'))
optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=0.0005)
loss_op = torch.nn.L1Loss(size_average=True, reduce=True)
# 5.迭代。
for epoch in range(400):
optimizer.param_groups[0]['lr'] = 0.001 # 设置学习率。
for x, y in data_loader:
out = net(x)
loss = loss_op(y, out)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_numpy = loss.cpu().detach().numpy()
if loss_numpy < 1e-10:
break
print('--epoch:', epoch, 'loss:', loss_numpy)
for k, v in net.named_parameters():
print(k, v.cpu().detach().numpy())
torch.save(net.state_dict(), 'parameters.pth') # 保存参数。
torch.save(net, 'model.pth') # 保存网络结构和参数。
def testNet():
net = torch.load('model.pth')
y = net(torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).reshape(1, 1, 2, 2))
print(y)
def testParams():
net = Net()
net.load_state_dict(torch.load('parameters.pth'))
y = net(torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).reshape(1, 1, 2, 2))
print(y)
if __name__ == '__main__':
train()
testNet()
testParams()
2. 加载数据
3. 创建网络(BN、卷积、ReLU)
4. 优化器(优化器比较、参数lr、权重衰减)
5. 损失函数比较
6. 模型保存与加载
这篇关于PyTorch程序设计的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!