前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Spark跑「DBSCAN」算法,工业级代码长啥样?

Spark跑「DBSCAN」算法,工业级代码长啥样?

作者头像
double
发布2019-11-14 17:11:45
2.2K2
发布2019-11-14 17:11:45
举报
文章被收录于专栏:算法channel算法channel

最近着手的一个项目需要在Spark环境下使用DBSCAN算法,遗憾的是Spark MLlib中并没有提供该算法。调研了一些相关的文章,有些方案是将样本点按照空间位置进行分区,并在每个空间分区中分别跑DBSCAN,但是这种方案容易遇到数据倾斜的问题,并且在分区的边界的结果很有可能是错误的。

经过与一些小伙伴的交流,通过几天的探索尝试,最终在Spark上手工实现了分布式的DBSCAN算法,经过校验结果和Sklearn单机结果完全一致,并且性能也达到了工业级水平。

通过该算法的实现,加深了对Spark的理解,用到了分批次广播和分区迭代计算等技巧,感觉自己还是棒棒哒,特意分享出来供有需要的小伙伴们参考。

一,总体思路

DBSCAN算法的分布式实现需要解决以下一些主要的问题。

1,如何计算样本点中两两之间的距离?

在单机环境下,计算样本点两两之间的距离比较简单,是一个双重遍历的过程。 为了减少计算量,可以用空间索引如Rtree进行加速。

在分布式环境,样本点分布在不同的分区,难以在不同的分区之间直接进行双重遍历。 为了解决这个问题,我的方案是将样本点不同的分区分成多个批次拉到Driver端, 然后依次广播到各个excutor分别计算距离,将最终结果union,从而间接实现双重遍历。

为了减少计算量,广播前对拉到Driver端的数据构建空间索引Rtree进行加速。

2,如何构造临时聚类簇?

这个问题不难,单机环境和分布式环境的实现差不多。

都是通过group的方式统计每个样本点周边邻域半径R内的样本点数量,

并记录它们的id,如果这些样本点数量超过minpoints则构造临时聚类簇,并维护核心点列表。

3,如何合并相连的临时聚类簇得到聚类簇?

这个是分布式实现中最最核心的问题。

在单机环境下,标准做法是对每一个临时聚类簇,判断其中的样本点是否在核心点列表,如果是,则将该样本点所在的临时聚类簇与当前临时聚类簇合并。并在核心点列表中删除该样本点。重复此过程,直到当前临时聚类簇中所有的点都不在核心点列表。

在分布式环境下,临时聚类簇分布在不同的分区,无法直接扫描全局核心点列表进行临时聚类簇的合并。我的方案是先在每一个分区内部对各个临时聚类簇进行合并,然后缩小分区数量重新分区,再在各个分区内部对每个临时聚类簇进行合并。不断重复这个过程,最终将所有的临时聚类簇都划分到一个分区,完成对全部临时聚类簇的合并。

为了降低最后一个分区的存储压力,我采用了不同于标准的临时聚类簇的合并算法。对每个临时聚类簇只关注其中的核心点id,而不关注非核心点id,以减少存储压力。合并时将有共同核心点id的临时聚类簇合并。

为了加快临时聚类的合并过程,分区时并非随机分区,而是以每个临时聚类簇的核心点id中的最小值min_core_id作为分区的Hash参数,具有共同核心点id的临时聚类簇有更大的概率被划分到同一个分区,从而加快了合并过程。

二,核心代码

代码语言:javascript
复制
import org.apache.spark.sql.SparkSession

val spark = SparkSession
.builder()
.appName("dbscan")
.getOrCreate()

val sc = spark.sparkContext
import spark.implicits._

1,寻找核心点形成临时聚类簇。

该步骤一般要采用空间索引 + 广播的方法,此处从略,假定已经得到了临时聚类簇。

代码语言:javascript
复制
//rdd_core的每一行代表一个临时聚类簇:(min_core_id, core_id_set)
//core_id_set为临时聚类簇所有核心点的编号,min_core_id为这些编号中取值最小的编号
var rdd_core = sc.parallelize(List((1L,Set(1L,2L)),(2L,Set(2L,3L,4L)),
                                       (6L,Set(6L,8L,9L)),(4L,Set(4L,5L)),
                                       (9L,Set(9L,10L,11L)),(15L,Set(15L,17L)),
                                       (10L,Set(10L,11L,18L))))
rdd_core.collect.foreach(println)

2,合并临时聚类簇得到聚类簇。

代码语言:javascript
复制
import scala.collection.mutable.ListBuffer
import org.apache.spark.HashPartitioner

//定义合并函数:将有共同核心点的临时聚类簇合并
val mergeSets = (set_list: ListBuffer[Set[Long]]) =>{
var result = ListBuffer[Set[Long]]()
while (set_list.size>0){
var cur_set = set_list.remove(0)
var intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)
while(intersect_idxs.size>0){
for(idx<-intersect_idxs){
        cur_set = cur_set|set_list(idx)
      }
for(idx<-intersect_idxs){
        set_list.remove(idx)
      }
      intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)
    }
result = result:+cur_set
  }
result
}

///对rdd_core分区后在每个分区合并,不断将分区数量减少,最终合并到一个分区
//如果数据规模十分大,难以合并到一个分区,也可以最终合并到多个分区,得到近似结果。
//rdd: (min_core_id,core_id_set)

def mergeRDD(rdd: org.apache.spark.rdd.RDD[(Long,Set[Long])], partition_cnt:Int):
org.apache.spark.rdd.RDD[(Long,Set[Long])] = {
  val rdd_merged =  rdd.partitionBy(new HashPartitioner(partition_cnt))
    .mapPartitions(iter => {
      val buffer = ListBuffer[Set[Long]]()
for(t<-iter){
        val core_id_set:Set[Long] = t._2
        buffer.append(core_id_set)
      }
      val merged_buffer = mergeSets(buffer)
var result = List[(Long,Set[Long])]()
for(core_id_set<-merged_buffer){
        val min_core_id = core_id_set.min
result = result:+(min_core_id,core_id_set)
      }
      result.iterator
    })
  rdd_merged
}
代码语言:javascript
复制
//分区迭代计算,可以根据需要调整迭代次数和分区数量
rdd_core = mergeRDD(rdd_core,8)
rdd_core = mergeRDD(rdd_core,4)
rdd_core = mergeRDD(rdd_core,1)
rdd_core.collect.foreach(println)

三,完整范例

完整范例还包括临时聚类簇的生成,以及最终聚类信息的整理。鉴于该部分代码较为冗长,在当前文章中不展示全部代码,仅说明最终结果。

范例的输入数据和《20分钟学会DBSCAN聚类算法》文中完全一致,共500个样本点。

聚类结果输出如下:

该结果中,聚类簇数量为2个。噪声点数量为500-242-241 = 17个,和调用sklearn中的结果完全一致。

添加云哥的公众号,并后台回复关键字:"源码",获取完整范例代码。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-11-04,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 程序员郭震zhenguo 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一,总体思路
  • 二,核心代码
  • 三,完整范例
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档