专栏首页算法channelSpark跑「DBSCAN」算法,工业级代码长啥样?

Spark跑「DBSCAN」算法,工业级代码长啥样?

最近着手的一个项目需要在Spark环境下使用DBSCAN算法,遗憾的是Spark MLlib中并没有提供该算法。调研了一些相关的文章,有些方案是将样本点按照空间位置进行分区,并在每个空间分区中分别跑DBSCAN,但是这种方案容易遇到数据倾斜的问题,并且在分区的边界的结果很有可能是错误的。

经过与一些小伙伴的交流,通过几天的探索尝试,最终在Spark上手工实现了分布式的DBSCAN算法,经过校验结果和Sklearn单机结果完全一致,并且性能也达到了工业级水平。

通过该算法的实现,加深了对Spark的理解,用到了分批次广播和分区迭代计算等技巧,感觉自己还是棒棒哒,特意分享出来供有需要的小伙伴们参考。

一,总体思路

DBSCAN算法的分布式实现需要解决以下一些主要的问题。

1,如何计算样本点中两两之间的距离?

在单机环境下,计算样本点两两之间的距离比较简单,是一个双重遍历的过程。 为了减少计算量,可以用空间索引如Rtree进行加速。

在分布式环境,样本点分布在不同的分区,难以在不同的分区之间直接进行双重遍历。 为了解决这个问题,我的方案是将样本点不同的分区分成多个批次拉到Driver端, 然后依次广播到各个excutor分别计算距离,将最终结果union,从而间接实现双重遍历。

为了减少计算量,广播前对拉到Driver端的数据构建空间索引Rtree进行加速。

2,如何构造临时聚类簇?

这个问题不难,单机环境和分布式环境的实现差不多。

都是通过group的方式统计每个样本点周边邻域半径R内的样本点数量,

并记录它们的id,如果这些样本点数量超过minpoints则构造临时聚类簇,并维护核心点列表。

3,如何合并相连的临时聚类簇得到聚类簇?

这个是分布式实现中最最核心的问题。

在单机环境下,标准做法是对每一个临时聚类簇,判断其中的样本点是否在核心点列表,如果是,则将该样本点所在的临时聚类簇与当前临时聚类簇合并。并在核心点列表中删除该样本点。重复此过程,直到当前临时聚类簇中所有的点都不在核心点列表。

在分布式环境下,临时聚类簇分布在不同的分区,无法直接扫描全局核心点列表进行临时聚类簇的合并。我的方案是先在每一个分区内部对各个临时聚类簇进行合并,然后缩小分区数量重新分区,再在各个分区内部对每个临时聚类簇进行合并。不断重复这个过程,最终将所有的临时聚类簇都划分到一个分区,完成对全部临时聚类簇的合并。

为了降低最后一个分区的存储压力,我采用了不同于标准的临时聚类簇的合并算法。对每个临时聚类簇只关注其中的核心点id,而不关注非核心点id,以减少存储压力。合并时将有共同核心点id的临时聚类簇合并。

为了加快临时聚类的合并过程,分区时并非随机分区,而是以每个临时聚类簇的核心点id中的最小值min_core_id作为分区的Hash参数,具有共同核心点id的临时聚类簇有更大的概率被划分到同一个分区,从而加快了合并过程。

二,核心代码

import org.apache.spark.sql.SparkSession

val spark = SparkSession
.builder()
.appName("dbscan")
.getOrCreate()

val sc = spark.sparkContext
import spark.implicits._

1,寻找核心点形成临时聚类簇。

该步骤一般要采用空间索引 + 广播的方法,此处从略,假定已经得到了临时聚类簇。

//rdd_core的每一行代表一个临时聚类簇:(min_core_id, core_id_set)
//core_id_set为临时聚类簇所有核心点的编号,min_core_id为这些编号中取值最小的编号
var rdd_core = sc.parallelize(List((1L,Set(1L,2L)),(2L,Set(2L,3L,4L)),
                                       (6L,Set(6L,8L,9L)),(4L,Set(4L,5L)),
                                       (9L,Set(9L,10L,11L)),(15L,Set(15L,17L)),
                                       (10L,Set(10L,11L,18L))))
rdd_core.collect.foreach(println)

2,合并临时聚类簇得到聚类簇。

import scala.collection.mutable.ListBuffer
import org.apache.spark.HashPartitioner

//定义合并函数:将有共同核心点的临时聚类簇合并
val mergeSets = (set_list: ListBuffer[Set[Long]]) =>{
var result = ListBuffer[Set[Long]]()
while (set_list.size>0){
var cur_set = set_list.remove(0)
var intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)
while(intersect_idxs.size>0){
for(idx<-intersect_idxs){
        cur_set = cur_set|set_list(idx)
      }
for(idx<-intersect_idxs){
        set_list.remove(idx)
      }
      intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)
    }
