前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >sklearn 源码分析系列:neighbors(2)

sklearn 源码分析系列:neighbors(2)

作者头像
用户1147447
发布2019-05-26 19:45:16
1.1K0
发布2019-05-26 19:45:16
举报
文章被收录于专栏:机器学习入门机器学习入门

sklearn 源码分析系列:neighbors(2)

by DemonSonggithub源码链接(https://github.com/demonSong/DML)

by\space DemonSong\\ github源码链接(https://github.com/demonSong/DML)

我起初一直在纠结是否需要把kd_tree的实现也放在这一篇中讲,如果讲算法实现,就违背了源码分析的初衷,过早钻入细节,是阅读源码的大忌。算法和框架的分析应属两部分内容,所以最终决定,所有sklearn源码分析系列不涉及具体算法,而是保证每个方法调用的连通性,重点关注架构,以及一些必要的python实现细节。

Note:

这篇文章主要分析Neighbors包中的Unsupervised Nearest Neighbors相关接口,对应于官方文档1.6.1章节,详见文档

Finding the Nearest Neighbors实操

详细实操代码可参考Github kaggle项目,详见链接

在实现最近邻算法时,常用的算法有”kd_tree”,”ball_tree”,”brute”三种,它们对应于不同的应用场景,这里不再赘述。

数据生成与可视化

代码语言:javascript
复制
# 1.6.1 Unsupervised Nearest Neighbors

from sklearn.neighbors import NearestNeighbors
import numpy as np
import matplotlib.pyplot as plt


# 1.6.1.1 Finding the Nearest Neighbors
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])

plt.figure()
plt.scatter(X[:,0],X[:,1])
plt.xlim(X[:,0].min()-1,X[:,0].max()+1)
plt.ylim(X[:,1].min()-1,X[:,1].max()+1)
plt.title("Unsupervised nearest neighbors")
plt.show()

# k个最近的点中包含自己
nbrs = NearestNeighbors(n_neighbors=3, algorithm='ball_tree').fit(X)

distances,indices = nbrs.kneighbors(X)

# k个最近点的下标,按升序排列
indices
alt text
alt text

输出:

代码语言:javascript
复制
array([[0, 1, 2],
       [1, 0, 2],
       [2, 1, 0],
       [3, 4, 5],
       [4, 3, 5],
       [5, 4, 3]], dtype=int64)
代码语言:javascript
复制
# k个最近点的最短距离,按升序排列
distances

Out[2]:
array([[ 0.        ,  1.        ,  2.23606798],
       [ 0.        ,  1.        ,  1.41421356],
       [ 0.        ,  1.41421356,  2.23606798],
       [ 0.        ,  1.        ,  2.23606798],
       [ 0.        ,  1.        ,  1.41421356],
       [ 0.        ,  1.41421356,  2.23606798]])

kneighbors(X)默认返回两个参数,其中k个最近邻中还包含了自己,距离和下标均按照升序排列。

代码语言:javascript
复制
# k个最近点生成的邻接矩阵
nbrs.kneighbors_graph(X).toarray()

Out [3]:
array([[ 1.,  1.,  1.,  0.,  0.,  0.],
       [ 1.,  1.,  1.,  0.,  0.,  0.],
       [ 1.,  1.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  1.,  1.],
       [ 0.,  0.,  0.,  1.,  1.,  1.],
       [ 0.,  0.,  0.,  1.,  1.,  1.]])

# 1.6.1.2 KD Tree and Ball Tree Classes
from sklearn.neighbors import KDTree
import numpy as np

# 可直接用KDtree实现最近邻查找
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
kdt = KDTree(X, leaf_size=30, metric='euclidean')
kdt.query(X,k = 3,return_distance = False)

Out [4]:
array([[0, 1, 2],
       [1, 0, 2],
       [2, 1, 0],
       [3, 4, 5],
       [4, 3, 5],
       [5, 4, 3]], dtype=int64)

源码剖析

我们先从整体上来看看,实现NearestNeighbors所需关联到的python文件及对应的文件结构是什么样子的。

alt text
alt text

相比于Neighbors(1)中的内容,它多了unsupervised.py文件而已。所以,我们直接顺藤摸瓜开始分析。

