首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >基于梯度下降算法求解线性回归

基于梯度下降算法求解线性回归

作者头像
OpenCV学堂
发布2018-04-04 11:10:13
6290
发布2018-04-04 11:10:13
举报

基于梯度下降算法求解线性回归

一:线性回归(Linear Regression)

梯度下降算法在机器学习方法分类中属于监督学习。利用它可以求解线性回归问题,计算一组二维数据之间的线性关系,假设有一组数据如下下图所示

其中X轴方向表示房屋面积、Y轴表示房屋价格。我们希望根据上述的数据点,拟合出一条直线,能跟对任意给定的房屋面积实现价格预言,这样求解得到直线方程过程就叫线性回归,得到的直线为回归直线,数学公式表示如下:

二:梯度下降

三:代码实现各步

训练数据读入

List<DataItem> items = new ArrayList<DataItem>();

File f = new File(fileName);

try {

    if (f.exists()) {

        BufferedReader br = new BufferedReader(new FileReader(f));

        String line = null;

        while((line = br.readLine()) != null) {

            String[] data = line.split(",");

            if(data != null && data.length == 2) {

                DataItem item = new DataItem();

                item.x = Integer.parseInt(data[0]);

                item.y = Integer.parseInt(data[1]);

                items.add(item);

            }

        }

        br.close();

    }

} catch (IOException ioe) {

    System.err.println(ioe);

}

return items;

归一化处理

float min = 100000;

float max = 0;

for(DataItem item : items) {

    min = Math.min(min, item.x);

    max = Math.max(max, item.x);

}

float delta = max - min;

for(DataItem item : items) {

    item.x = (item.x - min) / delta;

}

梯度下降

int repetion = 1500;

float learningRate = 0.1f;

float[] theta = new float[2];

Arrays.fill(theta, 0);

float[] hmatrix = new float[items.size()];

Arrays.fill(hmatrix, 0);

int k=0;

float s1 = 1.0f / items.size();

float sum1=0, sum2=0;

for(int i=0; i<repetion; i++) {

    for(k=0; k<items.size(); k++ ) {

        hmatrix[k] = ((theta[0] + theta[1]*items.get(k).x) - items.get(k).y);

    }


    for(k=0; k<items.size(); k++ ) {

        sum1 += hmatrix[k];

        sum2 += hmatrix[k]*items.get(k).x;

    }


    sum1 = learningRate*s1*sum1;

    sum2 = learningRate*s1*sum2;


    // 更新 参数theta

    theta[0] = theta[0] - sum1;

    theta[1] = theta[1] - sum2;

}


return theta;

价格预言 - theta表示参数矩阵

float result = theta[0] + theta[1]*input;return result;

线性回归Plot绘制

int w = 500;

int h = 500;

BufferedImage plot = new BufferedImage(w, h, BufferedImage.TYPE_INT_ARGB);

Graphics2D g2d = plot.createGraphics();

g2d.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);

g2d.setPaint(Color.WHITE);

g2d.fillRect(0, 0, w, h);

g2d.setPaint(Color.BLACK);

int margin = 50;

g2d.drawLine(margin, 0, margin, h);

g2d.drawLine(0, h-margin, w, h-margin);

float minx=Float.MAX_VALUE, maxx=Float.MIN_VALUE;

float miny=Float.MAX_VALUE, maxy=Float.MIN_VALUE;

for(DataItem item : series1) {

    minx = Math.min(item.x, minx);

    maxx = Math.max(maxx, item.x);

    miny = Math.min(item.y, miny);

    maxy = Math.max(item.y, maxy);

}

for(DataItem item : series2) {

    minx = Math.min(item.x, minx);

    maxx = Math.max(maxx, item.x);

    miny = Math.min(item.y, miny);

    maxy = Math.max(item.y, maxy);

}

// draw X, Y Title and Aixes

g2d.setPaint(Color.BLACK);

g2d.drawString("价格(万)", 0, h/2);

g2d.drawString("面积(平方米)", w/2, h-20);


// draw labels and legend

g2d.setPaint(Color.BLUE);

float xdelta = maxx - minx;

float ydelta = maxy - miny;

float xstep = xdelta / 10.0f;

float ystep = ydelta / 10.0f;

int dx = (w - 2*margin) / 11;

int dy = (h - 2*margin) / 11;


// draw labels

for(int i=1; i<11; i++) {

    g2d.drawLine(margin+i*dx, h-margin, margin+i*dx, h-margin-10);

    g2d.drawLine(margin, h-margin-dy*i, margin+10, h-margin-dy*i);

    int xv = (int)(minx + (i-1)*xstep);

    float yv = (int)((miny + (i-1)*ystep)/10000.0f);

    g2d.drawString(""+xv, margin+i*dx, h-margin+15);

    g2d.drawString(""+yv, margin-25, h-margin-dy*i);

}


// draw point

g2d.setPaint(Color.BLUE);

for(DataItem item : series1) {

    float xs = (item.x - minx) / xstep + 1;

    float ys = (item.y - miny) / ystep + 1;

    g2d.fillOval((int)(xs*dx+margin-3), (int)(h-margin-ys*dy-3), 7,7);

}

g2d.fillRect(100, 20, 20, 10);

g2d.drawString("训练数据", 130, 30);


// draw regression line

g2d.setPaint(Color.RED);

for(int i=0; i<series2.size()-1; i++) {

    float x1 = (series2.get(i).x - minx) / xstep + 1;

    float y1 = (series2.get(i).y - miny) / ystep + 1;

    float x2 = (series2.get(i+1).x - minx) / xstep + 1;

    float y2 = (series2.get(i+1).y - miny) / ystep + 1;

    g2d.drawLine((int)(x1*dx+margin-3), (int)(h-margin-y1*dy-3), (int)(x2*dx+margin-3), (int)(h-margin-y2*dy-3));

}

g2d.fillRect(100, 50, 20, 10);

g2d.drawString("线性回归", 130, 60);



g2d.dispose();

saveImage(plot);

四:总结

本文通过最简单的示例,演示了利用梯度下降算法实现线性回归分析,使用更新收敛的算法常被称为LMS(Least Mean Square)又叫Widrow-Hoff学习规则,此外梯度下降算法还可以进一步区分为增量梯度下降算法与批量梯度下降算法,这两种梯度下降方法在基于神经网络的机器学习中经常会被提及,对此感兴趣的可以自己进一步探索与研究。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2017-06-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV学堂 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 基于梯度下降算法求解线性回归
    • 一:线性回归(Linear Regression)
      • 二:梯度下降
        • 三:代码实现各步
          • 四:总结
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档