result = result:+cur_set
  }
result
}

///对rdd_core分区后在每个分区合并,不断将分区数量减少,最终合并到一个分区
//如果数据规模十分大,难以合并到一个分区,也可以最终合并到多个分区,得到近似结果。
//rdd: (min_core_id,core_id_set)

def mergeRDD(rdd: org.apache.spark.rdd.RDD[(Long,Set[Long])], partition_cnt:Int):
org.apache.spark.rdd.RDD[(Long,Set[Long])] = {
  val rdd_merged =  rdd.partitionBy(new HashPartitioner(partition_cnt))
    .mapPartitions(iter => {
      val buffer = ListBuffer[Set[Long]]()
for(t<-iter){
        val core_id_set:Set[Long] = t._2
        buffer.append(core_id_set)
      }
      val merged_buffer = mergeSets(buffer)
var result = List[(Long,Set[Long])]()
for(core_id_set<-merged_buffer){
        val min_core_id = core_id_set.min
result = result:+(min_core_id,core_id_set)
      }
      result.iterator
    })
  rdd_merged
}
//分区迭代计算,可以根据需要调整迭代次数和分区数量
rdd_core = mergeRDD(rdd_core,8)
rdd_core = mergeRDD(rdd_core,4)
rdd_core = mergeRDD(rdd_core,1)
rdd_core.collect.foreach(println)

三,完整范例

完整范例还包括临时聚类簇的生成,以及最终聚类信息的整理。鉴于该部分代码较为冗长,在当前文章中不展示全部代码,仅说明最终结果。

范例的输入数据和《20分钟学会DBSCAN聚类算法》文中完全一致,共500个样本点。

聚类结果输出如下:

该结果中,聚类簇数量为2个。噪声点数量为500-242-241 = 17个,和调用sklearn中的结果完全一致。

添加云哥的公众号,并后台回复关键字:"源码",获取完整范例代码。

本文分享自微信公众号 - Python与机器学习算法频道(alg-channel)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-11-04

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 3分钟极简掌握matplotlib绘图原理

    matplotlib是基于Python语言的开源项目,旨在为Python提供一个数据绘图包。我将在这篇文章中介绍matplotlib API的核心对象,并介绍如...

    double
  • 列表的常用操作,这十张图把它说的明明白白!

    列表( list)作为Python中最常用的数据类型之一,是一个可增加、删除元素的可变(mutable)容器。

    double
  • 机器学习是万能的吗?AI落地有哪些先决条件?

    这段时间,有幸聆听了几场大牛报告,一位是第四范式,目前工业界应用AI经验最丰富的之一,曾经在百度与吴恩达共同推进AI在工业界的落地;另一位来自学术界,新加坡国立...

    double
  • 「中高级前端」窥探数据结构的世界- ES6版

    数据结构是在计算机中组织和存储数据的一种特殊方式,使得数据可以高效地被访问和修改。更确切地说,数据结构是数据值的集合,表示数据之间的关系,也包括了作用在数据上的...

    Nealyang
  • 窥探数据结构的世界

    数据结构是在计算机中组织和存储数据的一种特殊方式,使得数据可以高效地被访问和修改。更确切地说,数据结构是数据值的集合,表示数据之间的关系,也包括了作用在数据上的...

    用户1462769
  • 「中高级前端」窥探数据结构的世界- ES6版

    数据结构是在计算机中组织和存储数据的一种特殊方式,使得数据可以高效地被访问和修改。更确切地说,数据结构是数据值的集合,表示数据之间的关系,也包括了作用在数据上的...

    前端劝退师
  • 「中高级前端」窥探数据结构的世界- ES6版

    数据结构是在计算机中组织和存储数据的一种特殊方式,使得数据可以高效地被访问和修改。更确切地说,数据结构是数据值的集合,表示数据之间的关系,也包括了作用在数据上的...

    coder_koala
  • Netty粘包拆包解决方案

    前言 本篇文章是Netty专题的第六篇,前面五篇文章如下: 高性能NIO框架Netty入门篇 高性能NIO框架Netty-对象传输 高性能NIO框架Netty...

    猿天地
  • Spring MVC组件

    根据request找到相应的处理器Handler和Interceptors,HanddlerMaping接口只有一个方法,getHandler()

    用户2909867
  • Node程序debug小记

    今天调试一个程序,用到了一个很久之前的NPM包,名为formstream,用来将form表单数据转换为流的形式进行接口调用时的数据传递。

    贾顺名

扫码关注云+社区

领取腾讯云代金券