Java教程

【yolov5 6.0 源码解析】---utils /datasets.py

本文主要是介绍【yolov5 6.0 源码解析】---utils /datasets.py,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

yolov5中数据读取并转换成训练格式

主要涉及到四点:

  1. 数据读取
  2. cache缓存
  3. 数据增强与label对应
  4. 其他一些辅助函数
    以下是自己的一些理解,如有纰漏,欢迎交流

class LoadImagesAndLabels(Dataset)

class LoadImagesAndLabels(Dataset):
    # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
    cache_version = 0.5  # dataset labels *.cache version

    # 初始化 
    def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
                 cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
        self.img_size = img_size # 图片大小
        self.augment = augment # 是否图片增强
        self.hyp = hyp # 超参
        self.image_weights = image_weights # 图片权重
        self.rect = False if image_weights else rect # 图片长宽比不resize成1 
        self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training) 
        self.mosaic_border = [-img_size // 2, -img_size // 2] # 如果mosaic, 边界 
        self.stride = stride # 步长
        self.path = path # 路径
        self.albumentations = Albumentations() if augment else None # 是否使用 Albumentations 库做数据增强 

        try:
            f = []  # image files
            for p in path if isinstance(path, list) else [path]:
                p = Path(p)  # 字符串的路径转成poxis路径 os-agnostic
                if p.is_dir():  # dir 匹配所有符合条件的文件,并以list 返回; recursive 是是否采用递归的方式 
                    f += glob.glob(str(p / '**' / '*.*'), recursive=True)
                    # f = list(p.rglob('**/*.*'))  # pathlib
                elif p.is_file():  # file(以文件的方式保存路径名,如coco.yaml ) 
                    with open(p, 'r') as t:
                        t = t.read().strip().splitlines() # 以list方式保存每一行路径字符串
                        parent = str(p.parent) + os.sep
                        f += [x.replace('./', parent) if x.startswith('./') else x for x in t]  # local to global path
                        # f += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib)
                else:
                    raise Exception(f'{prefix}{p} does not exist')
            self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS])
            # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats])  # pathlib
            assert self.img_files, f'{prefix}No images found'
        except Exception as e:
            raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')

        # Check cache
        self.label_files = img2label_paths(self.img_files)  # 将img图片路径转换成对应label路径
        cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
        try:
            cache, exists = np.load(cache_path, allow_pickle=True).item(), True  # load dict
            assert cache['version'] == self.cache_version  # same version
            assert cache['hash'] == get_hash(self.label_files + self.img_files)  # same hash
        except: # 否则重新缓存labels 
            cache, exists = self.cache_labels(cache_path, prefix), False  # cache

        # Display cache
        nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupted, total
        if exists:
            d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
            tqdm(None, desc=prefix + d, total=n, initial=n)  # display cache results
            if cache['msgs']:
                logging.info('\n'.join(cache['msgs']))  # display warnings
        assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'

        # Read cache
        [cache.pop(k) for k in ('hash', 'version', 'msgs')]  # remove items 将这三个值去掉 留下label,shape,segments 
        labels, shapes, self.segments = zip(*cache.values()) #返回元组组成的list  
        self.labels = list(labels) 
        self.shapes = np.array(shapes, dtype=np.float64)
        self.img_files = list(cache.keys())  # update 返回key 组成的list 
        self.label_files = img2label_paths(cache.keys())  # update 将imgs 路径转成对应的labels 路径 
        if single_cls: # 如果多类别合并成一个类别, 标签成 0 
            for x in self.labels: 
                x[:, 0] = 0

        n = len(shapes)  # number of images 
        bi = np.floor(np.arange(n) / batch_size).astype(np.int)  # 每张图片属于哪个batch的索引  batch index
        nb = bi[-1] + 1  # number of batches
        self.batch = bi  # batch index of image
        self.n = n
        self.indices = range(n)

        # Rectangular Training
        if self.rect:
            # 数据样本长宽比不为1 
            # Sort by aspect ratio
            s = self.shapes  # wh
            ar = s[:, 1] / s[:, 0]  # aspect ratio
            irect = ar.argsort() # 将长宽比从小到大排序,返回对应的索引
            self.img_files = [self.img_files[i] for i in irect] # 将图片按照长宽比从小到大重新排列img_file 
            self.label_files = [self.label_files[i] for i in irect]# 将图片按照长宽比从小到大重新排列label_file  
            self.labels = [self.labels[i] for i in irect] # 将图片按照长宽比从小到大重新排列label
            self.shapes = s[irect]  # wh
            ar = ar[irect] # h/w 

            # Set training image shapes
            shapes = [[1, 1]] * nb 
            for i in range(nb): # 对于每一个batch 
                ari = ar[bi == i] # 属于该batch的长宽比
                mini, maxi = ari.min(), ari.max() 
                if maxi < 1: # 长宽比[1, <1的值]
                    shapes[i] = [maxi, 1]
                elif mini > 1:
                    shapes[i] = [1, 1 / mini]
            # 对于每个batch,bacth_shape 的长宽比取最大的,或者宽长比最大的那个为整个batch的,同时为了保证上下采样像素点为整数 
            self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride

        # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
        self.imgs, self.img_npy = [None] * n, [None] * n
        if cache_images: # 缓存图片
            if cache_images == 'disk': # 将图片缓存进disk中
                self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy') # 图片缓存文件夹 
                self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files] # 将图片缓存成.npy文件
                self.im_cache_dir.mkdir(parents=True, exist_ok=True) # 创建文件夹缓存文件
            gb = 0  # Gigabytes of cached images 
            self.img_hw0, self.img_hw = [None] * n, [None] * n
            results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 多进程加载图片。并将图片resize下,返回resize后的图片和原始长宽比。
            pbar = tqdm(enumerate(results), total=n) # 以进度条的型式显示出来
            for i, x in pbar: 
                if cache_images == 'disk': 
                    if not self.img_npy[i].exists():# npy图片文件不存在,重新保存
                        np.save(self.img_npy[i].as_posix(), x[0])
                    gb += self.img_npy[i].stat().st_size # 文件大小
                else:
                    self.imgs[i], self.img_hw0[i], self.img_hw[i] = x  # im, hw_orig, hw_resized = load_image(self, i)
                    gb += self.imgs[i].nbytes
                pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
            pbar.close()

    def cache_labels(self, path=Path('./labels.cache'), prefix=''):
        # Cache dataset labels, check images and read shapes
        x = {}  # dict
        nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages
        desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
        with Pool(NUM_THREADS) as pool:
            # Pool python 多进程的一个子模块, 可以提供指定数量的进程给用户使用,一般用于需要执行的目标很多,而手动限制进程数量又繁琐时,如果目标少且不用控制进程数量的时候,用Process 类。
            pbar = tqdm(pool.imap(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
                        desc=desc, total=len(self.img_files))
            # tqdm 进度条显示;
            # pool.imap 输入函数,迭代器,返回iterable 
            # verify_image_label 验证图片和label 可读,并将label转换成统一格式,拱后面使用。
            for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
                nm += nm_f
                nf += nf_f
                ne += ne_f
                nc += nc_f
                if im_file: # 如果图片和label都有, 将标签,图片形状,分割点以字典的方式保存下来 
                    x[im_file] = [l, shape, segments]
                if msg: # 如果有miss 或者empty 将对于msg 保存下来
                    msgs.append(msg)
                pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"

        pbar.close()
        if msgs:
            logging.info('\n'.join(msgs))
        if nf == 0:
            logging.info(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
        x['hash'] = get_hash(self.label_files + self.img_files) # 将标签路径和图片路径以hash加密算法保存下来
        x['results'] = nf, nm, ne, nc, len(self.img_files) # # found , missing, empty, corrupt 
        x['msgs'] = msgs  # warnings
        x['version'] = self.cache_version  # cache version
        try:
            np.save(path, x)  # save cache for next time
            path.with_suffix('.cache.npy').rename(path)  # remove .npy suffix 将numpy保存的label.cache.npy 重命名为label.cache 
            logging.info(f'{prefix}New cache created: {path}')
        except Exception as e:
            logging.info(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}')  # path not writeable
        return x

    def __len__(self):
        return len(self.img_files) # 返回图片数量 

    # def __iter__(self):
    #     self.count = -1
    #     print('ran dataset iter')
    #     #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
    #     return self

    def __getitem__(self, index):
        index = self.indices[index]  # linear, shuffled, or image_weights

        hyp = self.hyp # 超参 
        mosaic = self.mosaic and random.random() < hyp['mosaic'] #  hyp['mosaic'] 取权重,大于随机值mosaic 
        if mosaic:
            # Load mosaic
            img, labels = load_mosaic(self, index) # 输出img和标签 
            shapes = None

            # MixUp augmentation mixup在mosaic里面运行 
            if random.random() < hyp['mixup']: 
                img, labels = mixup(img, labels, *load_mosaic(self, random.randint(0, self.n - 1))) #*load_mosaic(self, random.randint(0, self.n - 1)) 随机取图片 与之前的img 融合 

        else:
            # Load image
            img, (h0, w0), (h, w) = load_image(self, index) # 返回解析的图片、以前的长宽比、resize后的长宽比

            # Letterbox
            shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size  # final letterboxed shape
            img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment) # 将img 以letterbox方式 resize到指定长宽 
            shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescaling

            labels = self.labels[index].copy()
            if labels.size:  # normalized xywh to pixel xyxy format
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])

            if self.augment:
                img, labels = random_perspective(img, labels,
                                                 degrees=hyp['degrees'],
                                                 translate=hyp['translate'],
                                                 scale=hyp['scale'],
                                                 shear=hyp['shear'],
                                                 perspective=hyp['perspective'])

        nl = len(labels)  # number of labels
        if nl: # 再转成yolo格式的label
            labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)

        if self.augment: # 数据增强 
            # Albumentations
            img, labels = self.albumentations(img, labels)
            nl = len(labels)  # update after albumentations

            # HSV color-space
            augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])

            # Flip up-down 上下翻转 
            if random.random() < hyp['flipud']:
                img = np.flipud(img)
                if nl:
                    labels[:, 2] = 1 - labels[:, 2]

            # Flip left-right 左右翻转 
            if random.random() < hyp['fliplr']:
                img = np.fliplr(img)
                if nl:
                    labels[:, 1] = 1 - labels[:, 1]

            # Cutouts
            # labels = cutout(img, labels, p=0.5)

        labels_out = torch.zeros((nl, 6)) # 6 1-类别标签 4- box 1-对应batch
        if nl:
            labels_out[:, 1:] = torch.from_numpy(labels)

        # Convert
        img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        img = np.ascontiguousarray(img) # np.ascontiguousarray 将内存不连续的数组,转换成内存连续的数组 

        return torch.from_numpy(img), labels_out, self.img_files[index], shapes

    @staticmethod
    def collate_fn(batch):
        img, label, path, shapes = zip(*batch)  # transposed
        for i, l in enumerate(label):
            l[:, 0] = i  # add target image index for build_targets()
        return torch.stack(img, 0), torch.cat(label, 0), path, shapes

    @staticmethod
    def collate_fn4(batch):
        img, label, path, shapes = zip(*batch)  # transposed
        n = len(shapes) // 4
        img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]

        ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
        wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
        s = torch.tensor([[1, 1, .5, .5, .5, .5]])  # scale
        for i in range(n):  # zidane torch.zeros(16,3,720,1280)  # BCHW
            i *= 4
            if random.random() < 0.5:
                im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
                    0].type(img[i].type())
                l = label[i]
            else:
                im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
                l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
            img4.append(im)
            label4.append(l)

        for i, l in enumerate(label4):
            l[:, 0] = i  # add target image index for build_targets()

        return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4

整体如下:

# YOLOv5 

	
这篇关于【yolov5 6.0 源码解析】---utils /datasets.py的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!