C/C++教程

Pytorch——Dataset类和DataLoader类

本文主要是介绍Pytorch——Dataset类和DataLoader类,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

  这篇文章主要探讨一下,Dataset类以及DataLoader类的使用以及注意事项。Dataset类主要是用于原始数据的读取或者基本的数据处理(比如在NLP任务中常常需要把文字转化为对应字典ids,这个步骤就可以放在Dataset中执行)。DataLoader,是进一步对Dataset的处理,Dataset得到的数据集你可以理解为是个"列表"(可以根据index取出某个特定位置的数据),而DataLoder就是把这个数据集(Dataset)根据你设定的batch_size划分成很多个“子数据集”,每个“子数据集”中的元素数量就是batch_size。

  DataLoader为什么要把Dataset划分成多个”子数据集“呢?因为一次性把所有的数据放进模型会导致内存溢出,而且模型的迭代会很慢。下面我们就深度解析下Dataset和DataLoader的使用方式。

一、Dataset的使用

  这里说到的Dataset其实就是,torch.utils.data.Dataset类 ,换句话说我们需要创建一个Dataset类,使用类的继承就可以了。既然是继承类,那么肯定会修改一些父类(torch.utils.data.Dataset类 )的方法来适应我们的真实数据和逻辑。而我们主要要重写的就是,__init__()__len__()__getitem__(),这三个方法分别是以下作用:

  • __init__方法:进行类的初始化,一般是用来读取原始数据。
  • __getitem__方法:根据下标对每一个数据进行进一步的处理。return:希望通过dataset[index]在数据集中取出的元素
  • __len__方法:return:数据集的数量(int)

  下面用一个例子来大致说明下Dataset该怎么构建,并且如何使用。

 

from torch.utils.data import Dataset 
import torch

def MyTokenizer(sentence):
    src_vocab = {'度':0,'上':1,'世':2,'中':3,
    '为':4,'人':5,'伟':6,'你':7,'务':8,'国':9,
    '大':10,'我':11,'是':12,'最':13,'服':14,
    '民':15,'爱':16,'界':17,'的':18}
    enc_input = [src_vocab[n] for n in sentence]
    if len(enc_input) < 12:                     ## 如果enc_input的长度小于12,则用100来补足,使得enc_input长度为12.
        enc_input = enc_input+(12-len(enc_input))*[100]
    return enc_input


class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.tokenizer = MyTokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sentence = self.data[index]
        return torch.tensor(self.tokenizer(sentence))

data = ['我爱你中国', '中国是世界上最伟大的国度', '为人民服务','你爱我']
dataset = MyDataset(data)

 

  简单介绍一下,函数MyTokenizer(sentence):把一个句子根据字典,转化为一个列表[...],例如”我爱你“——>[11, 16, 7, 100,100,...100]。这里为了简便,我是用的数据就是四句话。我在初始化(__init__)我的MyDataset类时,把数据储存下来,并且定义了我的编码器Mytokenizer。(这里为什么输出列表末尾会有100呢,主要是使得每一个数据长度是一样的,为后面进入DataLoader做准备,其实这个操作就叫做padding)

  __len__(self):这里返回了我传入数据集的大小。而__getitem__(self,index):中index指的是数据下标,根据这个下标提取出原始数据(self.data中的一句话),并且把这句话传入到self.tokenizer进行编码,最后返回编码的结果(一个列表[....]),可以看到__getitem__这个函数就是根据index来处理每一个拿出来的原始数据的,你对原始数据的所有处理都可以放在这里。我们最后一行代码是完成了MyDataset的实例化。看一看这个实例化之后的结果。

print(len(dataset))
print(dataset[0])
3
[11, 16, 7, 3, 9]

  这里返回的tensor([11,16,7,3,9]),其实就是”我爱你中国“经过编码之后的结果(可以对照上面的字典看看)。

  其实说白了Dataset就是一个数据处理器,把数据收集起来,并且进行对每一个index的数据进行处理,最后输出。有人会问为啥不先处理好这些数据呢?其实是因为DataLoader只能接受torch.utils.data.Dataset类作为传入参数,因此用其他任意的数据结构都没办法放到DataLoader里面,这样就没法自动根据batch_size拆分成”子数据集“。因此Dataset是我们必须构建的,就算是我的数据不想进一步处理,也必须写一个以上的最简单的MyDataset类(直接传入啥,输出啥的类)。

