斯坦福tensorflow教程-tensorflow 实现线性回归代码结果

代码

import os
os.environ['TF_CCP_MIN_LOG_LEVEL'] = '2'
import time

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

import utils

DATA_FILE = 'data/birth_life_2010.txt'

# 步骤 1:读取数据
data, n_samples = utils.read_birth_life_data(DATA_FILE)
# print(data, n_samples)

# 步骤 2:给X和Y创建占位符
X = tf.placeholder(tf.float32, name='X')
Y = tf.placeholder(tf.float32, name='Y')

# 步骤 3:创建weights 和 bias ,初始化为0
w = tf.get_variable('weights', initializer=tf.constant(0.0))
b = tf.get_variable('bias', initializer=tf.constant(0.0))

# 步骤 4:创建模型
Y_predicted = w * X + b

# 步骤 5:使用方差squared error 作为损失函数 loss function
# 也可以使用其他平均方差作为损失函数 或者 Huber loss
loss = tf.square(Y - Y_predicted, name='loss')
# loss = utils.huber_loss(Y, Y_predicted)

# 步骤 6:使用梯度下降算法最小化损失, 学习率为0.001
optimizer = tf.train.GradientDescentOptimizer(
    learning_rate=0.001).minimize(loss)

start = time.time()
writer = tf.summary.FileWriter('./graphs/linear_reg', tf.get_default_graph())
with tf.Session() as sess:
    # 初始化变量
    sess.run(tf.global_variables_initializer())

    # 训练模型 100 epochs
    for i in range(100):
        total_loss = 0
        for x, y in data:
            _, l = sess.run([optimizer, loss], feed_dict={X: x, Y: y})
            total_loss += l
        print('Epoch {0}:{1}'.format(i, total_loss / n_samples))

    # 关闭witer
    writer.close()

    # 步骤 9:输出 w 和 b
    w_out, b_out = sess.run([w, b])
print('Took :%f seconds' % (time.time() - start))

# 画图
plt.plot(data[:, 0], data[:, 1], 'bo', label='Real data')
plt.plot(data[:, 0], data[:, 0] * w_out + b_out, 'r', label='Predicted data')
plt.legend()
plt.show()

结果

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Java Edge

AbstractList源码解析1 实现的方法2 两种内部迭代器3 两种内部类3 SubList 源码分析4 RandomAccessSubList 源码:AbstractList 作为 Lis

它实现了 List 的一些位置相关操作(比如 get,set,add,remove),是第一个实现随机访问方法的集合类,但不支持添加和替换

422
来自专栏项勇

笔记68 | 切换fragmengt的replace和add方法笔记

1444
来自专栏java闲聊

JDK1.8 ArrayList 源码解析

当运行 ArrayList<Integer> list = new ArrayList<>() ; ,因为它没有指定初始容量,所以它调用的是它的无参构造

1192
来自专栏Phoenix的Android之旅

Java 集合 Vector

List有三种实现,ArrayList, LinkedList, Vector, 它们的区别在于, ArrayList是非线程安全的, Vector则是线程安全...

662
来自专栏ml

朴素贝叶斯分类器(离散型)算法实现(一)

1. 贝叶斯定理:        (1)   P(A^B) = P(A|B)P(B) = P(B|A)P(A)   由(1)得    P(A|B) = P(B|...

3427
来自专栏alexqdjay

HashMap 多线程下死循环分析及JDK8修复

1K4
来自专栏xingoo, 一个梦想做发明家的程序员

AOE关键路径

这个算法来求关键路径,其实就是利用拓扑排序,首先求出,每个节点最晚开始时间,再倒退求每个最早开始的时间。 从而算出活动最早开始的时间和最晚开始的时间,如果这两个...

2507
来自专栏刘君君

JDK8的HashMap源码学习笔记

3008
来自专栏计算机视觉与深度学习基础

Leetcode 114 Flatten Binary Tree to Linked List

Given a binary tree, flatten it to a linked list in-place. For example, Given...

1938
来自专栏后端之路

LinkedList源码解读

List中除了ArrayList我们最常用的就是LinkedList了。 LInkedList与ArrayList的最大区别在于元素的插入效率和随机访问效率 ...

19510

扫码关注云+社区