最近要开始使用Transformer去做一些事情了,特地把与此相关的知识点记录下来,构建相关的、完整的知识结构体系,
以下是要写的文章,本文是这个系列的第十五篇:
对于Bert的改进,大家可能注意到我的句式都是Bert很好很强大,但是……,今天这篇也不例外,只是改进的方向有点出乎我的意料。
Bert很好很强大,但是它的训练太低效,为什么呢?我们来回顾一下,在训练Bert的时候,在输入上,把15%的词语给替换成Mask,然后这其中有80%是Mask,有10%是替换成其他词语,最后剩下10%保持原来的词语。
可以看到,Bert的训练中,每次相当于只有15%的输入上是会有loss的,而其他位置是没有的,这就导致了每一步的训练并没有被完全利用上,导致了训练速度慢。
于是,就有了Electra,Electra是Efficiently Learning an Encoder that Classifies Token Replacement Accurately的缩写。从名字中可以看出,相对于Bert的去预测Mask的正确值,Electra则是去预测Token是不是被替换了。那么具体是如何做的呢?
如下图所示,首先会训练一个生成器来生成假样本,然后Electra去判断每个token是不是被替换了。
大家看到这张图,可能会想到,这不是对抗生成网络咩?其实不是的,不是的奥秘就在损失函数上。
再来仔细的看一下算法流程,首先,输入经过随机选择设置为[MASK],然后输入给Generator,Generator负责把[MASK]变成替换过的词。
但此时Generator并不像对抗神经网络那样需要等Discriminator中传回来的梯度,而是像Bert一样那样去尝试预测正确的词语,从而计算损失。这就是Electra不是GAN的根本原因。
因此,极端情况下如果Generator的预测准确率是100%,那么Discriminator就学习不到什么了,因为所有的token都是正确词语。但所幸,Generator一般是个小模型,所以效果达不到这么高,同时,Generator刚开始就要和Discriminator联合训练,所以刚开始也不会达到这么高。
Discriminator则是去预测每个位置上的词语是不是被替换过。Discriminator是训练完之后我们得到的预训练模型,Generator在训练完之后就没有用了。
Electra另外一点和对抗生成网络不同的是,如果Generator生成的是和原始输入一样的token,那么这个token会被当做是没有替换,而在对抗生成网络中所有来源于生成器的数据都是fake数据。
用公式来解释上述过程,如下:
损失函数如下:
所以,最后的损失函数如下:
注意到GAN的损失函数是minG maxDV(D, G),跟这个损失函数大有不同。
终极目标就是能在计算量等同的情况下,超过同等体量的模型效果。
当然,还有很多其他的设置:
实验结果如下,可以看到,Discriminator的宽度为768,Generator宽度为256时效果最好。同时,GAN和Two-stage训练都不如Electra。
实验还分别比较了小模型和大模型。可以看到,无论是大模型还是小模型,Electra都可以超过Bert。
为了验证Electra的提升到底是哪里来的,做了一些Bert和Electra中间设置的一些模型的实验,包括:
结果如下,可以看到,All-tokens MLM提升最大,但还可以看到,Electra相对于Bert,不仅仅在训练速度上有提升,在最终的结果上也有提升。
如下图所示,Electra可以达到Bert达不到的高度。
勤思考,多提问是Engineer的良好品德。
关注公众号【雨石记】,答案会在后续的文章中。