unsupervised.py

代码语言:javascript
复制
class NearestNeighbors(NeighborsBase, KNeighborsMixin,
                       RadiusNeighborsMixin, UnsupervisedMixin):
    def __init__(self, n_neighbors=5, radius=1.0,
                     algorithm='auto', leaf_size=30, metric='minkowski',
                     p=2, metric_params=None, n_jobs=1, **kwargs):
            self._init_params(n_neighbors=n_neighbors,
                              radius=radius,
                              algorithm=algorithm,
                              leaf_size=leaf_size, metric=metric, p=p,
                              metric_params=metric_params, n_jobs=n_jobs, **kwargs)

这是一个明显的子类继承多个父类的情况,其中KNeighborsMixinRadiusNeighborsMixin属于功能相同,但具体实现细节有所差异,只单独分析一例。

先来看看它的构造方法吧,构造方法中传入了,9个参数,都是带默认值的。但令人奇怪的是,它同样是空有型而无内容的【初始化类】,该类只与客户端打交道,而真正的参数初始化都交给了其中的某个父类的__init__params()方法。为什么要这么做?不急,先看看到底是哪个父类完成了参数初始化。

所有父类集中在neighbors包下的base.py文件中。

alt text
alt text

经过一番寻找总算找到了初始化参数方法,在类neighborsBase

代码语言:javascript
复制
class NeighborsBase(six.with_metaclass(ABCMeta, BaseEstimator)):
    """Base class for nearest neighbors estimators."""

    @abstractmethod
    def __init__(self):
        pass

    def _init_params(self, n_neighbors=None, radius=None,
                     algorithm='auto', leaf_size=30, metric='minkowski',
                     p=2, metric_params=None, n_jobs=1):

        self.n_neighbors = n_neighbors
        self.radius = radius
        self.algorithm = algorithm
        self.leaf_size = leaf_size
        self.metric = metric
        self.metric_params = metric_params
        self.p = p
        self.n_jobs = n_jobs

        if algorithm not in ['auto', 'brute',
                             'kd_tree', 'ball_tree']:
            raise ValueError("unrecognized algorithm: '%s'" % algorithm)

        if algorithm == 'auto':
            if metric == 'precomputed':
                alg_check = 'brute'
            else:
                alg_check = 'ball_tree'
        else:
            alg_check = algorithm

        if callable(metric):
            if algorithm == 'kd_tree':
                # callable metric is only valid for brute force and ball_tree
                raise ValueError(
                    "kd_tree algorithm does not support callable metric '%s'"
                    % metric)
        elif metric not in VALID_METRICS[alg_check]:
            raise ValueError("Metric '%s' not valid for algorithm '%s'"
                             % (metric, algorithm))

        if self.metric_params is not None and 'p' in self.metric_params:
            warnings.warn("Parameter p is found in metric_params. "
                          "The corresponding parameter from __init__ "
                          "is ignored.", SyntaxWarning, stacklevel=3)
            effective_p = metric_params['p']
        else:
            effective_p = self.p

        if self.metric in ['wminkowski', 'minkowski'] and effective_p < 1:
            raise ValueError("p must be greater than one for minkowski metric")

        # 重点关注
        self._fit_X = None
        self._tree = None
        self._fit_method = None

喔,原来NeighborsBase是要作为整个Neighbors最具领导力的类?起码这家伙拿到了全局信息吧,我的一个猜测是,除了unsupervised需要用到这些参数之外,其他类也同样需要用这些参数做些有趣的事吧?所以既然大家都要复用这些参数!那就放在一个基类中吧,此处就叫NeighborsBase吧。(待检验)

我们关注下方法本身中的参数: 1. self.n_neighbors = n_neighbors ## k近邻中的k 2. self.radius = radius ## 不知 3. self.algorithm = algorithm ## 使用何种k近邻算法,如’kd_tree’ 4. self.leaf_size = leaf_size ## 生成’kd_tree’树需要传入的参数 5. self.metric = metric ## 计算其他各种形式的两点间距离 6. self.metric_params = metric_params ## 不知 7. self.p = p ## 不知 8. self.n_jobs = n_jobs ## 并发创建的线程数

