首先注意pytorch中模型保存有两种格式,pth和pkl,其中,pth是pytorch默认格式,pkl还支持pickle库,不过一般如果没有特殊需求的时候,推荐使用默认pth格式保存
pytorch中有两种数据保存方法,一种是存储整个模型,一种只存储参数
#保存 torch.save(model1, 'net.pth') #读取 model1 = torch.load('net.pth')
#保存 torch.save(model.state_dict(), 'checkpoint.pth') #提取 state_dict = torch.load('checkpoint.pth') model.load_state_dict(state_dict)
state_dict 包含了模型使用的所有参数(Parameter类型),如果自定义的模型参数没有用Parameter封装,那么不会出现在state_dict中, 所以使用的时候,自定义参数一定不要忘记使用Parameter进行封装。
class MLP(nn.Module): def __init__(self): super(MLP, self).__init__() self.w1 = torch.randn(10,2) self.w2 = nn.Parameter(torch.randn(2,1)) self.l1 = nn.Linear(10,1) def forward(self,x): pass net = MLP() net.state_dict()
输出,可以发现只有w2和l1
OrderedDict([('w2', tensor([[0.9826], [0.4665]])), ('l1.weight', tensor([[ 0.3098, 0.0985, -0.2566, -0.1024, 0.0449, -0.1681, -0.1743, 0.2985, -0.0644, -0.0181]])), ('l1.bias', tensor([-0.2871]))])
在训练的时候,可以保存训练中的中间状态,只需要把参数都保存到state字典中就可以了。 例如,在断点续传任务中,可以把epoch,模型状态,优化器状态,初始learning rate 等进行保存。
state = { 'state_dict': net.state_dict(), 'optimizer': optim.optimizer.state_dict(), 'lr_base': optim.lr_base 'epoch': epoch } torch.save( state, self.CKPTS_PATH + 'ckpt_' + self.VERSION + '/epoch'+ str(epoch) + '.pkl' )
加载
state = torch.load( self.CKPTS_PATH + 'ckpt_' + self.VERSION + '/epoch'+ str(epoch) + '.pkl' ) net.load_state_dict(state['state_dict']) optim.optimizer.load_state_dict(state['optimizer']) optim.lr_base = state['lr_base'] start_epoch = state['epoch']