TensorFlow从0到1丨第十六篇 L2正则化对抗“过拟合”

前面的第十四篇 交叉熵损失函数——防止学习缓慢第十五篇 重新思考神经网络初始化从学习缓慢问题入手,尝试改进神经网络的学习。本篇讨论过拟合问题,并引入与之相对的L2正则化(Regularization)方法。

图1.overfitting,来源:http://blog.algotrading101.com

无处不在的过拟合

模型对于已知数据的描述适应性过高,导致对新数据的泛化能力不佳,我们称模型对于数据过拟合(overfitting)。

过拟合无处不在。

罗素的火鸡对自己的末日始料未及,曾真理般存在的牛顿力学沦为狭义相对论在低速情况下的近似,次贷危机破灭了美国买房只涨不跌的神话,血战钢锯岭的医疗兵Desmond也并不是懦夫。

凡是基于经验的学习,都存在过拟合的风险。动物、人、机器都不能幸免。

图2.Russell's Turkey,来源:http://chaospet.com/115-russells-turkey/g

谁存在过拟合?

对于一些离散的二维空间中的样本点,下面两条曲线谁存在过拟合?

图3.谁存在过拟合?

遵循奥卡姆剃刀的一派,主张“如无必要,勿增实体”。他们相信相对简单的模型泛化能力更好:上图中的蓝色直线,虽然只有很少的样本点直接落在它上面,但是不妨认为这些样本点或多或少包含一些噪声。基于这种认知,可以预测新样本也会在这条直线附近出现。

或许很多时候,倾向简单会占上风,但是真实世界的复杂性深不可测。虽然在自然科学中,奥卡姆剃刀被作为启发性技巧来使用,帮助科学家发展理论模型工具,但是它并没有被当做逻辑上不可辩驳的定理或者科学结论。总有简单模型表达不了,只能通过复杂模型来描述的事物存在。很有可能红色的曲线才是对客观世界的真实反映。

康德为了对抗奥卡姆剃刀产生的影响,创建了他自己的反剃刀:“存在的多样性不应被粗暴地忽视”。

阿尔伯特·爱因斯坦告诫:“科学理论应该尽可能简单,但不能过于简单。”

所以仅从上图来判断,一个理性的回答是:不知道。即使是如此简单的二维空间情况下,在没有更多的新样本数据做出验证之前,不能仅通过模型形式的简单或复杂来判定谁存在过拟合。

过拟合的判断

二维、三维的模型,本身可以很容易的绘制出来,当新的样本出现后,通过观察即可大致判断模型是否存在过拟合。

然而现实情况要复杂的多。对MNIST数字识别所采用的3层感知器——输入层784个神经元,隐藏层30个神经元,输出层10个神经元,就包含了23860个参数(23860 = 784 x 30 + 30 x 10 + 30 + 10),靠绘制模型来观察是不现实的。

最有效的方式是通过识别精度判断模型是否存在过拟合:比较模型对验证集和训练集的识别精度,如果验证集识别精度大幅低于训练集,则可以判断模型存在过拟合。

至于为什么是验证集而不是测试集,请复习第十一篇 74行Python实现手写体数字识别中“验证集与超参数”一节。

然而静态的比较已训练模型对两个集合的识别精度无法回答一个问题:过拟合是什么时候发生的?

要获得这个信息,就需要在模型训练过程中动态的监测每次迭代(Epoch)后训练集和验证集的识别精度,一旦出现训练集识别率继续上升而验证集识别率不再提高,就说明过拟合发生了。

这种方法还会带来一个额外的收获:确定作为超参数之一的迭代数(Epoch Number)的量级。更进一步,甚至可以不设置固定的迭代次数,以过拟合为信号,一旦发生就提前停止(early stopping)训练,避免后续无效的迭代。

过拟合监测

了解了过拟合的概念以及监测方法,就可以开始分析我们训练MNIST数字识别模型是否存在过拟合了。

