Intro

在机器学习领域,监督学习(Supervised Learning)是一种广泛应用的模型训练范式,依赖于大量的标注数据。具体来说,监督学习的训练过程中,模型fθ(x)f_{\theta}(\mathbf x) 的参数θ\theta 是通过含有ll 个带标注的训练数据组成的集合L={(x1,y1),(x2,y2),...,(xl,yl)}\mathfrak L=\{(\mathbf x_1,y_1),(\mathbf x_2,y_2),...,(\mathbf x_l,y_l)\} 学习得到的。但现实世界中数据标注成本是高昂且耗时的。例如,在医学影像分析中,标注一张CT扫描图需要专业医生花费数小时,而未标注数据却可能以TB级存在。如果将这些未标注数据的集合记为U={x1,x2,...,xu}\mathfrak U=\{\mathbf x_1,\mathbf x_2,...,\mathbf x_u\},有ulu\gg l

半监督学习(Semi-Supervised Learning,SSL) 就是一种催生于此种背景下,尝试利用少量标注数据和大量未标注数据共同进行模型训练,提升模型性能的技术。

基本假设

半监督学习得以应用主要基于以下假设:

The Smoothness Assumption.

  • 平滑假设:在输入空间中位于稠密连续数据区域内两个距离很近的样本,它们的输出也应该是相近的,反之则趋于不同。

The Cluster Assumption.

  • 聚类假设:属于同一个聚类簇的两个样本,它们的类标签也应该属于同一类。

The Manifold Assumption.

  • 流形假设:如果两个高维样本恰好可以映射到一个低维的流形结构上,它们的输出应该是相近的。

基本分类

Consistency Regularization (a.k.a Consistency Training).

一致性正则化:该方法主要基于一个假设:对输入的(未标注)数据做一个微小扰动后,它的输出结果应该是不变的。把这种约束加到损失函数中作为一个正则项进行训练时,就能对那些未标注的数据产生正向效果。

Proxy-label Methods.

代理标签法:在未标记数据上生成 proxy-label,并将其与有标注数据共同使用。这些标签是模型本身或其变体生成的,无需额外监督——这意味着它们可能并不能反映事实,是嘈杂的、微弱的,但我们还是能从中提取到一些有用的学习信号。

Generative Models.

  • 生成式模型:生成式模型通常能学习到数据的分布p(x)p(x),进而认为可以将其迁移到带有标签yy 的有监督任务p(xy)p(x|y) 上。

Graph-based Methods.

  • 基于图的半监督方法:将数据样本点视为图的节点,基于图的半监督方法通常是目标是利用两个节点间的相似性将标签从有标签节点传播到无标签节点。

一致性正则化

一致性正则化方法中的一致性就是指对输入的(未标注)数据做一个微小扰动后,它的输出结果应该是不变的。即是所得的目标模型fθf_{\theta} 应该对这种扰动体现出一致性结果。比如对一张小猫的图像随机添加一些噪声之后,这张图还应该被识别成是一只猫。

具体来说,对于未标注数据样本xU\mathbf x\in \mathfrak U 和 它的扰动x^\hat{\mathbf x},可构造一致性约束损失:

Lcons=d(fθ(x),fθ(x^))\mathcal L_{\text{cons}}=d\big(f_{\theta}(\mathbf x),f_{\theta}(\hat{\mathbf x})\big)

其中d(,)d(·,·) 可以是 MSE、KL散度或是 JS散度。

一致性正则化方法的总损失往往被设计为带监督信号的损失和一致性约束损失的组合:

L=Lsup+λLcons.\mathcal L=\mathcal L_{\text{sup}}+\lambda\mathcal L_{\text{cons}}.

Pi-Model

Π-Model 来自 ICRL 2017 的论文《 Temporal Ensembling for Semi-Supervised Learning》,它是 Ladder net 和 Γ-Model 的简化版本(我们之后会介绍到它们),但是它将一致性正则化的核心策略体现了出来。

如图所示,该模型对输入数据进行数据增强(图像增强是在CV中更为常用的表达,相当于前述的某种扰动/加噪),然后共同输入到同一个带 Dropout 的神经网络中(这里带上 dropout 就是一种正则化策略,也是某种意义上的去噪),最后最小化两种输出的 MSE 损失。另一方面,如果该输入同时是带标签的,还需同时加上对应的损失(这里是分类问题,所以采用了交叉熵函数)。最终总的损失函数描述为:

