前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >python|线性回归问题

python|线性回归问题

作者头像
算法与编程之美
发布2020-04-15 15:31:53
8940
发布2020-04-15 15:31:53
举报
文章被收录于专栏:算法与编程之美

问题描述

线性回归是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。可以解释为,利用线性回归方程的最小平方函数对一个或多个自变量和因变量之间的关系进行数学建模。这种函数是一个或多个称为回归系数的模型参数的线性组合。其中只有一个自变量的情况称为简单回归,大于一个自变量情况的叫做多元回归。本文将介绍一个二元线性回归问题。

解决方案

1 线性回归原理

回归问题研究的是因变量和自变量之间的关系,在中学阶段学习过以一个二元一次方程y = w*x + b 这样一条直线对线性关系的表述。这样便可以通过几组确定的数据来得到一个精确的求解结果b和w的值。但实际上,由于模型本身的未知性和采集数据偏差等情况,很难精确的求解这两个值。因此,需要用大量样本数据来不断更新演算最终求出一个与真实值最为接近的值。

2 确定b、w最优解

通过数学知识可以知道,函数梯度的方向永远指向函数值变大的方向(如下图1所示),所以,如果向着函数梯度方向的反方向逐步寻查就能得到函数的最小值。

图1 函数梯度方向

因此,需要构造一个loss函数:

对于这个函数模型,又利用:

这两个公式来对b和w进行演算更新,使得通过这个模型求解出来的y’值最为接近真实值y,且最终得到的b、w就是最优解。(注:这里的lr是一个学习率learningrate,可以把它理解为衰减系数,是为了避免b、w在更新时,跨度太大而跳过最小值。)

3 算法流程及代码

(1)构建一个线性模型,遍历points数组,对数组数据进行一个迭代求和算平均值。代码如下:

代码语言:javascript
复制

import numpy as np

def computer_error_for_line_points(b,w,points):

    totalError = 0

    for i in range(0,len(points)):

        x = points[i][0]

        y = points[i][1]

        totalError += ((w * x + b) - y) ** 2

    return totalError / float(len(points))

(2)初始化b、w的值,通过对b、w求偏导来对b、w进行迭代更新。代码如下:

代码语言:javascript
复制
def step_gradient(b_current,w_current,points,learningRate):

    b_gradient = 0

    w_gradient = 0

    N = float(len(points))

    for i in range(0,len(points)):

        x = points[i][0]

        y = points[i][1]

        b_gradient += (2/N) * ((w_current * x + b_current) - y)

        w_gradient += (2/N) * x * ((w_current * x + b_current) - y)

    new_b = b_current - (learningRate * b_gradient)

    new_w = w_current - (learningRate * w_gradient)

    return [new_b,new_w]

(3)重复将新的b’、w’的值赋值给b、w,多次循环最终返回一个最优的b、w值。代码如下:

代码语言:javascript
复制

def gradient_descent_runner(points,starting_b,starting_w,learning_rate,num_iterations):

    b = starting_b

    w = starting_w

    for i in range(num_iterations):

        b,w = step_gradient(b,w,np.array(points),learning_rate)

    return [b,w]

(4)最后,定义一个运行方法。(注:在做模型演算时,往往会先将数据处理成矩阵,一般将矩阵存储为一个.csv文件,放在与.py文件同一级文件夹下。在使用时类似points = np.genfromtxt('data.csv',delimiter=",")进行文件读取。)代码如下:

代码语言:javascript
复制

def run():

    points = np.genfromtxt('data.csv',delimiter=",")

    learning_rate = 0.0001

    initial_b = 0

    initial_w = 0

    num_iterations =1000

    print('Starting gradient descent at b = {0},w = {1},error = {2}'.format(initial_b,initial_w,computer_error_for_line_points(initial_b,initial_w,points)))

    print('Running...')

    [b,w] = gradient_descent_runner(points,initial_b,initial_w,learning_rate,num_iterations)

    print('After {0} iterations b = {1},w = {2},error = {3}'.format(num_iterations,b,w,computer_error_for_line_points(b,w,points)))

4 运行结果

运行之后,得到的b、w就是最优解。

图2 运行结果

结语

通过这样一个简单的线性回归问题,可以初步感受到借助python语言来解决一个数据分析处理的问题的便携性和功能性是十分强大的。不仅如此,在面对其他更为复杂的数学分析问题,利用编程和建立数学模型来解决会十分方便和高效。

END

主 编 | 王文星

责 编 | 吴怡辰

where2go 团队

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-04-04,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法与编程之美 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档