pytorch中加载数据的顺序是:
①创建一个dataset对象
②创建一个dataloader对象
③循环调用dataloader对象,获取data,label数据拿到模型中去训练
你需要自己定义一个class继承父类Dataset,其中至少需要重写以下3个函数:
①__init__:传入数据,或者加载数据
②__len__:返回这个数据集一共有多少个item
③__getitem__: 返回一条训练数据,并将其转换成tensor
示例代码:
class MyData(Dataset): def __init__(self, x_patches, y_patches, transform = None): self.y_patches = clean_patches self.x_patches = blur_patches self.transform = transform def __len__(self): return len(self.y_patches) def __getitem__(self, idx): y_image = self.y_patches[idx] x_image = self.x_patches[idx] y_image = np.asarray(y_image) x_image = np.asarray(x_image) y_image = Image.fromarray(y_image.astype(np.uint8)) x_image = Image.fromarray(x_image.astype(np.uint8)) if self.transform: y_image = self.transform(y_image) x_image = self.transform(x_image) return x_image, y_image
参数:
dataset:传入的数据
shuffle = True:是否打乱数据
collate_fn:这个参数可以自己操作每个batch的数据 参考:https://blog.csdn.net/kahuifu/article/details/108654421
示例代码:
dataset = MyData(x_patches, y_patches, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])) bs = 16 data_loader = DataLoader(dataset, batch_size=bs, shuffle=True) num_batches = len(data_loader)
最后循环调用dataloader ,拿到数据放入模型进行训练
for n_batch, (x_batch, y_batch) in enumerate(data_loader): x_data = x_batch.float().cuda() y_data = y_batch.float().cuda()