这里对权重ww 的设置也是一个值得注意的 trick,这可以避免模型在训练初期就陷入不稳定的困境。

如果从无监督学习/自监督学习er 的视角来看的话,这个损失函数完全就相当于将自监督损失和有监督损失进行加权求和而已。

Temporal Ensembling

Π-Model 存在一个明显的二次反向传播问题,两个经过扰动的输入经过同一个网络后分别得到两个不同的输出,而损失函数需要对这两个输出进行比较,使用梯度下降法时,亦即在反向传播过程中会分别对两个输出都反向传播到模型里去,这将会影响模型的稳定性和训练速度。

为解决这个问题,Π-Model 的原作者其实已经在同一篇论文里给出了他的方案,提出了Temporal Ensembling。如图所示,Temporal Ensembling 在
Π-Model 的基础上,额外维护一个的输出y^\hat{y},用这个输出和当前xx 的输出y~\tilde{y} 进行 MSE 损失。这样一来就避免了二次传播带来的问题了。

注意,这里的y^\hat{y} 是输入xx 在上一次迭代(上一个 epoch 而不是上一批 batch )的输出结果,所以其实是同一个输入样本的输出,只是因为经过一轮训练和随机扰动之后,输出结果可能不一样了,但是因为来自同一个样本,并且基于一致性假设,它们之间还是应该尽可能接近的。

当然,上面这个只是便于解释,实际上作者并不只是存储了上一个 epoch 的输出,而是把迭代过程中所有的历史输出都参与进来了,并且使用指数移动平均exponential moving average (EMA) 来得到用于实际上和每一轮的输出进行比较的y^\hat{y}。具体公式为:

yema=αyema+(1α)y~1αty_{\text{ema}}=\frac{\alpha y_{\text{ema}}+(1-\alpha)\tilde{y}}{1-\alpha^t}

Temporal Ensembling 的缺点就是由于每个目标在每个epoch中只更新一次,因此学习到的信息总是以较慢的速度被纳入训练过程。

Mean Teacher

Temporal Ensembling 是对每轮的输出进行维护和EMA,而论文《Mean teachers are better role models:Weight-averaged consistency targets improve semi-supervised deep learning results》则是反过来对模型的参数进行EMA。

具体来说,Mean Teacher 将原本的同一个网络 copy 了一份,形成两个结果相同参数不同的孪生网络,分别命名为 Student 学生模型和 Teacher 教师模型。学生模型的权重参数由梯度下降更新得到,而教师模型本身不会自动更新,而是将学生模型连续训练若干轮次后产生的EMA作为自己的更新权重。

由于 EMA 提高了所有层的输出质量,而不仅仅是最后一层的输出,因此模型能够更好地表征中层乃至高层语义信息。这些方面使得了该模型相对于 Temporal ensemble 存在两个优势:首先,能够得到更准确的标签使学生和教师模型之间发生更快的反馈循环,从而产生更好的精度。其次,这种方法可以适用于大型数据集和在线学习。

VAT

Virtual Adversarial Training(VAT) 虚拟对抗训练的灵感顾名思义来源于原始的 Adversarial Training。原始的对抗训练的目标是通过在监督损失上升的方向为输入数据添加扰动,实现了在训练时就“加入过对抗样本”以增强模型的鲁棒性。

在半监督任务中,因为无标注样本没有标签,所以就没法计算监督信号上损失函数上升的方向,所以 VAT 创新性地在于一致性约束损失上进行对抗扰动。具体来说,对于给定的数据样本xx ,我们希望计算出对抗性扰动radvr_{\text{adv}},这个扰动将最大程度地改变模型的预测结果。即:

Lcons=d(fθ(x),fθ(x+radv))\mathcal{L}_{\text{cons}}=d\big(f_{\theta}(x),f_{\theta}(x+r_{\text{adv}})\big)

其中,

radv=argmaxr<ϵd(fθ(x),fθ(x+radv))ϵdd  where d=xd(fθ(x),fθ(x+radv))\begin{aligned} r_{\text{adv}}&=\arg\max_{||r||\lt\epsilon}d\big(f_{\theta}(x),f_{\theta}(x+r_{\text{adv}})\big)\\ &\approx\epsilon\frac{\nabla d}{||\nabla d||} \;\text{where }\nabla d=\nabla_xd\big(f_{\theta}(x),f_{\theta}(x+r_{\text{adv}})\big) \end{aligned}

