无论是tensorflow,还是keras,抑或pytorch的torchvision提供的datasets库,都无法提供足够灵活、足够简洁的Dataset类。
因此,我打算自己写一个简单的易于扩展的单节点数据集工具类。
import numpy as np import torch import torchvision.datasets as D import torchvision.transforms as T import glob import cv2 from collections import Iterator class PairedGraySet(Iterator): def __init__(self, root_path, num_epoc, batch_size, shuffle=True, max_image_size=(640, 480)): self.batch_size = batch_size self.num_epoc = num_epoc self.shuffle = shuffle self.epoc = 0 self.iter = 0 self.im1 = None self.im2 = None # load images files_main = glob.glob(root_path + '/*/*/*_siam_main.png') files_main.sort() files_aux = glob.glob(root_path + '/*/*/*_siam_aux.png') files_aux.sort() assert len(files_aux) == len(files_main) assert len(files_main) > batch_size # try to load all images into memory n_pairs = len(files_main) self.images_main = [None]*n_pairs self.images_aux = [None]*n_pairs for i in range(len(files_main)): main_ = cv2.imread(files_main[i], -1) # gray image aux_ = cv2.imread(files_aux[i], -1) # gray image assert len(main_.shape)==2 assert len(aux_.shape)==2 assert main_.shape[0] == aux_.shape[0] assert main_.shape[1] == aux_.shape[1] self.images_main[i] = cv2.resize(main_, max_image_size, cv2.INTER_LINEAR) self.images_aux[i] = cv2.resize(aux_, max_image_size, cv2.INTER_LINEAR) # read images from root path, and determine how many # batches to run till an epoc is running out. self.num_samples = n_pairs self.max_iter = self.num_samples // self.batch_size # the first index sequence if self.shuffle: self.ind_seq = np.random.permutation(np.arange(start=0, stop=self.num_samples)) else: self.ind_seq = np.arange(start=0, stop=self.num_samples) # preallocation for speed h, w = self.images_main[0].shape self.im1 = np.zeros([self.batch_size, 1, h, w], np.float32) self.im2 = np.zeros([self.batch_size, 1, h, w], np.float32) def __next__(self): cur_epoc = self.epoc cur_iter = self.iter if cur_epoc >= self.num_epoc: print('training is over, return None!') return cur_epoc, cur_iter, None, None else: # make batch of image pairs beg_ = int(cur_iter*self.batch_size) end_ = beg_ + self.batch_size for pair_id in range(beg_, end_): self.im1[pair_id-beg_, 0, :,:] = self.images_main[self.ind_seq[pair_id]][:,:] self.im2[pair_id-beg_, 0, :,:] = self.images_aux[self.ind_seq[pair_id]][:,:] # update state to the next self.iter += 1 if self.iter == self.max_iter: self.epoc += 1 self.iter = 0 # shuffle the index sequence if self.shuffle: self.ind_seq = np.random.permutation(np.arange(start=0, stop=self.num_samples)) return cur_epoc, cur_iter, self.im1, self.im2
使用示例:
# resolve the package root path import os parent = os.path.dirname(os.path.abspath(__file__)) package_root = os.path.dirname(parent) os.environ['CUDA_VISIBLE_DEVICES']='3' import sys sys.path.append(package_root) from model import StereoSiamNet from dataset import PairedGraySet import numpy as np import torch import torch.nn as nn import torchvision.transforms as T import torch.optim def train_ssn(model_path, model_name, data_path): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) #net_ = StereoSiamNet() net_ = torch.load('ssn_bremen-cross_decoding-150.pth') net_ = net_.to(device) #print(net_.dump_info()) loss_fn = nn.SmoothL1Loss(reduction='mean') learning_rate = 1e-3 #print(net_.parameters) opt_ = torch.optim.Adam(net_.parameters(), lr=learning_rate) num_epoch = 2000 batch_size = 8 #8 print_freq = 100 #100 save_freq = 50 # 50 # load dataset data_loader = PairedGraySet('../Datasets/SSN/Bremen/', num_epoch, batch_size) net_.train() counter_ = 0 for i_epo, i_itr, im1, im2 in data_loader: if im1 is None or im2 is None: break counter_ += 1 x1 = torch.Tensor(im1).to(device) x2 = torch.Tensor(im2).to(device) x1r, x2r, _, _ = net_(x1, x2) loss_ = loss_fn(x1, x1r) + loss_fn(x2, x2r) opt_.zero_grad() loss_.backward() opt_.step() if counter_ % print_freq == 0: print('Epoch: %03d Iter: %5d Loss %8.5f' % (i_epo, i_itr, loss_.item())) if i_epo > 0 and (i_epo % save_freq == 0) and (i_itr==0): torch.save(net_, '%s-%03d.pth' % (model_name, i_epo)) print('model saved.') torch.save(net_, '%s-%03d.pth' % (model_name, num_epoch)) print('model saved.') print('training done.') if __name__ == '__main__': model_dir = '../Models/SSN/Bremen/' model_name = 'ssn_bremen-cross_decoding' data_dir = '../Datasets/SSN/Bremen/' train_ssn(model_dir, model_name, data_dir) print('training done!') print('trained model saved as: ' + os.path.join(model_dir + model_name))
可以看到,数据集的行为表现非常简单,和torchvision的Data Loader很相似,就是一个迭代器。