C/C++教程

Pytorch(3)-Torchvision的使用

本文主要是介绍Pytorch(3)-Torchvision的使用,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
import torchvision
# 通过ToTensor()将数据集转为tensor数据类型,并通过compose连接
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

# 加载数据集,其中CIFAR10是pytorch提供的一种数据集类型,具体参数介绍:
#   root:要保存数据集的目录
#   train:如果为true创建一个训练数据集,如果为false创建一个测试数据集
#   transform:图像转换后的数据类型
#   download:如果为true则会在网络中下载该数据集,如果为false则不会下载,如果数据集已经存在则会在控制台输出数据集已存在
train_set = torchvision.datasets.CIFAR10(root="F:\\pytorch\\pytorch01_hello\\dataset\\train\\torchvision_image", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="F:\\pytorch\\pytorch01_hello\\dataset\\train\\torchvision_image", train=False, transform=dataset_transform, download=True)
# print(test_set.classes)
# print(test_set[0])

# 将数据集加载到tensorboard查看,这里查看前十张图片
writer = SummaryWriter("p10")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("p10课程", img, i)
writer.close()
这篇关于Pytorch(3)-Torchvision的使用的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!