二、DataLoader的使用

  先放官方文档:官方文档

  刚才说到Dataset的构建是为了放进到DataLoader里,为啥非要放到这里面呢?其根本原因是DataLoader中有很多好用的设置可以让我们更好的处理数据,比如参数shuffle,可以让Dataset中的数据打乱重新排列再进行分批次,num_workers参数可以设定安排多少个进程来加载数据(加速)。一般情况下我们不需要重写DataLoader类,只需要实例化就可以了。例如我们把上面创建好的Dataset实例——dataset传入到DataLoader中构建实例。

  这里一定要注意,每个batch(子集)里的长度一定要一致,不然会报错“RuntimeError: each element in list of batch should be of equal size”。(这也就是为什么,在建立Dataset的时候我会用100来吧不足12长度的句子填充成统一长度,因为我举的例子中没有超过12的句子,所以不存在切割句子,真实情况需要按你自己的数据需求,但是一定要保证出来的数据要一样长,至于为什么一会后面说)。

from torch.utils.data import DataLoader
myDataloader = DataLoader(dataset, shuffle=True, batch_size=2)

  这个myDataloader就是DataLoader的实例,已经被分为了2个数据为一个batch,接下来我们打印一下每个batch(由于我们只有4句话,2个样本为一个batch那么其实就只有2个batch,所以可以打印来看看)。

for batch in myDataloader:
    print(batch)
    print('===============================')
tensor([[  7,  16,  11, 100, 100, 100, 100, 100, 100, 100, 100, 100],
        [  4,   5,  15,  14,   8, 100, 100, 100, 100, 100, 100, 100]])
===============================
tensor([[  3,   9,  12,   2,  17,   1,  13,   6,  10,  18,   9,   0],
        [ 11,  16,   7,   3,   9, 100, 100, 100, 100, 100, 100, 100]])
===============================

  可以看到每个batch其实是一个tensor,维度是(2,12)。每个tensor的每一行其实就是一个dataset里的一个样本。并且要注意每个样本已经不是按照原本的顺序排列了。

 

三、collate_fn参数的使用

  在DataLoader里,除了上面提到的shuffle参数和batch_size参数以外,还有一个非常重要的传入参数collate_fn,这个参数传入的是一个函数,这个函数主要是对每个batch进行处理,最终输出一个batch的返回值,换句话说collate_fn函数的返回值,就是遍历DataLoader的时候每个“batch”的返回值了(类似于上面例子中的二维tensor)。下面我写一个函数,让大家看看到底是怎么处理的。

def mycollate(item):
    sample1, sample2 = item
    return {'第一个样本':sample1,'第二个样本':sample2}

from torch.utils.data import DataLoader
myDataloader = DataLoader(dataset, shuffle=True, batch_size=2, collate_fn=mycollate)

  我们现在再来打印一下myDataloader的每个元素。

for batch in myDataloader:
    print(batch)
    print('===============================')
{'第一个样本': tensor([ 11,  16,   7,   3,   9, 100, 100, 100, 100, 100, 100, 100]), '第二个样本': tensor([  7,  16,  11, 100, 100, 100, 100, 100, 100, 100, 100, 100])}
===============================
{'第一个样本': tensor([ 3,  9, 12,  2, 17,  1, 13,  6, 10, 18,  9,  0]), '第二个样本': tensor([  4,   5,  15,  14,   8, 100, 100, 100, 100, 100, 100, 100])}
===============================

  可以看到,这个时候打印myDataloader的每个元素,就变成我在mycollate()函数中的返回值了。或许会不明白,我在mycollate()函数中这个传入的item是什么?其实这个item是个元组,元组的每个元素就是dataset的每个元素(tensor([3,9,12,....])),item的元素个数其实就是batch_size,这里的batch_size是2,所以我在mycollate()中用了两个变量来接收(换句话说要是我把batch_size换成2以外的其他数字,就会报错了)。

  可以看到其实我们在DataLoader的时候依然可以使用函数来处理我们的数据,换句话说我们完全可以把tokenizer函数放到mycollate()函数中。

  好了现在我们可以来解释为什么我在第二节的时候要求每个batch的数据要一样长了,那是因为当你不给定collate_fn这个参数的时候,会自动调用一个函数叫做default_collate(),大家可以粗略的看看这个内置函数的源码:

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

  看到倒数第三行了么?这就是为什么会报错的原因了。所以如果可以,我建议还是自己设定mycollate()函数,因为源码里如果你的dataset输出的元素不是tensor类型,那么将会按照它的方式来重新组织来返回,不同类别返回的东西是不一样的,大家可以看看源码。

 

 

 

 

参考网站:

Pytorch的第一步:(1) Dataset类的使用 - 简书 (jianshu.com)

Pytorch 中的数据类型 torch.utils.data.DataLoader 参数详解_Never-Giveup的博客-CSDN博客_dataloader参数

RuntimeError: each element in list of batch should be of equal size_NLP新手村成员的博客-CSDN博客

这篇关于Pytorch——Dataset类和DataLoader类的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!