前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >《机器学习实战(Scala实现)》(二)——k-邻近算法

《机器学习实战(Scala实现)》(二)——k-邻近算法

作者头像
小爷毛毛_卓寿杰
发布2019-02-13 11:47:01
5140
发布2019-02-13 11:47:01
举报
文章被收录于专栏:Soul Joy HubSoul Joy Hub

算法流程

1.计算中的set中每一个点与Xt的距离。 2.按距离增序排。 3.选择距离最小的前k个点。 4.确定前k个点所在的label的出现频率。 5.返回频率最高的label作为测试的结果。

实现

python

代码语言:javascript
复制
# -*- coding: utf-8 -*-  
'''
Created on 2017年3月18日

@author: soso
'''
from numpy import *
import operator

def createDataSet():
    group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    # 函数形式: tile(A,rep)
    # 功能:重复A的各个维度
    # 参数类型:
    # - A: Array类的都可以
    # - rep:A沿着各个维度重复的次数
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat ** 2
    # 当加入axis=1以后就是将一个矩阵的每一行向量相加
    sqDistances = sqDiffMat.sum(axis=1)
    distance = sqDistances * 0.5
    # argsort函数返回的是数组值从小到大的索引值
    sortedDistIndicies = distance.argsort()
    classCount = {}
    for i in range(k):
        votelabel = labels[sortedDistIndicies[i]]
        classCount[votelabel] = classCount.get(votelabel, 0) + 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

Scala

代码语言:javascript
复制
import scala.collection.mutable.Map

object kNN {

  def getGroup(): Array[Array[Double]] = {
    return Array(Array(1.0, 1.1), Array(1.0, 1.0), Array(0, 0), Array(0, 0.1))
  }
  def getLabels(): Array[Char] = {
    return Array('A', 'A', 'B', 'B')
  }

  def classify0(inX: Array[Double], dataSet: Array[Array[Double]], labels: Array[Char], k: Int): Char = {
    val dataSetSize = dataSet.length
    val sortedDisIndicies = dataSet.map { x =>
      val v1 = x(0) - inX(0)
      val v2 = x(1) - inX(1)
      v1 * v1 + v2 * v2
    }.zipWithIndex.sortBy(f => f._1).map(f => f._2)
    var classsCount: Map[Char, Int] = Map.empty
    for (i <- 0 to k - 1) {
      val voteIlabel = labels(sortedDisIndicies(i))
      classsCount(voteIlabel) = classsCount.getOrElse(voteIlabel, 0) + 1
    }
    classsCount.toArray.sortBy(f => -f._2).head._1
  }
  def main(args: Array[String]) {
    println(classify0(Array(0, 0), getGroup(), getLabels(), 3))
  }
}
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017年03月18日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 算法流程
  • 实现
    • python
      • Scala
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档