tensorflow dropout用法

     dropout(x, keep_prob, noise_shape=None, seed=None, name=None)
  • 函数作用就是使得矩阵x的一部分(概率大约为keep_prob)变为0,其余变为element/keep_prob,
  • noise_shape可以使得矩阵x一部分行全为0或者部分列全为0
  • 用在tensorflow中使得部分神经元随机为0不参与训练,如果算法过拟合了,可以试试这个办法。
with tf.Session() as sess:
    d = tf.to_float(tf.reshape(tf.range(1,17),[4,4]))
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.shape(d)))
    print(sess.run(d[0]))
    
    # 矩阵有一半左右的元素变为element/0.5,其余为0
    dropout_a44 = tf.nn.dropout(d, 0.5, noise_shape = None)
    result_dropout_a44 = sess.run(dropout_a44)
    print(result_dropout_a44)

    # 行大小相同4,行同为0,或同不为0
    dropout_a41 = tf.nn.dropout(d, 0.5, noise_shape = [4,1])
    result_dropout_a41 = sess.run(dropout_a41)
    print(result_dropout_a41)
    
    # 列大小相同4,列同为0,或同不为0
    dropout_a24 = tf.nn.dropout(d, 0.5, noise_shape = [1,4])
    result_dropout_a24 = sess.run(dropout_a24)
    print(result_dropout_a24)
    #不相等的noise_shape只能为1

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏HansBug's Lab

算法模板——线段树6(二维线段树:区域加法+区域求和)(求助phile)

实现功能——对于一个N×M的方格,1:输入一个区域,将此区域全部值作加法;2:输入一个区域,求此区域全部值的和 其实和一维线段树同理,只是不知道为什么速度比想象...

48350
来自专栏技术随笔

[RNN] Simple LSTM代码实现 & BPTT理论推导

51340
来自专栏从流域到海域

使用Python生成一张用于登陆验证的字符图片

Python Pillow库的简单使用 使用Python生成一张用于登陆验证的字符图片, 代码使用了Pillow,Anaconda已经默认安装此库,如果你...

22190
来自专栏人工智能

如何使用 scikit-learn 为机器学习准备文本数据

文本数据需要特殊处理,然后才能开始将其用于预测建模。

93880
来自专栏xingoo, 一个梦想做发明家的程序员

布线问题-分支限界法

问题描述:   印刷电路板不限区域划分成n*m个方格阵列。如下图所示 ?   精确的电路布线问题要求确定连接方格a的中点,到连接方格b的中点的最短布线方案。  ...

228100
来自专栏PPV课数据科学社区

【学习】ggplot2绘图入门系列之二:图层控制与直方图

如前文所述,ggplot2使用图层将各种图形元素逐步添加组合,从而形成最终结果。第一层必须是原始数据层,其中data参数控制数据来源,注意数据形式...

26960
来自专栏程序生活

Leetcode-Easy 887. Projection Area of 3D Shapes

当时自己没有想到好办法,就是按部就班的分别求三个面的面积,注意求xy的面积的时候需要考虑grid[i][j]值是否为0

10620
来自专栏zingpLiu

机器学习之线性代数

  完整内容已上传到github:https://github.com/ZingP/machine-learning/tree/master/linear_al...

20310
来自专栏鸿的学习笔记

写给开发者的机器学习指南(十三)

在我们实际使用支持向量机(SVM)之前,我先简要介绍一下SVM是什么。 基本SVM是一个二元分类器,它通过选取代表数据点之间最大间隔的超平面将数据集分成2部分。...

8510
来自专栏CVer

TensorFlow从入门到精通 | 01 简单线性模型(上篇)

[TensorFlow从入门到精通] 01 简单线性模型(上)介绍了TensorFlow如何加载MNIST、定义数据维度、TensorFlow图、占位符变量和O...

10920

扫码关注云+社区

领取腾讯云代金券