使用的数据集为flower_photos,此数据集中:同一种花的图片分别存放于同一文件夹,文件夹名即为花的品种名称,如图:
from __future__ import print_function, division import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) dataset =torchvision.datasets.ImageFolder(root='D:/DATASETS/flower_photos',transform=train_transform) train_dataset, valid_dataset = train_test_split(dataset,test_size=0.2, random_state=0) print(len(train_dataset)) print(len(valid_dataset)) train_loader =DataLoader(train_dataset,batch_size=4, shuffle=True,num_workers=0) #Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。 valid_loader =DataLoader(valid_dataset,batch_size=4, shuffle=True,num_workers=0)
2936 734 下面的代码出现了一个错误:Invalid shape (3, 224, 224) for image data。 因为一般神经网络中对图像处理之后的格式是(3,224,224)这种,分别为通道,高,宽。 但是plt显示的图像格式为(224,224,3)也就是高,宽,通道。 此处使用了方法:img = img.permute(2, 1, 0)解决这个问题。
print(len(train_dataset)) for image in train_loader: valid_image,valid_label=image print('valid_label:',valid_label[0]) print('valid_image shape:',valid_image[0].shape) print(valid_image[0].dtype) plt.imshow(valid_image[0].permute(2, 1, 0)) plt.show() break
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
2936 valid_label: tensor(2) valid_image shape: torch.Size([3, 224, 224]) torch.float32