交叉验证法,就是把一个大的数据集分为 k k k 个小数据集,其中 k − 1 k-1 k−1 个作为训练集,剩下的 1 1 1 个作为测试集,在训练和测试的时候依次选择训练集和它对应的测试集。这种方法也被叫做 k k k 折交叉验证法(k-fold cross validation)。最终的结果是这 k 次验证的均值。
此外,还有一种交叉验证方法就是 留一法(Leave-One-Out,简称LOO),顾名思义,就是使 k k k 等于数据集中数据的个数,每次只使用一个作为测试集,剩下的全部作为训练集,这种方法得出的结果与训练整个测试集的期望值最为接近,但是成本过于庞大。
我们用SKlearn库来实现一下LOO:
from sklearn.model_selection import LeaveOneOut # 一维示例数据 data_dim1 = [1, 2, 3, 4, 5] # 二维示例数据 data_dim2 = [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5]] loo = LeaveOneOut() # 实例化LOO对象 # 取LOO训练、测试集数据索引 for train_idx, test_idx in loo.split(data_dim1): # train_idx 是指训练数据在总数据集上的索引位置 # test_idx 是指测试数据在总数据集上的索引位置 print("train_index: %s, test_index %s" % (train_idx, test_idx)) # 取LOO训练、测试集数据值 for train_idx, test_idx in loo.split(data_dim1): # train_idx 是指训练数据在总数据集上的索引位置 # test_idx 是指测试数据在总数据集上的索引位置 train_data = [data_dim1[i] for i in train_idx] test_data = [data_dim1[i] for i in test_idx] print("train_data: %s, test_data %s" % (train_data, test_data))
data_dim1的输出:
train_index: [1 2 3 4], test_index [0] train_index: [0 2 3 4], test_index [1] train_index: [0 1 3 4], test_index [2] train_index: [0 1 2 4], test_index [3] train_index: [0 1 2 3], test_index [4] train_data: [2, 3, 4, 5], test_data [1] train_data: [1, 3, 4, 5], test_data [2] train_data: [1, 2, 4, 5], test_data [3] train_data: [1, 2, 3, 5], test_data [4] train_data: [1, 2, 3, 4], test_data [5]
data_dim2的输出:
train_index: [1 2 3 4], test_index [0] train_index: [0 2 3 4], test_index [1] train_index: [0 1 3 4], test_index [2] train_index: [0 1 2 4], test_index [3] train_index: [0 1 2 3], test_index [4] train_data: [[2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5]], test_data [[1, 1, 1, 1]] train_data: [[1, 1, 1, 1], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5]], test_data [[2, 2, 2, 2]] train_data: [[1, 1, 1, 1], [2, 2, 2, 2], [4, 4, 4, 4], [5, 5, 5, 5]], test_data [[3, 3, 3, 3]] train_data: [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [5, 5, 5, 5]], test_data [[4, 4, 4, 4]] train_data: [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]], test_data [[5, 5, 5, 5]]