除此之外,在初始化最后,还占了三个位: 1. self._fit_X = None ## fit_X 和传入的X之间有何关系? 2. self._tree = None ## _tree表示返回的树结构 3. self._fit_method = None ## fit传入的算法

NeighborsBase就这些内容,它还有一个_fit()方法,稍后分析。总的来说,当客户端调用诸如nbrs = NearestNeighbors(n_neighbors=3, algorithm='kd_tree',leaf_size=30)的构造方法时,NearestNeighbors什么都没做,把参数初始化任务交给了它的父类NeighborsBase(该小组的老大!),而这老大具体也没做什么具体的事,把该初始化的参数初始化,并做一些参数合法性的检查,完工。

模型参数初始完毕之后,自然到了fit步骤,正如,客户端调用那样nbrs = NearestNeighbors(n_neighbors=3, algorithm='kd_tree',leaf_size=30).fit(X)我把数据X,传给了谁?谁来拟合这些数据呢?

记得NearestNeighbors中的几个父类吧,完成fit操作的是UnsupervisedMixin类,接着来看看它的代码。

代码语言:javascript
复制
class UnsupervisedMixin(object):
    def fit(self, X, y=None):
        """Fit the model using X as training data

        Parameters
        ----------
        X : {array-like, sparse matrix, BallTree, KDTree}
            Training data. If array or matrix, shape [n_samples, n_features],
            or [n_samples, n_samples] if metric='precomputed'.
        """
        return self._fit(X)

非常简短,针对非监督的数据,全部交给了自己的self._fit(X)方法,所以它又是个代理类?这个代理类更狠,什么都没做,直接转交给NearestNeighbors中的某个父类来完成。调用_fit()方法后,就又回到了NeighborsBase中去了,所以当客户端要调用fit方法时,先交给了NeighborsBase的手下UnsupervisedMixin做一些前期的处理操作,但这手下学会了偷懒,什么都没做直接交给了领导,直接让领导来处理咯,真坏。那领导真的有功夫,有能力处理这个fit任务?领导也不傻,我们看看领导怎么做的。

代码语言:javascript
复制
def _fit(self, X):

        ......
        # 做些必要的检查
        X = check_array(X, accept_sparse='csr')

        # 还是在做检查
        n_samples = X.shape[0]
        if n_samples == 0:
            raise ValueError("n_samples must be greater than 0")
        ......

        #前面占的位子给补上
        self._fit_method = self.algorithm
        self._fit_X = X

        ......

        # 嘿,领导开始派发任务了
        if self._fit_method == 'ball_tree':
            self._tree = BallTree(X, self.leaf_size,
                                  metric=self.effective_metric_,
                                  **self.effective_metric_params_)
        # 看到了熟悉的kd_tree了                       
        elif self._fit_method == 'kd_tree':
            self._tree = KDTree(X, self.leaf_size,
                                metric=self.effective_metric_,
                                **self.effective_metric_params_)
        elif self._fit_method == 'brute':
            self._tree = None
        else:
            raise ValueError("algorithm = '%s' not recognized"
                             % self.algorithm)

        # 检查,为什么不放在一开始做?
        if self.n_neighbors is not None:
            if self.n_neighbors <= 0:
                raise ValueError(
                    "Expected n_neighbors > 0. Got %d" %
                    self.n_neighbors
                )

        return self

唉,领导也没有干活啊,做了一些检查,根据来的参数,交给对应的具体执行者去做!但返回的还是自己,因为我要和客户端打交道。我们来分析下具体的执行者做了些什么操作。看如下代码,

代码语言:javascript
复制
elif self._fit_method == 'kd_tree':
            self._tree = KDTree(X, self.leaf_size,                               metric=self.effective_metric_,                          **self.effective_metric_params_)

NeighborsBasefit()方法中,并不是返回某个模型对象,而是把模型对象内嵌到了NeighborsBase中的self._tree中去,这是为什么?kd_tree模型本身有查询最近邻的方法,为什么不直接暴露给客户端呢?在这里我并不理解它这样做的用意是什么。(待解决)

