C/C++教程

PyTorch中利用ImageFolder和Dataloader加载自制图像数据集

本文主要是介绍PyTorch中利用ImageFolder和Dataloader加载自制图像数据集,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

使用的数据集为flower_photos,此数据集中:同一种花的图片分别存放于同一文件夹,文件夹名即为花的品种名称,如图:http://img1.sycdn.imooc.com/5fa35e110001779f06420197.jpg

from __future__ import print_function, division
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

dataset =torchvision.datasets.ImageFolder(root='D:/DATASETS/flower_photos',transform=train_transform)

train_dataset, valid_dataset = train_test_split(dataset,test_size=0.2, random_state=0)
print(len(train_dataset))
print(len(valid_dataset))
train_loader =DataLoader(train_dataset,batch_size=4, shuffle=True,num_workers=0)
#Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。
valid_loader =DataLoader(valid_dataset,batch_size=4, shuffle=True,num_workers=0)
2936
734
下面的代码出现了一个错误:Invalid shape (3, 224, 224) for image data。
因为一般神经网络中对图像处理之后的格式是(3,224,224)这种,分别为通道,高,宽。
但是plt显示的图像格式为(224,224,3)也就是高,宽,通道。
此处使用了方法:img = img.permute(2, 1, 0)解决这个问题。
print(len(train_dataset))
for image in train_loader:
    valid_image,valid_label=image
    print('valid_label:',valid_label[0])
    print('valid_image shape:',valid_image[0].shape)
    print(valid_image[0].dtype)
    plt.imshow(valid_image[0].permute(2, 1, 0))
    plt.show()
    break
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
2936
valid_label: tensor(2)
valid_image shape: torch.Size([3, 224, 224])
torch.float32

http://img4.sycdn.imooc.com/5fa35f180001259602570252.jpg

这篇关于PyTorch中利用ImageFolder和Dataloader加载自制图像数据集的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!