所用代码:tf_16_mnist_loss_weight.py(链接https://github.com/EthanYuan/TensorFlow/blob/master/TF1_1/tf_16_mnist_loss_weight.py)。它在第十二篇 TensorFlow构建3层NN玩转MNIST代码的基础上,使用了交叉熵损失,以及1/sqrt(nin)权重初始化。

训练过程中,分别对训练集和验证集的识别精度进行了跟踪,如下图所示,其中红线代表训练集识别率,蓝线代表测试集识别率。图中显示,大约在第15次迭代前后,测试集的识别精度稳定在95.5%不再提高,而训练集的识别精度仍然继续上升,直到30次迭代全部结束后达到了98.5%,两者相差3%。

由此可见,模型存在明显的过拟合的特征。

图4.训练集和验证集识别精度(基于TensorBoard绘制)

过拟合的对策:L2正则化

对抗过拟合最有效的方法就是增加训练数据的完备性,但它昂贵且有限。另一种思路是减小网络的规模,但它可能会因为限制了模型的表达潜力而导致识别精度整体下降。

本篇引入L2正则化(Regularization),可以在原有的训练数据,以及网络架构不缩减的情况下,有效避免过拟合。L2正则化即在损失函数C的表达式上追加L2正则化项

L2正则化

上式中的C0代表原损失函数,可以替换成均方误差、交叉熵等任何一种损失函数表达式。

关于L2正则化项的几点说明:

  • 求和∑是对网络中的所有权重进行的;
  • λ(lambda)为自定义参数(超参数);
  • n是训练样本的数量(注意不是所有权重的数量!);
  • L2正则化并没有偏置参与;

L2正则化表达式暗示着一种倾向:训练尽可能的小的权重,较大的权重需要保证能显著降低原有损失C0才能保留。实际上L2正则化对于缓解过拟合的数学解释并不充分,更多是依据经验的。

L2正则化的实现

因为在原有损失函数中追加了L2正则化项,那么是不是得修改现有反向传播算法(BP1中有用到C的表达式)?答案是不需要。

C对w求偏导数,可以拆分成原有C0对w求偏导,以及L2正则项对w求偏导。前者继续利用原有的反向传播计算方法,而后者可以直接计算得到:

C对于偏置b求偏导保持不变:

基于上述,就可以得到权重w和偏置b的更新方法:

TensorFlow实现L2正则化

TensorFlow的最优化方法tf.train.GradientDescentOptimizer包办了梯度下降、反向传播,所以基于TensorFlow实现L2正则化,并不能按照上节的算法直接干预权重的更新,而要使用TensorFlow方式:

tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_2)

tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_3)

regularizer = tf.contrib.layers.l2_regularizer(scale=5.0/50000)

reg_term = tf.contrib.layers.apply_regularization(regularizer)

loss = (tf.reduce_mean(

tf.nn.sigmoid_cross_entropy_with_logits(labels=y_,

logits=z_3)) + reg_term)

对上述代码的一些说明:

  • 将网络中所有层中的权重,依次通过tf.add_to_collectio加入到tf.GraphKeys.WEIGHTS中;
  • 调用tf.contrib.layers.l2_regularizer生成L2正则化方法,注意所传参数scale=λ/n(n为训练样本的数量);
  • 调用tf.contrib.layers.apply_regularization来生成损失函数的L2正则化项reg_term,所传第一个参数为上面生成的正则化方法,第二个参数为none时默认值为tf.GraphKeys.WEIGHTS;
  • 最后将L2正则化reg_term项追加到损失函数表达式;

在模型和训练设置均保持不变(除了学习率做了调整以适应正则化项的介入)的情况下,向原有损失函数追加L2正则化项后,重新运行训练。跟踪训练集和验证集的识别精度,如下图所示。图中显示,在整个30次迭代中,训练集和验证集的识别率均持续上升(都超过95%),最终两者的差距控制在0.5%,过拟合程度显著的减轻了。

图6.L2正则化(基于TensorBoard绘制)

附完整代码

import argparse
import sys
from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tf
FLAGS = None


