tensorflow学习笔记(三十八):损失函数加上正则项

tensorflow Regularizers

在损失函数上加上正则项是防止过拟合的一个重要方法,下面介绍如何在TensorFlow中使用正则项.

tensorflow中对参数使用正则项分为两步: 1. 创建一个正则方法(函数/对象) 2. 将这个正则方法(函数/对象),应用到参数上

如何创建一个正则方法函数

tf.contrib.layers.l1_regularizer(scale, scope=None)

返回一个用来执行L1正则化的函数,函数的签名是func(weights). 参数:

  • scale: 正则项的系数.
  • scope: 可选的scope name

tf.contrib.layers.l2_regularizer(scale, scope=None)

返回一个执行L2正则化的函数.

tf.contrib.layers.sum_regularizer(regularizer_list, scope=None)

返回一个可以执行多种(个)正则化的函数.意思是,创建一个正则化方法,这个方法是多个正则化方法的混合体.

参数: regularizer_list: regulizer的列表

已经知道如何创建正则化方法了,下面要说明的就是如何将正则化方法应用到参数上

应用正则化方法到参数上

tf.contrib.layers.apply_regularization(regularizer, weights_list=None)

先看参数

  • regularizer:就是我们上一步创建的正则化方法
  • weights_list: 想要执行正则化方法的参数列表,如果为None的话,就取GraphKeys.WEIGHTS中的weights.

函数返回一个标量Tensor,同时,这个标量Tensor也会保存到GraphKeys.REGULARIZATION_LOSSES中.这个Tensor保存了计算正则项损失的方法.

tensorflow中的Tensor是保存了计算这个值的路径(方法),当我们run的时候,tensorflow后端就通过路径计算出Tensor对应的值

现在,我们只需将这个正则项损失加到我们的损失函数上就可以了.

如果是自己手动定义weight的话,需要手动将weight保存到GraphKeys.WEIGHTS中,但是如果使用layer的话,就不用这么麻烦了,别人已经帮你考虑好了.(最好自己验证一下tf.GraphKeys.WEIGHTS中是否包含了所有的weights,防止被坑)

其它

在使用tf.get_variable()tf.variable_scope()的时候,你会发现,它们俩中有regularizer形参.如果传入这个参数的话,那么variable_scope内的weights的正则化损失,或者weights的正则化损失就会被添加到GraphKeys.REGULARIZATION_LOSSES中. 示例:

import tensorflow as tf
from tensorflow.contrib import layers

regularizer = layers.l1_regularizer(0.1)
with tf.variable_scope('var', initializer=tf.random_normal_initializer(), 
regularizer=regularizer):
    weight = tf.get_variable('weight', shape=[8], initializer=tf.ones_initializer())
with tf.variable_scope('var2', initializer=tf.random_normal_initializer(), 
regularizer=regularizer):
    weight2 = tf.get_variable('weight', shape=[8], initializer=tf.ones_initializer())

regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

参考资料

https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.layers/regularizers

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Petrichor的专栏

tensorflow编程: Constants, Sequences, and Random Values

  注意: start 和 stop 参数都必须是 浮点型;     取值范围也包括了 stop; tf.lin_space 等同于 tf.lins...

432
来自专栏人工智能

Tensorflow下Char-RNN项目代码详解

前言 Char-RNN,字符级循环神经网络,出自于Andrej Karpathy写的The Unreasonable Effectiveness of Recu...

50910
来自专栏marsggbo

Tensorflow datasets.shuffle repeat batch方法

由结果我们可以知道TensorFlow能很好地帮我们自动处理最后一个batch的数据。

2782
来自专栏有趣的Python

py编程技巧-2.5-如何在一个for语句中迭代多个可迭代队象(并行&串行)?

实际案例: 某班学生期末考试成绩,语文,数学,英语分布存储在三个列表当中 同时迭代三个列表,计算每个学生的总分 某年级有四个班,某次考试每班英语成绩分布存储在...

3556
来自专栏深度学习那些事儿

pytorch中autograd以及hook函数详解

pytorch中的Autograd mechanics(自动求梯度机制)是实现前向以及后向反馈运算极为重要的一环,pytorch官方专门针对这个机制进行了一个版...

2235
来自专栏专知

【干货】计算机视觉实战系列01——用Python做图像处理

【导读】在当今互联网飞速发展的社会中,数量庞大的图像和视频充斥着我们的生活,让我们需要对图片进行检索、分类等操作时,利用人工手段显然是不现实的,于是,计算机视觉...

63912
来自专栏人工智能

如何使用 scikit-learn 为机器学习准备文本数据

文本数据需要特殊处理,然后才能开始将其用于预测建模。

7898
来自专栏禹都一只猫博客

TensorFlow小入门

1935
来自专栏图形学与OpenGL

机械版CG 实验3 变换

进一步掌握二维、三维变换的数学知识、变换原理、变换种类、变换方法;进一步理解采用齐次坐标进行二维、三维变换的必要性;利用OpenGL实现二维、三维图形变换。

301
来自专栏PPV课数据科学社区

使用R语言进行异常检测

本文结合R语言,展示了异常检测的案例,主要内容如下: (1)单变量的异常检测 (2)使用LOF(local outlier factor,局部异常因子)进行异常...

2916

扫码关注云+社区