斯坦福tensorflow教程(三) 线性和逻辑回归1. 线性回归:根据出生率来预测平均寿命

1. 线性回归:根据出生率来预测平均寿命

相信大家对线性回归很熟悉了,在这里不介绍了。我们将简单地构建一个神经网络,只包含一层,用来预测自变量X与因变量Y之间的线性关系。

  • 问题描述 下面图片是关于出生率和平均寿命关系的可视化图片,数据来自全世界不同的国家。你会发现一个有趣的结论:对于一个地区,儿童越多,平均寿命就越短。详细请见.

问题是我们可以量化X与Y之间的关系吗?换句话说,如果一个国家的出生率是X,平均寿命是Y,我们能够找到线性函数吗,例如Y=f(X)?如果我们量化这种关系,给出一个国家的出生率,我们就能预测这个国家的平均寿命。 完整数据集:https://datacatalog.worldbank.org/dataset/world-development-indicators 为了简便,我们仅适用2010年的数据集:https://github.com/chiphuyen/stanford-tensorflow-tutorials/blob/master/examples/data/birth_life_2010.txt

  • 数据描述 Name: Birth rate - life expectancy in 2010 X = birth rate. Type: float. Y = life expectancy. Type: foat. Number of datapoints: 190
  • 方法 首先,我们假设出生率和寿命的关系是线性的,这就意味着我们可以找到类似Y=wX+b这种方程。 为了计算出w和b,我们将在一层神经网络使用反向传播算法。对于损失函数,使用均方差,在训练每一轮之后,我们计算出实际值与预测值Y之间的均方差。 03_linreg_starter.py
# -*- coding: utf-8 -*-
# @Author: yanqiang
# @Date:   2018-05-10 22:31:37
# @Last Modified by:   yanqiang
# @Last Modified time: 2018-05-10 23:05:47
import tensorflow as tf
import utils
import matplotlib.pyplot as plt

DATA_FILE = 'data/birth_life_2010.txt'

# Step 1: read in data from the .txt file
# data is a numpy array of shape (190, 2), each row is a datapoint
data, n_samples = utils.read_birth_life_data(DATA_FILE)

# Step 2: create placeholders for X (birth rate) and Y (life expectancy)
X = tf.placeholder(tf.float32, name='X')
Y = tf.placeholder(tf.float32, name='Y')

# Step 3: create weight and bias, initialized to 0
w = tf.get_variable('weights', initializer=tf.constant(0.0))
b = tf.get_variable('bias', initializer=tf.constant(0.0))

# Step 4: construct model to predict Y (life expectancy from birth rate)
Y_predicted = w * X + b

# Step 5: use the square error as the loss function
loss = tf.square(Y - Y_predicted, name='loss')

# Step 6: using gradient descent with learning rate of 0.01 to minimize loss
optimizer = tf.train.GradientDescentOptimizer(
    learning_rate=0.001).minimize(loss)

with tf.Session() as sess:
    # Step 7: initialize the necessary variables, in this case, w and b
    sess.run(tf.global_variables_initializer())

    # Step 8: train the model
    for i in range(100):  # run 100 epochs
        for x, y in data:
            # Session runs train_op to minimize loss
            sess.run(optimizer, feed_dict={X: x, Y: y})
    # Step 9: output the values of w and b
    w_out, b_out = sess.run([w, b])


# uncomment the following lines to see the plot
plt.plot(data[:, 0], data[:, 1], 'bo', label='Real data')
plt.plot(data[:, 0], data[:, 0] * w_out + b_out, 'r', label='Predicted data')
plt.legend()
plt.show()

[utils.py以及以后其他代码都在github](https://github.com/chiphuyen/stanford-tensorflow-tutorials) 预测结果:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏开源FPGA

基于FPGA的均值滤波算法实现

  我们为了实现动态图像的滤波算法,用串口发送图像数据到FPGA开发板,经FPGA进行图像处理算法后,动态显示到VGA显示屏上,前面我们把硬件平台已经搭建完成了...

2225
来自专栏贾老师の博客

Reed-Solomon 编码算法

1722
来自专栏懒人开发

(2.8)James Stewart Calculus 5th Edition:Derivatives

如果极限存在,则可以简单写成 导数(衍生品?) (其实,自己对于导数 这个名词的翻译,为什么这样翻译一直不理解)

862
来自专栏进击的程序猿

经典检索算法:BM25原理

bm25 是一种用来评价搜索词和文档之间相关性的算法,它是一种基于概率检索模型提出的算法,再用简单的话来描述下bm25算法:我们有一个query和一批文档Ds,...

1271
来自专栏Ldpe2G的个人博客

图像素描风格生成

论文链接:Combining Sketch and Tone for Pencil Drawing Production

2267
来自专栏机器学习算法全栈工程师

手把手教你搭建目标检测器-附代码

翻译:刘威威 编辑:祝鑫泉 前 言 本文译自:[http://www.hackevolve.com/create-your...

3294
来自专栏CVer

OpenCV实战:人脸关键点检测(FaceMark)

Summary:利用OpenCV中的LBF算法进行人脸关键点检测(Facial Landmark Detection) Author: Amusi Dat...

8177
来自专栏Ldpe2G的个人博客

图像素描风格生成

1132
来自专栏量化投资与机器学习

深度学习Matlab工具箱代码注释之cnnapplygrads.m

%%========================================================================= %...

17610
来自专栏数值分析与有限元编程

有限元 | 经典梁单元刚度矩阵推导

经典欧拉梁单元不考虑剪切变形。基于试函数的能量方法(也称为泛函极值方法),基本要点是不需求解原微分方程,但需要假设一个满足位移边界条件的许可位移场。因此,如何寻...

3987

扫码关注云+社区