def main(_):
    # Import data
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True,
                                      validation_size=10000)

    # Create the model
    x = tf.placeholder(tf.float32, [None, 784])
    W_2 = tf.Variable(tf.random_normal([784, 100]) / tf.sqrt(784.0))
    '''W_2 = tf.get_variable(
        name="W_2",
        regularizer=regularizer,
        initializer=tf.random_normal([784, 30], stddev=1 / tf.sqrt(784.0)))'''
    b_2 = tf.Variable(tf.random_normal([100]))
    z_2 = tf.matmul(x, W_2) + b_2
    a_2 = tf.sigmoid(z_2)

    W_3 = tf.Variable(tf.random_normal([100, 10]) / tf.sqrt(100.0))
    '''W_3 = tf.get_variable(
        name="W_3",
        regularizer=regularizer,
        initializer=tf.random_normal([30, 10], stddev=1 / tf.sqrt(30.0)))'''
    b_3 = tf.Variable(tf.random_normal([10]))
    z_3 = tf.matmul(a_2, W_3) + b_3
    a_3 = tf.sigmoid(z_3)

    # Define loss and optimizer
    y_ = tf.placeholder(tf.float32, [None, 10])

    tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_2)
    tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_3)
    regularizer = tf.contrib.layers.l2_regularizer(scale=5.0 / 50000)
    reg_term = tf.contrib.layers.apply_regularization(regularizer)

    loss = (tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3)) +
        reg_term)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)

    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()

    correct_prediction = tf.equal(tf.argmax(a_3, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    scalar_accuracy = tf.summary.scalar('accuracy', accuracy)
    train_writer = tf.summary.FileWriter(
        'MNIST/logs/tf16_reg/train', sess.graph)
    validation_writer = tf.summary.FileWriter(
        'MNIST/logs/tf16_reg/validation')

    # Train
    best = 0
    for epoch in range(30):
        for _ in range(5000):
            batch_xs, batch_ys = mnist.train.next_batch(10)
            sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
        sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
        # Test trained model
        accuracy_currut_train = sess.run(accuracy,
                                         feed_dict={x: mnist.train.images,
                                                    y_: mnist.train.labels})

        accuracy_currut_validation = sess.run(
            accuracy,
            feed_dict={x: mnist.validation.images,
                       y_: mnist.validation.labels})

        sum_accuracy_train = sess.run(
            scalar_accuracy,
            feed_dict={x: mnist.train.images,
                       y_: mnist.train.labels})

        sum_accuracy_validation = sess.run(
            scalar_accuracy,
            feed_dict={x: mnist.validation.images,
                       y_: mnist.validation.labels})

        train_writer.add_summary(sum_accuracy_train, epoch)
        validation_writer.add_summary(sum_accuracy_validation, epoch)

        print("Epoch %s: train: %s validation: %s"
              % (epoch, accuracy_currut_train, accuracy_currut_validation))
        best = (best, accuracy_currut_validation)[
            best <= accuracy_currut_validation]

    # Test trained model
    print("best: %s" % best)
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='../MNIST/',
                        help='Directory for storing input data')
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

下载 tf_16_mnist_loss_weight_reg.py(链接https://github.com/EthanYuan/TensorFlow/blob/master/TF1_1/tf_16_mnist_loss_weight_reg.py)

原文发布于微信公众号 - 人工智能LeadAI(atleadai)

原文发表时间:2017-08-29

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

用 Keras 搭建 GAN:图像去模糊中的应用(附代码)

2014年 Ian Goodfellow 提出了生成对抗网络(GAN)。这篇文章主要介绍在Keras中搭建GAN实现图像去模糊。所有的Keras代码可点击这里。

802
来自专栏null的专栏

可扩展机器学习——梯度下降(Gradient Descent)

注:这是一份学习笔记,记录的是参考文献中的可扩展机器学习的一些内容,英文的PPT可见参考文献的链接。这个只是自己的学习笔记,对原来教程中的内容进行了梳理,有些图...

3517
来自专栏专知

图像和文本的融合表示学习——Text2Image和Image2Text

【导读】图像和文本之间的相互转换涉及到图像的场景识别与理解、目标的检测和识别、图像融合等,它可以使得计算机具有“看图说话”、“看书作图”的能力,可以说是图像理解...

712
来自专栏绿巨人专栏

机器学习实战 - 读书笔记(08) - 预测数值型数据:回归

34711
来自专栏深度学习

循环神经网络

循环神经网络的神经网络体系结构,它针对的不是自然语言数据,而是处理连续的时间数据,如股票市场价格。在本文结束之时,你将能够对时间序列数据中的模式进行建模,以对未...

3708
来自专栏奇点大数据

Pytorch神器(6)

作者介绍:高扬,奇点大数据创始人。技术畅销书《白话大数据与机器学习》、《白话深度学习与Tensorflow》、《数据科学家养成手册》著书人。重庆工商大学研究生导...

1083
来自专栏机器之心

教程 | 经得住考验的「假图片」:用TensorFlow为神经网络生成对抗样本

选自arXiv 作者:Anish Athalye 机器之心编译 参与:李泽南 用于识别图片中物体的神经网络可以被精心设计的对抗样本欺骗,而这些在人类看起来没有什...

3678
来自专栏量子位

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

安妮 编译自 O’Reilly 量子位出品 | 公众号 QbitAI 生成式对抗网络是20年来机器学习领域最酷的想法。 ——Yann LeCun 自从两年前...

4143
来自专栏AI研习社

Inception Network 各版本演进史

Inception 网络是卷积神经网络 (CNN) 分类器发展中的一个重要里程碑。在 inception 之前, 大多数流行的 CNN 只是将卷积层堆叠得越来越...

923
来自专栏社区的朋友们

跬步神经网络:基本模型解析

最近开始看NN,很多疑问。微积分什么的早丢了,边看边查,记录备忘。 本篇主要是针对最基本的网络模型,解释反向传播(backpropagation)原理。

2792

扫描关注云+社区