前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >机器学习系列--KNN分类算法例子

机器学习系列--KNN分类算法例子

作者头像
Dlimeng
发布2023-06-29 16:08:23
1250
发布2023-06-29 16:08:23
举报
文章被收录于专栏:开源心路开源心路

url:机器学习系列--KNN分类算法

用的是spark2.0.2,scala2.11

import org.apache.spark.{SparkConf, SparkContext}

object knntest {

  /**     * 欧式距离     * 计算两点间的距离     * @param rs as r1,r2, ..., rd     * @param ss as s1,s2, ..., sd     * @param d 维数     */   def euclideanDistance(rs: String, ss: String, d: Int): Double = {     val r = rs.split(",").map(_.toDouble)     val s = ss.split(",").map(_.toDouble)

    if (r.length != d || s.length != d) Double.NaN else {       //zip匹配key/value 分区数一样,ri-si的平方的求和再开方,欧式距离       math.sqrt((r, s).zipped.take(d).map {         case (ri, si) => math.pow(ri - si, 2)       }.sum)     }   }

  def main(args: Array[String]): Unit = {     val sparkConf=new SparkConf().setAppName("knntest").setMaster("local[4]")     val sc=new SparkContext(sparkConf)

    //生成矩阵,每行代表一个样本 10为索引,A,B为类别,其它为属性1,2..     val groupes=sc.parallelize(List("10;A;1.0,0.9", "11;A;1.0,1.0", "12;B;0.1,0.2", "13;B;0.0,0.1"))     //100为索引,其它为属性1,2..     val testxs = sc.parallelize(List("100;1.2,1.0","101;0.1,0.3"))     //近邻数     val k = sc.broadcast(3)     //向量维度     val d = sc.broadcast(2)     //笛卡尔     //ArrayBuffer((100;1.2,1.0,10;A;1.0,0.9),     // (100;1.2,1.0,11;A;1.0,1.0),     // (100;1.2,1.0,12;B;0.1,0.2),     // (100;1.2,1.0,13;B;0.0,0.1),     // (101;0.1,0.3,10;A;1.0,0.9),     // (101;0.1,0.3,11;A;1.0,1.0),     // (101;0.1,0.3,12;B;0.1,0.2),     // (101;0.1,0.3,13;B;0.0,0.1))     val cart=testxs.cartesian(groupes)

    val knns=cart.map(p=>{       val testx=p._1//例 100;1.2,1.0       val group2=p._2//例 10;A;1.0,0.9       val testx_index=testx.split(";")(0)       val testx_rs=testx.split(";")(1)

      //类型       val group2_type=group2.split(";")(1)       val group2_ss=group2.split(";")(2)       //欧式距离       val distance =euclideanDistance(testx_rs, group2_ss, d.value)       //ArrayBuffer((100,(0.2236067977499789,A)), (100,(0.19999999999999996,A)),       // (100,(1.3601470508735443,B)), (100,(1.5,B)),       // (101,(1.0816653826391969,A)), (101,(1.140175425099138,A)),       // (101,(0.09999999999999998,B)), (101,(0.22360679774997896,B)))       (testx_index,(distance,group2_type))     })

    val knnGrouped = knns.groupByKey()

    val knnOutput = knnGrouped.mapValues(itr => {       //(100,List((0.19999999999999996,A), (0.2236067977499789,A), (1.3601470508735443,B)))       //(101,List((0.09999999999999998,B), (0.22360679774997896,B), (1.0816653826391969,A)))       val nearestK = itr.toList.sortBy(_._1).take(k.value)       //(101,List((B,1), (B,1), (A,1)))       //(100,List((A,1), (A,1), (B,1)))       //(100,Map(A -> 2, B -> 1))       //(101,Map(A -> 1, B -> 2))       val majority = nearestK.map(f => (f._2, 1)).groupBy(_._1).mapValues(list => {         val (stringList, intlist) = list.unzip         intlist.sum       })       //(100,A)       //(101,B)       majority.maxBy(_._2)._1     })

    knnOutput.foreach(println)     sc.stop()   } }

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018-08-10,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 用的是spark2.0.2,scala2.11
  • import org.apache.spark.{SparkConf, SparkContext}
  • object knntest {
  •   /**     * 欧式距离     * 计算两点间的距离     * @param rs as r1,r2, ..., rd     * @param ss as s1,s2, ..., sd     * @param d 维数     */   def euclideanDistance(rs: String, ss: String, d: Int): Double = {     val r = rs.split(",").map(_.toDouble)     val s = ss.split(",").map(_.toDouble)
  •     if (r.length != d || s.length != d) Double.NaN else {       //zip匹配key/value 分区数一样,ri-si的平方的求和再开方,欧式距离       math.sqrt((r, s).zipped.take(d).map {         case (ri, si) => math.pow(ri - si, 2)       }.sum)     }   }
  •   def main(args: Array[String]): Unit = {     val sparkConf=new SparkConf().setAppName("knntest").setMaster("local[4]")     val sc=new SparkContext(sparkConf)
  •     //生成矩阵,每行代表一个样本 10为索引,A,B为类别,其它为属性1,2..     val groupes=sc.parallelize(List("10;A;1.0,0.9", "11;A;1.0,1.0", "12;B;0.1,0.2", "13;B;0.0,0.1"))     //100为索引,其它为属性1,2..     val testxs = sc.parallelize(List("100;1.2,1.0","101;0.1,0.3"))     //近邻数     val k = sc.broadcast(3)     //向量维度     val d = sc.broadcast(2)     //笛卡尔     //ArrayBuffer((100;1.2,1.0,10;A;1.0,0.9),     // (100;1.2,1.0,11;A;1.0,1.0),     // (100;1.2,1.0,12;B;0.1,0.2),     // (100;1.2,1.0,13;B;0.0,0.1),     // (101;0.1,0.3,10;A;1.0,0.9),     // (101;0.1,0.3,11;A;1.0,1.0),     // (101;0.1,0.3,12;B;0.1,0.2),     // (101;0.1,0.3,13;B;0.0,0.1))     val cart=testxs.cartesian(groupes)
  •     val knns=cart.map(p=>{       val testx=p._1//例 100;1.2,1.0       val group2=p._2//例 10;A;1.0,0.9       val testx_index=testx.split(";")(0)       val testx_rs=testx.split(";")(1)
  •       //类型       val group2_type=group2.split(";")(1)       val group2_ss=group2.split(";")(2)       //欧式距离       val distance =euclideanDistance(testx_rs, group2_ss, d.value)       //ArrayBuffer((100,(0.2236067977499789,A)), (100,(0.19999999999999996,A)),       // (100,(1.3601470508735443,B)), (100,(1.5,B)),       // (101,(1.0816653826391969,A)), (101,(1.140175425099138,A)),       // (101,(0.09999999999999998,B)), (101,(0.22360679774997896,B)))       (testx_index,(distance,group2_type))     })
  •     val knnGrouped = knns.groupByKey()
  •     val knnOutput = knnGrouped.mapValues(itr => {       //(100,List((0.19999999999999996,A), (0.2236067977499789,A), (1.3601470508735443,B)))       //(101,List((0.09999999999999998,B), (0.22360679774997896,B), (1.0816653826391969,A)))       val nearestK = itr.toList.sortBy(_._1).take(k.value)       //(101,List((B,1), (B,1), (A,1)))       //(100,List((A,1), (A,1), (B,1)))       //(100,Map(A -> 2, B -> 1))       //(101,Map(A -> 1, B -> 2))       val majority = nearestK.map(f => (f._2, 1)).groupBy(_._1).mapValues(list => {         val (stringList, intlist) = list.unzip         intlist.sum       })       //(100,A)       //(101,B)       majority.maxBy(_._2)._1     })
  •     knnOutput.foreach(println)     sc.stop()   } }
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档