tensorflow | 重新学习 | 了解graph 和 Session

源于工作需要,重新学习tensorflow,好久未使用,忘记的差不多了。


tensorflow的基础框架

tensorflow中是由Graph和Session组成,Graph负责将计算架构搭建起来,Session则负责将数据输入、执行模型、产出结果。分工明确,严格分割开来。

其中,Graph和Session过程也可以细分为一下几个部分:

1、 数据准备

这部分是最起始的部分,将数据集从磁盘读取

2、 定义placeholder容器

placeholder用于存储变量,自变量和因变量。定义如下:

tf.placeholder(dtype, shape=None, name=None)

dtype :定义数据类型; shape:定义维度; name:定义名称。

例子:

batch_size = 128
X = tf.placeholder(tf.float32, [batch_size, 784], name='X_placeholder') 
Y = tf.placeholder(tf.int32, [batch_size, 10], name='Y_placeholder')

3、 初始化参数/权重

这部分是定义权重变量,模型中涉及到的参数需要提前定义。

w = tf.Variable(<initial-value>, name=<optional-name>)

包含初始化值和命名两部分。 例子:

W = tf.Variable(tf.random_normal([1]), name='weight')
b = tf.Variable(tf.random_normal([1]), name='bias')

Wb就是定义的参数 更加深入的研究Variable请看Variable帮助文档

4、 计算预测结果

Y_pred = tf.add(tf.multiply(X, W), b)

通过估计的参数来计算预测值

5、 计算损失函数值

为了估计模型的参数,一般通过定义损失来估计。

loss = tf.square(Y - Y_pred, name='loss')

当然,根据自己数据来定义损失函数最为恰当,这里仅仅给出案例。

6、 初始化optimizer

前面的定义了模型的结果,这部做模型求解。常用的求解算法有很多,要结合自己的数据来定义。通过情况下,我们需要提前定义学习率。

learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

这里是使用梯度下降法来优化模型,损失函数最小。

7、 在session里执行graph

首先,定义迭代次数或迭代停止条件。

# xs是指输入的数据
n_samples = xs.shape[0]
with tf.Session() as sess:
    # 初始化所有变量
    sess.run(tf.global_variables_initializer()) 
    # 记录日志
    writer = tf.summary.FileWriter('./graphs/linear_reg', sess.graph)  

    # 训练模型,这里定义训练50次
    for i in range(50):
        total_loss = 0
        for x, y in zip(xs, ys):
            # 通过feed_dic把数据灌进去
            _, l = sess.run([optimizer, loss], feed_dict={X: x, Y:y}) 
            total_loss += l
        if i%5 ==0:
            print('Epoch {0}: {1}'.format(i, total_loss/n_samples))

    # 关闭writer
    writer.close() 

    # 取出w和b的值
    W, b = sess.run([W, b]) 

这样就可以运行了,同时将最后的参数打印出来。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器学习原理

深度学习——RNN(3)

2795
来自专栏ATYUN订阅号

四个用于Keras的很棒的操作(含代码)

Keras是最广泛使用的深度学习框架之一。它在易于使用的同时,在性能方面也与TensorFlow,Caffe和MXNet等更复杂的库相当。除非你的应用程序需要一...

1564
来自专栏码云1024

游戏中的人物是如何寻路的?

2167
来自专栏每日一篇技术文章

OpenGL ES _ 着色器_纹理图像

玩过游戏的同学们,都知道在游戏人物身上穿的那个叫皮肤,专业点将那个就叫做纹理图像。GLSL 支持在顶点和片段着色器使用纹理图像。

2003
来自专栏机器学习和数学

[编程经验] SciPy之图像处理小结

Python中可以处理图像的module有很多个,比如Opencv,Matplotlib, Numpy, PIL以及今天要分享的SciPy。其他几个后续都会总结...

8027
来自专栏AI研习社

如何在 Keras 中从零开始开发一个神经机器翻译系统?

机器翻译是一项具有挑战性的任务,包含一些使用高度复杂的语言知识开发的大型统计模型。 神经机器翻译的工作原理是——利用深层神经网络来解决机器翻译问题。 在本教程...

37212
来自专栏月色的自留地

《连连看》算法c语言演示(自动连连看)

2949
来自专栏WindCoder

TensorFlow入门:一篇机器学习教程

TensorFlow是一个由Google创建的开源软件库,用于实现机器学习和深度学习系统。这两个名称包含一系列强大的算法,它们共享一个共同的挑战——让计算机学习...

3261
来自专栏PPV课数据科学社区

使用R语言进行异常检测

本文结合R语言,展示了异常检测的案例,主要内容如下: (1)单变量的异常检测 (2)使用LOF(local outlier factor,局部异常因子)进行异常...

3736
来自专栏数据科学与人工智能

【算法】 Keras 四步工作流程

Francois Chollet在他的“深度学习Python”一书中概述了与Keras开发神经网络的概述。 通过本书前面的一个简单的MNIST示例,Cholle...

1332

扫码关注云+社区

领取腾讯云代金券