首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

Spark MLlib 课堂学习笔记-逻辑回归

关于逻辑回归的算法原理Spark官方文档里有说明,另外网上也有中文翻译文档可参考。本笔记是学习MLlib的辑回归API使用时一道练习题记录,通过这道练习,可以掌握基本使用。MLLib提供了两种算法实现,分别是SGD梯度下降法和LBFGS。

1. 数据文件

交通事故的统计文件,四列,accident(去年是否出过事故,1表示出过事故,0表示没有),age(年龄 数值型),vision(视力状况,分类型,1表示好,0表示有问题),drive(驾车教育,分类型,1表示参加过驾车教育,0表示没有)。第1列是因变量,其它3列是特征。这是一个用空格分隔的文本文件,要使用MLLib算法库,首先要读文件并转成LabeledPoint数据类型的RDD。

[plain]view plaincopy

1 17 1 1

1 44 0 0

1 48 1 0

1 55 0 0

1 75 1 1

0 35 0 1

0 42 1 1

0 57 0 0

0 28 0 1

0 20 0 1

0 38 1 0

0 45 0 1

0 47 1 1

0 52 0 0

0 55 0 1

1 68 1 0

1 18 1 0

1 68 0 0

1 48 1 1

1 17 0 0

1 70 1 1

1 72 1 0

1 35 0 1

1 19 1 0

1 62 1 0

0 39 1 1

0 40 1 1

0 55 0 0

0 68 0 1

0 25 1 0

0 17 0 0

0 45 0 1

0 44 0 1

0 67 0 0

0 55 0 1

1 61 1 0

1 19 1 0

1 69 0 0

1 23 1 1

1 19 0 0

1 72 1 1

1 74 1 0

1 31 0 1

1 16 1 0

1 61 1 0

2. SGD算法

[plain]view plaincopy

package classify

/*

accident.txt

accident(去年是否出过事故,1表示出过事故,0表示没有)

age(年龄 数值型)

vision(视力状况,分类型,1表示好,0表示有问题)

drive(驾车教育,分类型,1表示参加过驾车教育,0表示没有)

*/

import org.apache.spark.mllib.linalg.

import org.apache.spark.mllib.regression.LabeledPoint

import org.apache.spark.mllib.classification.LogisticRegressionWithSGD

import org.apache.spark.

object LogisticSGD {

def parseLine(line: String): LabeledPoint = {

val parts = line.split(" ")

val vd: Vector = Vectors.dense(parts(1).toDouble, parts(2).toDouble, parts(3).toDouble)

return LabeledPoint(parts(0).toDouble, vd )

}

def main(args: Array[String]){

val conf = new SparkConf().setMaster(args(0)).setAppName("LogisticSGD")

val sc = new SparkContext(conf)

val data = sc.textFile(args(1)).map(parseLine(_))

val splits = data.randomSplit(Array(0.6, 0.4), seed=11L)

val trainData = splits(0)

val testData = splits(1)

val model = LogisticRegressionWithSGD.train(trainData, 50)

println(model.weights.size)

println(model.weights)

println(model.weights.toArray.filter(_ != 0).size)

val predictionAndLabel = testData.map(p => (model.predict(p.features), p.label))

predictionAndLabel.foreach(println)

}

}

parseLine函数将文本文件的每一行转成一个LabeledPoint数据类型,randomSplit用例把数据集分成训练和测试两部分。val model = LogisticRegressionWithSGD.train(trainData, 50) 执行训练并得到模型,这里的50为迭代次数。val predictionAndLabel = testData.map(p => (model.predict(p.features), p.label))中的model.predict执行预测,testData.map测试数据集的特征值传递给model去预测,并将预测值与原有的label合并形成一个新的map。

3. LBFGS算法

[plain]view plaincopy

package classify

import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS

import org.apache.spark.

import org.apache.spark.mllib.linalg.

import org.apache.spark.mllib.regression.LabeledPoint

object LogisticLBFGS {

def parseLine(line: String): LabeledPoint = {

val parts = line.split(" ")

val vd: Vector = Vectors.dense(parts(1).toDouble, parts(2).toDouble, parts(3).toDouble)

return LabeledPoint(parts(0).toDouble, vd )

}

def main(args: Array[String]){

val conf = new SparkConf().setMaster(args(0)).setAppName("LogisticLBFGS")

val sc = new SparkContext(conf)

val data = sc.textFile(args(1)).map(parseLine(_))

val splits = data.randomSplit(Array(0.6, 0.4), seed=11L)

val trainData = splits(0)

val testData = splits(1)

val model = new LogisticRegressionWithLBFGS().setNumClasses(2).run(trainData)

println(model.weights.size)

println(model.weights)

println(model.weights.toArray.filter(_ != 0).size)

val prediction = testData.map(p => (model.predict(p.features), p.label))

//println(prediction)

prediction.foreach(println)

}

}

val model = new LogisticRegressionWithLBFGS().setNumClasses(2).run(trainData)中的setNumClasses(2)设置分类数。

对于这个列子,LBFGS的效果比SGD的效果好。

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180106A097PW00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券