所以对于数据真正的fit()是交给具体算法来完成的,咱们接下来就看看kd_tree.py吧。关于kd_tree的算法细节,可以参考之前我的一篇博文【K近邻法学习笔记】。关于sklearn中kd_tree的具体分析,不作为本文内容,日后单独开辟一章来讲解。本文重点关注各接口的实现与内在联系。

alt text
alt text

所以当NeighborsBase构造了kd_tree时,就调用了它的构造方法,走。

代码语言:javascript
复制
def __init__(self, data, leafsize=10):
        self.data = np.asarray(data)
        self.n, self.m = np.shape(self.data)
        self.leafsize = int(leafsize)
        if self.leafsize < 1:
            raise ValueError("leafsize must be at least 1")
        self.maxes = np.amax(self.data,axis=0)
        self.mins = np.amin(self.data,axis=0)

        # 关键步骤
        self.tree = self.__build(np.arange(self.n), self.maxes, self.mins)

前面也是做了一些初始化操作,接着开始构建kd_tree的数据结构了。调用__build()方法,由传入的数据的生成了对应的数据结构。到这里,数据到结构的映射完成了。

Created with Raphaël 2.1.0数据X到结构的映射ClientClientNearestNeighborsNearestNeighborsNeighborsBaseNeighborsBaseUnsupervisedMixinUnsupervisedMixinKDTreeKDTree__init__()_init_params()fit()_fit()__init__()_build()

总结下,NearsetNeighbors和客户端打交到,而NeighborsBase统筹规划所有调度。

既然有了数据X到结构的映射,那自然要做真正的查询操作了(k近邻查询),我们继续来看看,客户端调用如下distances,indices = nbrs.kneighbors(X),在NearestNeighbors中只要初始化方法,并没有kneighbors(X)方法,该方法在它的另外一个父类KNeighborsMixin中。

代码语言:javascript
复制
class KNeighborsMixin(object):
    def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
    ......
     n_samples, _ = X.shape
        sample_range = np.arange(n_samples)[:, None]
    ......
    elif self._fit_method in ['ball_tree', 'kd_tree']:
            if issparse(X):
                raise ValueError(
                    "%s does not work with sparse matrices. Densify the data, "
                    "or set algorithm='brute'" % self._fit_method)
            result = Parallel(n_jobs, backend='threading')(
                delayed(self._tree.query, check_pickle=False)(
                    X[s], n_neighbors, return_distance)
                for s in gen_even_slices(X.shape[0], n_jobs)
            )
            if return_distance:
                dist, neigh_ind = tuple(zip(*result))
                result = np.vstack(dist), np.vstack(neigh_ind)
            else:
                result = np.vstack(result)

很多东西都可以忽略不看,只需要关注一行代码就可以了。

代码语言:javascript
复制
result = Parallel(n_jobs, backend='threading')(
                delayed(self._tree.query, check_pickle=False)(
                    X[s], n_neighbors, return_distance)
                for s in gen_even_slices(X.shape[0], n_jobs)
            )

前面它包了一个并发的类,咱们不去研究,在delay方法中,传入了self._tree.query这是一个方法名,在之前KDTree类的接口中,有相应的实现,也就是说KNeighborsMixin类也不做任何查询操作,同样把查询交给了KDTree来完成,的确如此,只有KDTree中存放了相应的数据结构,不是它做查询谁来做查询,KNeighborsMixin只是简单的把KDTree返回的查询结果交给客户端就可以了,别无其他。

Created with Raphaël 2.1.0查询结果的返回过程ClientClientNearestNeighborsNearestNeighborsKNeighborsMixinKNeighborsMixinKDTreeKDTreek近邻查询kneighbors(X)query(X)查询结果查询结果

综上,整个关于数据X到kd_tree的结构映射调用就完成了,也没有太多东西,理清各个类之间的关系就可以了。同样的,当要进行k近邻查询时,交给了NearestNeighbors中的父类KNeighborsMixin来代理查询,真正的查询操作还是kd_tree来完成,前期都是些琐碎的调用流程,而算法的核心在于kd_tree,起码数据在到kd_tree之前,能够做很多前期处理,保证了算法对数据的要求。看来是时候研究下kd_tree的核心算法了。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017年03月16日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • sklearn 源码分析系列:neighbors(2)
    • Finding the Nearest Neighbors实操
      • 源码剖析
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档