相关源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def vat_loss(model, ul_x, ul_y, xi=1e-6, eps=2.5, num_iters=1):
# find r_adv
d = torch.Tensor(ul_x.size()).normal_()
for i in range(num_iters):
d = xi *_l2_normalize(d)
d = Variable(d.cuda(), requires_grad=True)
y_hat = model(ul_x + d)
delta_kl = kl_div_with_logit(ul_y.detach(), y_hat)
delta_kl.backward()
d = d.clone().cpu()
model.zero_grad()
d = _l2_normalize(d)
d = Variable(d.cuda())
r_adv = eps * d
# compute lds
y_hat = model(ul_x + r_adv.detach())
delta_kl = kl_div_with_logit(ul_y.detach(), y_hat)
return delta_kl

在 VAT 的 原论文 中将前述的一致性约束损失称为 LDS (local distributional smoothness).
更多介绍详见: https://www.twistedwg.com/2018/12/04/VAT.html

Ladder Net

Ladder Networks出自Semi-Supervised Learning with Ladder Networks,这篇由Rasmus等人撰写的文章被誉为2015年深度学习领域“五大佳文”之一。它的目的是让模型对噪声更具鲁棒性。对于每个不含标签的样本,在干净样本中训练好的模型会先给出一个代理标签,然后在加噪样本中训练好的模型会对这个结果进行预测。这样模型就能学会在嘈杂环境中准确提取特征,并在带标签数据集中进行预测

待更

代理标签法

代理标签法通常先使用有标签的数据来训练模型,然后利用该模型去预测无标签数据,并将置信度高的预测结果作为伪标签,参与模型训练和调整。模型同时在真实标签和伪标签数据上进行训练。

Self-training

1995年,Yarowsky等人提出 Self-training 算法,被视为半监督学习的基本雏形。顾名思义,Self-training就是在已有标签基础上,让模型对无标签数据进行分类,然后再把分类预测用于训练,并从中获取额外信息。

如上图所示,Self-training 首先通过带标注数据集L\mathfrak L 训练得到模型mm,然后用这个模型去预测无标注数据集U\mathfrak U 内的样本,如果预测结果达到预设的阈值τ\tau 就将模型预测结果加入到带标注数据集中,反复迭代模型mm 的权重。

该方法存在一个很明显的问题,那就是模型会对自己预测的结果过于“自信”,并且这种自信是盲目的,一旦结果其实是错误的,那么这种错误偏差就会一直积累在训练过程中不断放大。

Multi-view training

既然Self-training算法存在难以自查的缺点,那我们可以从不同数据的视角出发,构建不同的模型。理想情况下,这些多角度信息应该是互补的,因此模型也能通过合作来获得更好的性能。这里的多角度只是一个泛指,它可以是不同特征,也可以是不同的模型架构和不同数据集。而这种方法就被整体归类为 Multi-view training,后面要介绍的 Co-training、Tir-training 等方法都属于此类模型。

Co-training

1998年,CMU 的 Blum 和 Mitchell 在论文 《Combining Labeled and Unlabeled Data with Co-Training》中首次提出Co-training概念。这是多角度训练的一个经典算法,也提出了相对较强的假设:数据 LL 可以被两个条件独立的特征集 L1L^1  和 L2L^2  表示,且每个特征集都能够训练出一个强学习器m1m_1m2m_2。在训练时,一个模型会为另一个模型的输入提供代理标签,如下图所示。

参考

  1. https://www.ruder.io/semi-supervised/
  2. https://blog.csdn.net/by6671715/article/details/122218418
  3. https://zhuanlan.zhihu.com/p/144716386
  4. https://blog.csdn.net/qq_43456016/article/details/132638116
  5. https://blog.csdn.net/qq_41380292/article/details/119248049
  6. https://zhuanlan.zhihu.com/p/562284701
  7. http://www.twistedwg.com/2018/12/04/VAT.html
  8. https://blog.csdn.net/weixin_52261094/article/details/134981502
  9. https://zhuanlan.zhihu.com/p/138085660
  10. https://zhuanlan.zhihu.com/p/296809584
  11. https://blog.csdn.net/IRONFISHER/article/details/120328715
  12. https://zhuanlan.zhihu.com/p/252343352
  13. https://zhuanlan.zhihu.com/p/37747650