C/C++教程

机器学习-网络模型的保存于读取(pytorch环境)

本文主要是介绍机器学习-网络模型的保存于读取(pytorch环境),对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

例子

import torchvision
from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

vgg16_true.classifier.add_module('add_liner',nn.Linear(1000,10))的作用:在classifier的Sequential中添加一个名为‘add_linear'的层

vgg16_false.classifier[6] = nn.Linear(4096,10)的作用:将classifier中的Sequential第7个元素修改为nn.Linear(4096,10)

这篇关于机器学习-网络模型的保存于读取(pytorch环境)的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!