Java教程

2021-10-19

本文主要是介绍2021-10-19,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

江大白共学计划课程–Pytorch课程作业(1)

import torch
import torchvision.models as models
from thop import profile
# 加载模型结构
model = models.resnet18()
# 读取权重和载入权重
pretrained_state_dict = torch.load('Lesson2/weights/resnet18-5c106cde.pth')
model_state_dict = model.load_state_dict(pretrained_state_dict, strict=False)


# 让模型变为推理状态
model.eval()
# 将模型放置到cuda上
model.to(torch.device('cuda'))

# 构建一个项目推理时需要的输入大小的单精度Tensor,并且放置模型所在的设备(cpu或cuda)
inputs = torch.ones([1, 3, 224, 224]).type(torch.float32).to(torch.device('cuda'))
flops, params = profile(model=model, inputs=(inputs))
print('Model:{:.2f} GFLOPs and {:.2f}M parameters'.format(flops/1e9, params/1e6))

# 生成onnx:将训练好的模型(包括结构和权重)保存成onnx。
torch.onnx.export(model, inputs, 'Lesson2/weights/resnet18.onnx', verbose = False)
输入:[1, 3, 224, 224]
输出:resnet18模型的参数量和计算量
Model:0.31 GFLOPs and 3.50M parameters

输入:[1, 3, 448, 448]
输出:resnet18模型的参数量和计算量
Model:1.25 GFLOPs and 3.50M parameters

Netron可视化ResNet18的网络结构

在这里插入图片描述

这篇关于2021-10-19的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!