反观人类的学习方法,不仅仅是学会了一样任务,更重要的是具备学习能力,能够利用以往学习到的知识来指导学习新的任务。如何设计能够通过少量样本的训练来适应新任务的学习模型,是元学习解决的目标问题,实现的方式包括[1]:根据模型评估指标(如模型预测的精确度)学习一种映射关系函数(如排序),基于新任务的表示,找到对应的最优模型参数;学习任务层面的知识,而不仅仅是任务中的具体内容,如任务的分布、不同任务的特征表示;学习一个基模型,这个基模型的参数是基于以往多种任务的各个特定模型而得到的,等等。
图 1:什么是元学习(图源:http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf)
下面从元学习的工程优化、解决局部最优和过拟合问题、模型解释性等方面详细解读和分析四篇论文。
一、"TaskNorm: Rethinking Batch Normalization for Meta-Learning"
核心:元模型训练阶段的工程优化
本文是发表于 ICML 2020 中的一篇文章[2],是剑桥大学、Invenia 实验室和微软研究院学者共同合作的研究成果,提出了一种适用于元学习在模型训练时的数据批量标准化方法。
在深度学习中网络模型的训练通常基于梯度下降法,与模型学习效果相关的因素包括了学习步长(学习率)、网络初始化参数,并且当涉及深层网络训练时,还需要考虑梯度消失的问题。标准化层(normalization layer,NL)的提出,使得增加了标准化层的网络在训练时,能够使用更高的学习率,并且能够降低网络对于初始参数的敏感度,对于深层网络的训练更加重要。NL 的一般表示为:
其中,γ和β为学习的参数,μ和σ是标准化的统计量,a_n 和 a’_n 是输入和标准化后的输出。
图 1.1:元学习的训练集。这是图片分类的例子,在不同 episode 中,由不同的子类构成不同的分类任务;在相同的 episode 中,支持集和查询集包含了相同的子类。来自:https://www.jiqizhixin.com/articles/2019-07-01-8
元学习的训练数据集包括了 context set Dτ(也称为 support set,支持集)和 target set Tτ(也称为 query set,查询集),如图 1.1 所示。利用这个数据集进行两个阶段的训练:在内层(inner loop)阶段,使用 context set 来更新参数θ,得到特定任务的参数ψ;在外层(outer loop)阶段(fφ表示由θ生成ψ的一个过程,可能会引入额外的参数φ),对 target set 中的 input 进行预测,并得到目标损失函数:
元学习中的分层框架(inner loop 和 outer loop 两层更新,如图 1.2 所示),可能会使得传统的批标准化方式(batch normalization,BN)失效:BN 的使用具有一定的前提条件,独立同分布 iid 条件,而元学习可能不满足这个条件,如果直接使用 BN 方法在元学习的网络模型中引入标准化层,可能会导致不理想的元模型效果。
作者提出了一种适用于元学习的标准化方式 --- 任务标准化(task normalization,TaskNorm),它能够提升模型训练的速率和稳定性,并且能够保持理想的测试效果;另外,它适用于不同大小的 context set,并没有受到很大的影响;而且这种标准化方式是非直推式的,因此在测试的时候能够适用于更多的情景(即更多样的图像分类任务)。在具体展开介绍 TaskNorm 之前,作者先对元学习的推理方式和几种常见的标准化方法进行简单介绍,并且说明了在元学习中对应不同的标准化方法的统计量μ和σ的计算和使用方式。
1.1 方法介绍
直推学习(transductive meta-learning)和非直推学习(non-transductive meta-learning)
对于元学习,作者讨论了两种方式:直推学习和非直推学习。非直推学习的元测试(meta-test)阶段,在对测试集(和训练集类似,也包括了 context set 和 target set)中的单个样本进行类别预测时,仅仅使用 context set 以及输入的观测值。直推学习的元测试阶段,对单个样本进行预测时,不仅需要 context set 和观测值,还需要测试集中其他样本的观测值。作者认为,元学习中的标准化层需要是 * 非直推式 * 的,因为对于直推学习,作者认为它的两个问题:
1. 对 target set 的分布敏感。在 outer loop 时,需要用到 target set 的其他样本,即当前样本的预测输出还与其他样本的输入相关,因此这种方式相比于非直推学习的泛化性更弱。如果在元测试中使用的 target set 样本的类别平衡情况和训练时有差别,那么模型在测试时的分类效果可能并不会很好。
2. 直推学习利用到了更多的信息(相当于需要依赖的信息更多),因此如果将两种方法直接进行比较是不公平的。
几种基本的标准化方式以及在元学习中的应用
批标准化(batch normalization,BN)。BN 在训练阶段和测试阶段的使用模式是不一样的。在元训练(meta-training)阶段,均值和方差的计算如下所示:
在 BN 中,输入的通道数不变,对每个通道、使用整个 batch 进行变换,这种标准化的方式没有涉及不同通道之间的数据交换。更直观一点,数据集输入的维度表示为 < B,C,W,H>,那么标准化计算量μ和σ的维度表示为 <1,C,1,1>。使用所有 batch 计算统计量有一个前提,就是假设了 batch 中的数据服从独立同分布。在测试阶段,使用的均值和方差是训练集所有数据的均值和方差。
在元学习网络中直接使用批标准化(Conventional Usage of Batch Normalization,CBN),会有两个重要的问题:(1)在元测试阶段,使用的是根据元训练阶段数据集计算得到的μ和σ,可以认为这两个统计量是和元模型等效的参数。然而,训练时的数据集包括了所有不同的任务,独立同分布的条件只是在相同任务的数据之间满足、在不同任务之间不一定满足。作者将 CBN 应用在 MAML 方法 [3] 中,实验结果表明了该方法在预测任务上表现并不好。(2)当训练过程中使用的 batch-size 较小,得到的统计量可能并不准确时,模型的效果也会受到影响。
图 1.3:批标准化(BN),元学习训练和测试过程中直接使用 BN 的方式。图源:[2]
基于实例的标准化(Instance-based Normalization)。基于实例的标准化方式是非直推式的,统计量只根据当前实例(如单张图片)来计算μ和σ,并且不依赖于 context set 数据集的大小。
1. 实例标准化(instance normalization,IN)。针对单张图片的 (H,W) 两个维度计算统计量(即每一张图只对 H 和 W 维度进行归一化),每一张图都有对应的统计量。该计算方式在元训练阶段(使用训练集)和元测试阶段(使用测试集)是一样的。
图 1.4:实例标准化(IN),元学习中 context set 和 target set 使用 IN 的方式。图源[2]
2. 层标准化(layer normalization,LN)。LN 针对图片单独进行变换,并考虑到了多个通道的维度。该计算方式在元训练阶段(使用训练集)和元测试阶段(使用测试集)是一样的。作者在后续提供的实验结果中,指出 LN 相比于其他标准化方式,在训练效率方面的表现较不足。
图 1.5:层标准化(LN),元学习中 context set 和 target set 使用 LN 的方式。图源:[2]
直推批标准化(transductive batch normalization,TBN)。相比于 CBN,TBN 的标准化方式在元测试阶段,并不是使用元训练阶段数据集的统计量,而是使用测试数据集(包括 context-set 或者是 target-set)来计算μ和σ。另外,TBN 会根据不同的任务分别计算各自的统计量。
虽然这种方法能够获得更好的效果,但是在元测试时,对于 target-set 的标准化处理使用了 target-set 全局的统计量,相当于测试的数据之间是存在某种信息交流和利用的,给了更多的先验信息,提升测试的准确率。这种方式在信息利用方面和非直推学习方式并不是对等的,因此不能直接比较 TBN 和其他的非直推方式。
图 1.6:直推式批标准化(TBN)。图源:[2]
1.2 任务标准化(Task Normalization, TASKNORM)
本质上,找到适用于元学习的标准化方法,关键在于找到合适的统计量μ和σ。根据标准化处理对于数据的独立同分布条件要求,对于元学习来说, μ和σ应该是任务级别的统计量,在一定程度上是融入任务模型参数ψ中。ψ是元模型通过适应 context set 而得到的任务模型的参数,因此在任务模型的推理阶段,用到的统计量μ和σ也应该能够从 context set 计算得到。
结合上述元学习对于标准化统计量的要求,作者首先提出了一种元批量标准化方法( meta-batch normalization,MetaBN)。对于每个任务,在 context set 中计算各自的均值和方差,这个统计量共用于 context set 和 target set;在元训练阶段和元测试阶段,是分别根据训练集中的 context set 和测试集中的 target set 得到各阶段的标准化统计量。但是,这种标准化方法仍然会受到 context set 大小的影响:当 context set 的 batch size 较小时,统计量的准确度不够高,会影响模型的预测效果。
图 1.7:MetaBN 方法和 TaskNorm 方法(包括 TaskNorm-L 和 TaskNorm-I)。图源:[2]
进一步地,作者保留了 MetaBN 的优点,结合基于实例的标准化方法不依赖数据集大小的特点,提出了本文的核心内容:任务标准化(TASKNORM)。TASKNORM 方法是在 MetaBN 的基础上,结合了 LN 或者是 IN,可以具体分为 TaskNorm-L 以及 TaskNorm-I 两种标准化方法:元训练(元测试)阶段,使用训练集(测试集)的 context set 得到统计量,context set 和 target set 都使用该统计量以及各自的 LN 或者 IN 的加权和,得到最终用于标准化的统计量,其中两部分统计量的权重由超参数α控制。此时的μ和σ的计算由下式得到:
其中,μ_{BN}和σ^2_{BN}是根据 context set 计算的统计量,μ+ 和σ+ 是根据层标准化(LN)或者是实例标准化(IN)得到的非直推式的统计量。这种结合方式的出发点是 * 解决使用少样本学习时存在的样本数量相关问题 *:当 context set 的样本量很少时,仅根据该数量集得到的统计量可能会得到关于该任务的不准确的数据;当结合其他统计量时,有助于提升训练效率以及模型的预测效果。
作者将权重α定义为一个参数化的变量,它和 context set 大小具有线性关系,表示为:α=sigmoid(scale|Dt| + offset)。其中 Dt 为 context set 元素个数,scale 和 offset 在元训练阶段是可学习的。α和 support set 大小之间存在线性关系式,表示为:α=sigmoid(scale|Dt| + offset)。其中 Dt 为 context set 的大小,scale 和 offset 是在元训练时学习得到的。
1.3 实验介绍
作者分别在小规模数据集和大规模数据集上进行少样本(few-shot)分类任务,对比几种标准化方法,验证本文提出的几个猜想:1)元学习对于标准化方式是比较敏感的;2)直推批标准化(TBN)比非直推批标准化的效果普遍要好;3)考虑了元学习数据集特性的方法如 TaskNorm,MetaBN 以及 RN 的效果,会比 CBN,BRN(batch renormalization),IN,LN 等没有考虑元学习数据特性的方法要好。在实验中,作者关注的指标包括模型预测的准确度和训练效率。
表 1.1:基于小数据集(mini imagenet 和 omniglot)的分类实验,此时仅考虑固定大小的 context set 和 target set。来自:[2]
表 1.2:基于大数据集 meta-dataset(包含了 13 个图像分类的数据集)的分类实验。来自:[2]
图 1.8:不同标准化方法得到的模型准确度和训练过程的对比图。图源:[2]
1.4 小结
本文提出了一种适用于元学习的标准化方法 TASKNORM,基于传统批标准化方法对统计量的计算进行改进。在计算用于数据标准化的统计量均值μ和方差σ^2 时,该方法考虑了任务内数据的独立同分布、任务间的数据不满足独立同分布条件,context set 大小的影响,以及考虑非直推式的学习方式,从而使得元学习模型能够应用在更多的场景。通过大量的对比实验,验证了使用 TASKNORM 方法能够提升元学习模型的训练效率和预测效果。
二、 "Meta-Learning with Warped Gradient Descent" (ICLR2020)
核心:解决基于梯度的元学习方法的参数局部最优问题
本文是发表于 ICLR 2020 中的一篇满分论文[4],由曼彻斯特大学、Alan 图灵研究机构和 DeepMind 的研究员提出了元学习中的梯度预处理计算方法。
在元学习领域有一个重要的问题,是学会一种更新规则,能够快速适应新的任务。处理这个问题的方式通常有两种:训练网络来产生更新(学习更新方式);或者是学习一个比较好的初始化模型或者是比例因子,应用于基于梯度更新的学习方法(学习和梯度更新相关的因素)。前者容易导致不收敛的效果,后者在少样本(few-shot 任务中的适应效果可能不太好。
作者结合前面说的两种方式,提出一种弯曲梯度下降(warped gradient descent)的方法,它主要学习一个参数化预处理矩阵,该矩阵是通过在 task-learner 网络模型的各层之间交叉放置非线性激活层(即弯曲层,warped layers)而产生。在网络训练时,这些 warp 层提供了一种更新方式,而它的参数是 meta-learned,在模型训练过程中是不经过梯度回传的。
为了验证这种梯度更新方式的有效性,作者还将这种弯曲梯度方法应用在少样本学习,标准的有监督学习,持续学习和强化学习等多种设定下进行实验。
2.1 方法介绍
在基于梯度更新的元学习中,task-learner 元参数的更新规则表示为 U(θ; ξ):= θ-α∇L(θ),初始参数θ_0 的元学习过程可表示为:
这类方法由于依赖于梯度更新的轨迹,会存在一些问题:梯度的计算会涉及到较大的计算量;容易受到梯度爆炸或者是梯度消失情况的影响;置信度分配问题。将损失函数 L 抽象成一个曲面,该曲面的情况会影响参数调整的效果,并且此时的参数空间不一定是合理的、不一定适用于不同任务的空间。
针对这几个问题,作者首先了介绍了一种结合预处理的梯度更新通用规则,表示为:
其中,P 表示一个用于预处理梯度的曲面。为了更好地拆分预处理模块的参数和 task-learner 的参数,作者使用了一种更为灵活的结构:在多层网络模型中插入全局参数化的 warp 层。最为简单的一种插入方式表示为:
h 是网络的隐藏层,w 是插入的 warp 层。在梯度回传时,对于 warp 层使用的是 Jacobian 矩阵(Dx 和 Dθ)来计算:
warp-layers 的具体原理和计算流程
如图 2.1 所示,是 warp 层在 task-learner 中的使用和计算流程。对于 task learner f(x),隐藏层之间(h1 和 h2)嵌入 warp 层(ω1 和ω2):在前向计算时,warp 层相当于激活层;在任务适应阶段(task adaptation)的后向回传中,warp 层通过 Dω来提供梯度。这就是本文提出的用于网络参数更新的 WarpGrad 方法。
图 2.1:warp 层及 WarpGrad 计算的示意图。图源:[4]
通过曲面的图示来更形象地展示 WarpGrad 起到的作用,如图 2.2 所示。在理想的 W 空间曲面,能够产生梯度上的预处理,找出梯度下降的最大方向。
图 2.2:上一行表示 WarpGrad 学习到的元几何(meta-geometry)P 曲面;下一行表示不同任务的损失函数 W 曲面,其中黑线是普通梯度下降的方向,紫色是利用元几何得到的梯度下降的方向。图源:[4]
考虑到 warp 层具有几何曲面的表示意义,作者提出 warp 层实际上是近似一个矩阵 G,该矩阵是一个正定的矩阵向量,用于度量流形的曲率。
Ω表示 warp-layers 起到的作用,它相当于通过重参数化(ω)来近似于最快的梯度下降方向:
在 P - 空间和 W - 空间上的梯度表示为:
其中,γ=Ω(θ; φ)表示从 P 空间映射到 W 空间的映射参数,并且
P 空间的参数梯度和 W 空间的参数梯度之间的转换关系如图 2.3 所示:
图 2.3:P 空间的θ参数梯度等价于 W 空间的γ参数梯度。图源:[4]
Warp 层参数控制了理想曲面的生成,本质上控制了 task learner 的收敛目标。因此,为了积累所有任务的信息帮助提升任务适应的过程,warp 层参数是通过元学习来训练得到的,目标函数表示为:
Warp-layer 参数的学习方式
作者定义了一个高层的任务τ=(h, L_{meta}, L_{task}),L_{meta}作为元训练的目标损失函数,用于 warp 参数的适应学习;L_{task}作为任务适应的目标函数,用于θ参数的适应学习。
上式对于φ的学习,依赖于 L-task,会涉及到二阶梯度的计算。作者进一步做梯度截断(stop gradient),使得φ的更新只涉及一阶梯度。
图 2.4:warpgrad 应用于在线元学习和离线元学习的算法流程。图源:[4]
2.2 实验介绍
在实验部分,作者在元学习方法 MAML[3]和 Leap[5]方法中引入 WarpGrad 的更新方式,在两个数据集(miniImageNet 和 tieredImageNet)上做少样本(few-shot)学习和多样本(multi-shot)学习,使用了 WarpGrad 方法的元学习模型能够超过普通元学习模型在分类任务上的准确率。
图 2.5:使用 warpgrad 方法进行少样本学习和多样本学习的对比实验。图源:[4]
作者还验证了 WarpGrad 方法对模型在不同任务上的泛化能力的作用。如图 2.6 所示,在不同任务数量的实验中,Warp-Leap 模型的测试准确率明显高于其他几种基准方法。
图 2.6:对比不同方法在不同任务数量实验中的准确率。图源:[4]
2.3 小结
本文提出了一种更为泛化的基于梯度的元学习方法 WarpGrad,在网络中引入 warp 层用于预处理原始梯度,该方法的特点包括:(1)WarpGrad 方法本质上是一种基于梯度的更新方式,它的创新之处在于对梯度进行了预处理,所以它也具有梯度下降法的特性,能够保证训练模型的收敛;(2)warp 层构造了梯度预处理的分布,而这个分布所具有的几何曲面能够从任务学习者中分离出来;(3)warp 层的参数是通过任务和对应轨迹来元学习得到的,根据局部的信息来获得任务分布相关的属性;(4)相比于用预处理矩阵来直接对梯度进行处理,warp 层在网络模型中同时参与了前向计算和后向梯度回传,是一种更为有效的学习方法。
三、"Meta-Learning without Memorization"
核心:解决任务层面的过拟合问题
本文是由 Google brain 团队和 UT Austin 学者发表于 ICLR 2020 中的一篇文章[6],它探讨了元学习模型的记忆问题并提出解决方法。
在分类任务中,当图片和类别标签并不是互斥的(mutually-exclusive)时(如在分类任务 1 中,狗的类别标签是 2;在分类任务 3 中,狗的类别标签仍然是 2),分类模型做的事情其实是直接将类别标签和图片中的数据特征对应起来。此时,训练得到元模型可能 * 无法 * 很好地应用在新的分类任务上:在训练阶段,模型不需要适应训练数据集、就可以在测试数据集上达到较好地效果;而在推理阶段,适应能力较弱的模型,则无法适应新任务的训练数据集,很难在新任务的测试数据集上达到理想效果。
图 3.1:Meta-learning 的图模型表示。图源:[6]
结合元学习的图模型来进一步理解这个问题的定义。M 是元训练数据集,包括了在元训练阶段的训练数据集 D(support set)和测试数据集 D*(query set),θ是元模型参数,φ是特定任务模型参数(task-specific parameters)。q(θ|M)表示基于元训练数据的元参数分布,q(φ|D, theta)表示基于任务训练(per-task training)的任务参数分布,q(y*|x*, φ, θ)表示预测的分布:
那什么是记忆问题?就是 y * 的计算,可以独立于φ和 Di,完全依赖于θ和 x*,即 q(y*|x*, φ, θ)=q(y*|x*, θ)。此时,在测试数据集上的预测结果可以直接根据元模型参数θ来得到,而不需要经过通过适应 D 而得到优化后的参数φ来进行预测的过程。
3.1 方法介绍
在本文中,作者给出了记忆问题的数学形式,引入互信息(mutual information)这个概念:在元学习中的完全记忆,指的是模型在预测 y 时忽略任务训练数据集 D 的信息,即 y 和 D 之间的互信息为 0,表示为 I(y;D|x,θ)=0。为了同时达到低误差,以及 y * 和 (x*,θ) 之间的低互信息,需要利用任务训练数据 D 来做预测,即增大 I(y*;D|x*, θ),从而减少记忆问题。
在本文中,作者提出元正则项(meta-regularizer, MR),基于信息论来提供一个通用的、不需要在任务分布上设置限制条件的方法,解决元学习的记忆问题。更具体地,分别是:激活项上的元正则化(meta regularization on activations),权重上的元正则化(meta regularization on weights)。
激活项上的元正则化。在上图中,当给定 theta 时,y * 和 x * 之间的信息流,包括了 y * 和 x * 之间的直接依赖,以及经过数据集 D 的间接依赖。作者提出,通过引入了一个中间变量 z*,有 q(ˆy* |x* , φ, θ) = ∫ q(ˆy* |z* , φ, θ)q(z* |x* , θ) dz*,控制 \ hat{y}* 和 x * 之间的信息流来解决记忆问题。
图 3.2:引入中间变量 z 的元学习的图模型,。图源:[6]
此时,为了引导模型有效地利用任务训练数据 D,增大的互信息目标变为 I(y*;D|z*, θ),通过如下的推导,等价于增大互信息 I(x*;y*|θ)和减小 KL 散度 E[D_{KL}(q(z*|x*,θ) || r(z*))]:
对于上式左项的互信息,假如 I(x*;y*|θ)=0,并且存在记忆问题(I(y*;D|x*,θ)=0)时,那么有 q(y*|x*, θ, D)=q(y*|x*, θ)=q(y*|θ),即预测结果 y * 并不依赖于观测值 x*,显然这样的模型并不会得到理想的预测准确度。因此,最小化损失函数(如式 (1))有助于引导互信息 I(y*;D|x*,θ) 或者是 I(x*;y*|θ)的最大化,所以在引入中间变量 z * 后,需要做的就是最小化 KL 散度,最终的损失函数表示为:
但是,作者在实验过程中发现这种方法在一些情况下并不能避免记忆问题,并进一步提出了另一种元正则化方法。
权重上的元正则化。作者提出,通过惩罚元模型参数,减少元参数所带有的任务信息,从而降低模型对于任务的记忆能力、解决记忆问题。对于元参数θ中包含的训练任务信息,可以表示为 I(y*1:N,D1:N; θ|x*1:N ),它的上确界有:
元参数的惩罚项即为最后的 KL 散度,该惩罚项实际上是限制模型参数的复杂度:如果模型需要去记住所有任务的信息,那么模型非常复杂;所以限制模型的复杂度,在一定程度上能够减少元参数包含的任务信息。但是,作者并没有完全限制模型参数的复杂度,在实际应用中,仍允许部分模型参数对任务训练数据进行处理,因此只是在部分参数θ上执行该惩罚项(模型的其他参数则表示为θ~),最后损失函数可以表示为:
3.2 实验介绍
本文分别在分类任务和回归任务上进行对比实验,在这些任务中图片标签和图片数据本身是非互斥的,用于验证元正则化方法在记忆问题上的有效性。如表 3.1 和 3.2 所示,使用了元正则化(MR)的方法,相比于其他的元学习基准方法,在分类任务和回归任务上都能明显获得更好的效果。
表 3.1:图片标签非互斥的回归任务(均方差),A 表示使用了激活项上的元正则化,W 表示使用了权重上的元正则化。来自:[6]
表 3.2:图片标签非互斥的分类任务(准确率)。来自:[6]
3.3 小结
本文从信息论的角度,提出了一种适用于不同的元学习方法的元正则化(MR)方法。该方法可以用在标签没有打乱(或者是很难打乱)的任务中,能够提升元学习方法在更多场景中的适用性和可行性,在一定程度上解决元学习的记忆问题。
四、"Unraveling Meta-Learning: Understanding Feature Representations for Few-Shot Tasks"
核心:探讨元模型特征表示模块的作用(元学习方法的可解释性)
本文是由马里兰大学的学者发表于 ICML 2020 中的一篇文章[7]。在少样本分类(few-shot classification)任务的场景中,元学习方法能够提供一个快速适应新任务(new tasks)或者是新域(new domains)的基础模型。然而,很少有工作去探讨模型达到不错效果的深层原因,如元学习方法中特征提取模块(feature extractor)提取得到的特征表示的不同之处是什么。
本文提出,相比于普通学习得到的特征表示,元学习得到的特征表示(meta-learned representations)是有区别的、更有助于少样本学习。使用元学习的特征表示能够提升少样本学习的效果,本文作者归为两种不同的机制:(1)固定特征提取模块参数,只更新(微调)最后的分类层(classification layer)参数。在这种机制下,类别数据点在特征空间中会更加聚集,那么在微调时,分类边界对于提供的样本会没那么敏感。(2)在模型参数空间寻找最优点作为基础模型,该最优点接近大部分特定任务(task-specific)模型参数的最优点,那么在面对新的特定任务时,能够通过几步的梯度计算,将基础模型更新为适用于新任务的特定模型。
进一步地,作者分别探讨上述两种机制的作用,定义了几种正则项,并结合正则项提出了几种带正则化的模型训练方法,通过实验验证了相关猜想以及正则化训练方法的有效性。
4.1 基于特征聚集的正则化方法
4.1.1 在特征空间的类别特征点聚集
作者先讨论第一种机制,即微调时固定特征提取模块、只更新分类层,使用这类机制的元学习方法包括 ProtoNet[8],R2-D2[9]和 MetaOptNet[10]。这类方法能够达到好的分类效果,猜想是特征提取模块已经能够做到很好的特征区分、从而对于新的分类任务也能够实现少样本学习。
特征点聚集对于少样本学习的重要性。如下图所示,当类别的特征点是分散的、类间相隔较近时,选取少量样本来训练分割平面容易导致较大的分割误差;而当类别的特征点是聚集的、类间相隔较远,训练得到的分割平面准确度较高,分割平面对于样本选取的依赖较弱。
图 4.1:特征点聚集对于少样本训练分割平面准确度的重要性。图源:[7]
然后,作者通过对比元学习的 ProtoNet 和传统训练的网络模型的特征提取效果,验证了元学习方法在特征点聚集上做得更好,虽然没有直接证明特征点聚集对于少样本学习的必要性,但是为接下来提出的基于特征点聚集的正则项提供了重要的思路和启发。
图 4.2:ProtoNet 和经典分类网络在 mini-ImageNet 数据集上提取的特征进行可视化(使用 LDA 处理元学习和经典分类器提取的特征,可视化映射到二维空间的特征)。图源:[7]
本文考虑特征聚集的评估指标(feature clustering, FC),定义为类内方差和类间方差的占比。根据 FC 的定义,本文给出了特征聚集的正则项(feature clustering regularizer, R_fc)定义:
其中,f_{θ}(x_i,j)是特征提取模块 f_{θ}对样本 x 给出的特征表示,μ_i 是第 i 类的特征向量均值,μ是所有数据的特征向量均值。作者基于 R2-D2 和 MetaOptNet 的网络结构,结合交叉熵损失函数和该正则项,作为传统的训练方法的损失函数,在 mini-ImageNet 数据集和 CIFAR-FS 数据集上进行 1-shot 和 5-shot 的实验,对比使用元训练的方法和不使用该正则项的传统训练方法。
如表 4.1 所示,相比于没有用 R_fc 训练的网络效果,使用 R_fc 来训练网络,能够和元学习网络达到类似的高分。这进一步说明了使用 R_fc 可以得到类似于元学习网络得到的特征表示,那么元学习方法实际上也有做特征聚集的工作。
更进一步地,作者探讨 特征点聚集和分割平面对数据样本不变性两者之间的联系,提出了超平面方差的正则项(hyperplane variation regularizer):
对于两个类别的特征点(A 类的 x1 和 x2,B 类的 y1 和 y2),该正则项衡量了不同类别数据点之间的距离向量的差异。当超平面对于数据样本有较强不变性时,该正则项的值越小。同样地,作者使用该正则项进行对比实验,效果和 Rfc 类似,比没有使用 Rhv 的传统训练方法的到的模型的分类效果要好。
表 4.1:使用 Rfc 或者是 Rhv 的对比实验结果。来自:[7]
前面的实验中,考虑的元学习训练方式是第一种机制,那对于微调时不会固定特征提取模块的元学习训练方式(比如 MAML 方法),情况又是怎样的呢?作者将 MAML 方法和迁移学习方法对比,发现 MAML 模型的效果并没有比传统训练模型的 feature seperation 效果更优,说明了特征聚集的提升作用,并不是元学习训练中会有的普遍现象,而是特定地存在于使用第一种机制的元训练模型中。于是接下来,作者对于元学习第二种机制的有效性进行了探讨和分析。
4.2 权重聚集的正则化方法(weight-clustering regularization)
4.1.2 在参数空间的任务损失函数的最优点聚集
接下来讨论没有固定特征提取模块的元模型,这类模型的参数能够很好地适应新任务。对于 Reptile[10],作者提出了一种假设:该方法寻找的模型参数,是接近于很多任务的最优点,所以能够在微调之后在这些任务上达到较好的效果。为了验证这个猜想,本文将 Reptile 方法表示为类似于一致性最优化方法的形式(consensus optimization,使用一项惩罚来促进各个特定任务的模型收敛到共同的参数),最小化的目标函数为:
θ~ 是 task-specific 参数,θ是一致值(实际上是元参数),左项是针对任务 p 的损失函数,右项是距离惩罚项,引导模型参数收敛到一个一致值的附近。虽然 Reptile 实际上并没有很明显地使用第二项来得到最优的 task-specific 参数,但是它使用了θ作为 task-specific 模型的初始化参数,隐式地促使θ~ 是在θ附近。
为了验证参数聚集的作用,作者在原始 reptile 算法中内部循环(inner loop)的损失函数加上如下一项,进而提出权重聚集(Weight Clustering)方法:
该项给出了针对某个任务 i 的模型参数θ^~_i 与当前训练批次所有任务的模型参数θ^~_p 的均值之间的距离。通过将 Reptile 方法结合该正则项,能够更显式地促使训练模型的参数聚集,在 1-shot 和 5-shot 实验中都能获得更优于传统训练方法、一阶 MAML 方法(FOMAML)和原始 Reptile 方法的效果。
图 4.4:使用了参数聚集正则化的 Reptile 算法(红色椭圆即为参数聚集相关的正则项)。图源:[7]
表 4.2:通过在 mini-ImageNet 上的对比实验,验证了增加惩罚项 Ri(即表中 W-Clustering 所示)对于模型效果的提升作用。来自:[7]
4.3 小结
本文对于元学习训练方法在少样本学习场景中的有效性进行了深入探讨,并提出了元学习得到的数据特征表示是不同于普通训练方法得到的数据特征表示的猜想。本文根据这个猜想设计了具有特征聚集特性和权重聚集特性两种正则项,并分别应用到迁移学习方法和原始元学习方法中,验证了正则项对于提升模型效果的作用。