Isolation Forest算法实现详解

前言

要想对一个算法有深入的了解,不仅仅要懂得其基本的原理,更要学会如何实现该算法才行。因为从原理到实现的过程还有很多未知的在理论中无法体现出来的实现技巧在里面。本文介绍的 Isolation Forest 算法原理的详细解释请参看之前文章:Isolation Forest异常检测算法原理详解。在此我们只介绍详细的代码实现过程细节,请读者一定要记住学习源码实现才是提升编码能力的根本,一定要坚持下去,如遇到任何不明白的地方请留言,燕哥会第一时间给你回复。当然,由于本人能力有限难免有疏忽错误之处,如果觉得源码实现中有改进或者错误之处也请不吝赐教。

1、ITree的设计与实现

首先,我们参看原论文中的ITree的构造伪代码:

1.1 设计ITree类的数据结构

由原论文[1,2]以及上述伪代码可知,ITree是一个二叉树,并且构建ITree的算法采用的是递归构建。同时构造的结束条件是:

  1. 当前节点的高度超过了算法设置的阈值 l ;
  2. 当前子树只包含一个叶节点;
  3. 当前子树的所有节点值的所有属性完全一致。

并且在递归的时候,我们需要随机的选择属性集 Q 中的一个属性Qi以及该属性在给出的输入数据上对应的最大值和最小值之间的一个值 q ,来将当前节点包含的样本分为左右子树。因此我们为了后续算法设计的方便,需要记录被选中的属性Qi的索引值 attrIndex,以及算出来的q值attrValue,因为后面算法需要根据该节点为根节点的子树包含的叶节点总数估计该节点总高度,因此我们还需要定义一个变量leafNodes记录树的叶子节点总数,同时我们还需要一个成员变量curHeight来记录该节点的实际高度。当然,二叉树少不了定义左右子树指针 lTree 和 rTree。

因此,我设计了如下的数据结构 ITree

public class ITree {

    // 被选中的属性索引
    public int attrIndex;

    // 被选中的属性的一个具体的值
    public double attrValue;

    // 树的总叶子节点数
    public int leafNodes;

    // 该节点在树种的高度
    public int curHeight;

    // 左右孩子书
    public ITree lTree, rTree;

    // 构造函数,初始化ITree中的值
    public ITree(int attrIndex, double attrValue) {
        // 默认高度,树的高度从0开始计算
      this.curHeight = 0;

        this.lTree = null;
        this.rTree = null;
        this.leafNodes = 1;
        this.attrIndex = attrIndex;
        this.attrValue = attrValue;
    }
    ...
 }

1.2 递归地构造二叉树ITree

根据原论文中的算法2的伪代码,我们知道递归地构造二叉树ITree分为两个部分:

  • 第一,首先判断是否满足1.1节列出的三个递归结束条件;
  • 第二,随机的选取属性集中的一个属性以及该属性集下的一个具体的值,然后根据该属性以及生成的属性值将父节点中包含的样本数据划分到左右子树,并递归地创建左右子树。
  • 同时记录每个节点包含的叶子节点数和当前节点在整个树中的实际高度。

参看如下的详细代码实现:

/**
 * 根据samples样本数据递归的创建 ITree 树
 */
public static ITree createITree(double[][] samples, 
                    int curHeight, int limitHeight) {

    ITree iTree = null;

    /** 第一步:判断递归是否满足结束条件 **/
    if (samples.length == 0) {
        return iTree;
    } else if (curHeight >= limitHeight || 
        samples.length == 1) {
        iTree = new ITree(0, samples[0][0]);
        iTree.leafNodes = 1;
        iTree.curHeight = curHeight;
        return iTree;
    }

    int rows = samples.length;
    int cols = samples[0].length;

    // 判断是否所有样本都一样,如果都一样构建也终止
    boolean isAllSame = true;
    break_label:
    for (int i = 0; i < rows - 1; i++) {
        for (int j = 0; j < cols; j++) {
            if (samples[i][j] != samples[i + 1][j]) {
                isAllSame = false;
                break break_label;
            }
        }
    }

    // 所有的样本都一样,构建终止,返回的是叶节点
    if (isAllSame == true) {
        iTree = new ITree(0, samples[0][0]);
        iTree.leafNodes = samples.length;
        iTree.curHeight = curHeight;
        return iTree;
    }


    /***第二步:不满足递归结束条件,继续递归产生子树 **/
    Random random = new Random(
            System.currentTimeMillis());

    int attrIndex = random.nextInt(cols);

    // 找这个被选维度的最大值和最小值
    double min, max;
    min = samples[0][attrIndex];
    max = min;
    for (int i = 1; i < rows; i++) {
        if (samples[i][attrIndex] < min) {
            min = samples[i][attrIndex];
        }
        if (samples[i][attrIndex] > max) {
            max = samples[i][attrIndex];
        }
    }

    // 计算划分属性值
   double attrValue = random.nextDouble() 
                        * (max - min) + min;

    // 将所有的样本的attrIndex对应的属性与
   // attrValue 进行比较以选出左右子树对应的样本
   int lnodes = 0, rnodes = 0;
    double curValue;
    for (int i = 0; i < rows; i++) {
        curValue = samples[i][attrIndex];
        if (curValue < attrValue) {
            lnodes++;
        } else {
            rnodes++;
        }
    }

    double[][] lSamples = new double[lnodes][cols];
    double[][] rSamples = new double[rnodes][cols];

    lnodes = 0;
    rnodes = 0;
    for (int i = 0; i < rows; i++) {
        curValue = samples[i][attrIndex];
        if (curValue < attrValue) {
            lSamples[lnodes++] = samples[i];
        } else {
            rSamples[rnodes++] = samples[i];
        }
    }
    // 创建父节点
   ITree parent = new ITree(attrIndex, attrValue);
    parent.leafNodes = rows;
    parent.curHeight = curHeight;
    parent.lTree = createITree(lSamples,
                         curHeight + 1, limitHeight);
    parent.rTree = createITree(rSamples, 
                         curHeight + 1, limitHeight);
    return parent;
}

2、IForest的设计与实现

原论文的算法1的伪代码如下所示:

由上图的伪代码,我们知道,IForest类主要作用就是用来做两件事:

  1. 对输入数据进行子采样后构建ITree;
  2. 将所有构建的ITree合并,构成检测森林。

2.1 设计IForest类的数据结构

因此,我们设计了如下的基本数据结构类。其中IForest 中的两个属性值center0 和 center1 是用于记录最终的异常和异常类别的异常因子中心(后续判断异常类中心的时候你就知道为什么叫异常因子中心)的,分类预测需要用这两个变量。subSampleSize是整个算法的子采样数目(默认值256)。iTreeList 是一个用于保存所有ITree的列表。

public class IForest {

    // center0代表异常类中心,center1代表正常类中心
    private Double center0;
    private Double center1;

    // 样本集子采样的数目
    private int subSampleSize;

    // IForest中包含的ITree链表
    private List<ITree> iTreeList;

    /**
     * 无参构造函数,contamination设置为默认值0.1
     */
    public IForest() {
        this.center0 = null;
        this.center1 = null;
        this.subSampleSize = 256;
        this.iTreeList = new ArrayList<>();
    }
   ...
 }

2.2 构建森林

初始化玩IForest之后的第一件事,当然是构建一颗一颗的ITree,并将它们添加到 iTreeList 中组合在一起,形成检测树森林。当然,首先我们得设置树的高度上限 limitHeight 。

/**
 * 创建IForest
 */
private void createIForest(double[][] samples, int t) 
                                   throws Exception {

    // 方法参数合法性检验
    if (samples == null || samples.length == 0) {
        throw new Exception("Samples is null or empty, 
                                    please check...");
    } else if (t <= 0) {
        throw new Exception("Number of subtree t 
                                must be a positive...");
    } else if (subSampleSize <= 0) {
        throw new Exception("subSampleSize 
                                must be a positive...");
    }

    int limitHeight = (int) Math.ceil(
    Math.log(subSampleSize) / Math.log(2));

    ITree iTree;
    double[][] subSample;

    for (int i = 0; i < t; i++) {
        subSample = this.getSubSamples(samples,
                                     subSampleSize);
        iTree = ITree.createITree(subSample, 0, limitHeight);
        this.iTreeList.add(iTree);
    }
}

2.3 计算样本的异常指数

IForest构建好了之后我们就可以对每一个样本计算他们的异常指数了,异常指数的计算方法请参看我的另一篇博文,结合代码就知道了。computeCn(subSampleSize) 的具体计算方法请参看我的另一篇博文以及具体的实现方法请看我的GitHub源代码。

在这里,我们就能看出来,我们计算的 ehx 是当前样本在所有ITree上的平均高度,其目的是为了增强树高度估计的鲁棒性,这也是为什么引入森林的原因。

/**
 * 计算某一个样本的异常指数
 */
private double computeAnomalyIndex(double[] sample) 
throws Exception {

    if (iTreeList == null || iTreeList.size() == 0) {
        throw new Exception("iTreeList is 
              empty,please create IForest...");
    } else if (sample == null || sample.length == 0) {
        throw new Exception("Sample is 
               null or empty, please check...");
    }

    // 样本在所有iTree上的平均高度(改进后的)
    double ehx = 0;
    double pathLength = 0;
    for (ITree iTree : iTreeList) {
        pathLength = computePathLength(sample, iTree);
        ehx += pathLength;
    }
    ehx /= iTreeList.size();

    double cn = computeCn(subSampleSize);
    double index = ehx / cn;

    double anomalyIndex = Math.pow(2, -index);
    return anomalyIndex;
}

2.4 计算路径高度

路径高度计算函数用于“估算” (为什么是估算请看上一篇博文或者原论文) 样本点在当前ITree上的高度。

详细代码实现如下:

/**
 * 计算样本sample在ITree上的PathLength
 */
private double computePathLength(double[] sample,
final ITree iTree) throws Exception {

    // 参数合法性检查
    if (sample == null || sample.length == 0) {
        throw new Exception("Sample is 
          null or empty, please check...");
    } else if (iTree == null || iTree.leafNodes == 0) {
        throw new Exception("iTree is 
             null or empty, please check...");
    }

    double pathLength = -1;
    double attrValue;
    ITree tmpITree = iTree;

    while (tmpITree != null) {
        pathLength += 1;
        attrValue = sample[tmpITree.attrIndex];

        if (tmpITree.lTree == null || 
            tmpITree.rTree == null || 
            attrValue == tmpITree.attrValue) {
            break;
        } else if (attrValue < tmpITree.attrValue) {
            tmpITree = tmpITree.lTree;
        } else {
            tmpITree = tmpITree.rTree;
        }
    }

    return pathLength + computeCn(tmpITree.leafNodes);
}

3、训练过程

细心的读者都会发现,上述的所有成员函数都是私有成员。这就是说无论是类还是对象都不能在主程序中调用它们,所以为了接口使用的方便一个简单的 train 方法,所有的这些私有成员方法都是在 train 方法的内部被调用。这样就体现了API接口设计的用户友好的原则,否则万一使用者不知道它们之间调用的先后顺序就出问题了。读者以后设计接口的时候也请注意!!!

public int[] train(double[][] samples, int t, 
    int subSampleSize, int iters) throws Exception {

    this.subSampleSize = subSampleSize;
    if (this.subSampleSize > samples.length) {
        this.subSampleSize = samples.length;
    }

    // 第一步:创建Isolation Forest
    createIForest(samples, t);

    // 第二步:计算所有样本的异常指数
    double[] scores = computeAnomalyIndex(samples);

    // 第三步:获取类标,并设置聚类中心
    int[] labels = classifyByCluster(scores, iters);
    return labels;
}

观察训练过程,我们知道总共的训练分三步:

  1. 第一步:创建Isolation Forest;
  2. 第二步:计算所有样本的异常指数;
  3. 第三步:获取类标,并设置类标中心。

根据前面的讲解,我们能明白前两步,但是第三步我们是如何获取类标和计算类标中心的呢?原论文只讲到异常指数趋向于0就是正常样本,趋向于1就是异常样本,如果全体都趋向于0.5左右,那么全体都是正常的。

但是就这些我们也没办法去判断到底多少是异常多少是正常,怎么去找这个界限或者说阈值?

4、计算类标中心

对于计算这个阈值,我曾经想过留给用户作为算法的参数,因为不同的情况下阈值根本不一样,所以我们不能在算法中固定死它的具体取值。但是为了减少算法参数以及用户的训练次数不至于多次尝试选取阈值浪费时间的考虑,我本人灵机一动,结合KMeans聚类的思想想出了用聚类的方法计算两类异常指数的中心点。

具体的方法是这样的,因为原论文虽然没有给出如何计算和判定类标属性,但是给出了大致的方针:趋向于0就是正常样本,趋向于1就是异常样本。

根据KMeans的思想,我们首先对所有的一维数据异常指数进行KMeans计算(K=2),这样我们就能计算到两个类的类标中心。我们知道了每个类的类标中心我们不就知道每个样本是哪个类了:离哪个近就是哪个类啊!我是不是很聪明? 并且在这里,由于上述方针,我们在进行KMeans计算类中心之前可以先将直接将初始类中心点设置为所有的异常指数的最大值和最小值,这也解决了KMeans方法在选初始类中心时可能导致算法不准确的问题(想一想为什么)。 这样判断类标的问题就完美解决了。好了,看代码吧:

/**
 * 通过使用聚类的思想,根据anomalyIndex进行分类获取类标
 */
private int[] classifyByCluster(double[] scores, int iters) {
     // 两个聚类中心
    center0 = scores[0]; // 异常类的聚类中心
    center1 = scores[0]; // 正常类的聚类中心
    
    /* 根据原论文,异常指数接近1说明是异常点,
         接近0为正常点。所以,将center0、center1
         分别初始化为scores中的最大值和最小值。这
         样就相当于KMeans聚类的初始点的选择,解决
         了KMeans聚类的不稳定性。*/
    for (int i = 1; i < scores.length; i++) {
        if (scores[i] > center0) {
            center0 = scores[i];
        }
        if (scores[i] < center1) {
            center1 = scores[i];
        }
    }
    int cnt0, cnt1;
    double diff0, diff1;
    int[] labels = new int[scores.length];
    // 迭代聚类(迭代iters次)
    for (int n = 0; n < iters; n++) {
        // 判断每个样本的类别
      cnt0 = 0;
        cnt1 = 0;

        for (int i = 0; i < scores.length; i++) {
            // 计算当前点与两个聚类中心的距离
         diff0 = Math.abs(scores[i] - center0);
            diff1 = Math.abs(scores[i] - center1);
            // 根据与聚类中心的距离,判断类标
            if (diff0 < diff1) {
                labels[i] = 0;
                cnt0++;
            } else {
                labels[i] = 1;
                cnt1++;
            }
        }
        // 保存旧的聚类中心
      diff0 = center0;
        diff1 = center1;
        // 重新计算聚类中心
      center0 = 0.0;
        center1 = 0.0;
        for (int i = 0; i < scores.length; i++) {
            if (labels[i] == 0) {
                center0 += scores[i];
            } else {
                center1 += scores[i];
            }
        }
        center0 /= cnt0;
        center1 /= cnt1;
        // 提前迭代终止条件
        if (center0 - diff0 <= 1e-6 && 
               center1 - diff1 <= 1e-6) {
            break;
        }
    }
    return labels;
}

5、对未知样本的预测

有了类标中心 center0 和 center1 那么对于新来的未知样本,我们可以首先计算它的异常指数值,然后看它离那个样本中心近,本代码中 center0 大于 center1 。所以与 center0接近就是异常点,反之为正常点。不解释了,直接上代码:

/**
 * 预测样本 sample 是否为异常值,正常返回1,异常返回-1
 */
public int predict(double[] sample) throws Exception {
    double score = computeAnomalyIndex(sample);
    double dis0 = Math.abs(score - center0);
    double dis1 = Math.abs(score - center1);

    // 与哪个中心近说明改点被判断为哪一类
    if (dis0 > dis1) {
        return 1;
    } else {
        return 0;
    }
}

参考文献

  • http://cs.nju.edu.cn/zhouzh/zhouzh.files/publication/icdm08b.pdf
  • https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/ensemble/iforest.py
  • http://scikit-learn.org/dev/modules/generated/sklearn.ensemble.IsolationForest.html

理解了Isolation Forest算法吗?

原文发布于微信公众号 - 机器学习算法全栈工程师(Jeemy110)

原文发表时间:2017-08-28

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏tkokof 的技术,小趣及杂念

数学小记

结果为 2n + 1, 考虑到任一奇数都可以表示成这种形式,所以使用以下构造方法即可立即得到上述的b和c:

1003
来自专栏AI科技评论

开发|简单有趣的 NLP 教程:手把手教你用 PyTorch 辨别自然语言(附代码)

最近在学pyTorch的实际应用例子。这次说个简单的例子:给定一句话,判断是什么语言。这个例子是比如给定一句话: Give it to me 判断是 ENGLI...

2866
来自专栏机器学习原理

深度学习——CNN(4)分析

1554
来自专栏青玉伏案

算法与数据结构(五) 普利姆与克鲁斯卡尔的最小生成树(Swift版)

上篇博客我们聊了图的物理存储结构邻接矩阵和邻接链表,然后在此基础上给出了图的深度优先搜索和广度优先搜索。本篇博客就在上一篇博客的基础上进行延伸,也是关于图的。今...

2097
来自专栏AI研习社

手把手教你用 PyTorch 辨别自然语言(附代码)

最近在学pyTorch的实际应用例子。这次说个简单的例子:给定一句话,判断是什么语言。这个例子是比如给定一句话: Give it to me 判断是 ENGLI...

3285
来自专栏机器之心

教程 | 在Python和TensorFlow上构建Word2Vec词嵌入模型

选自adventuresinmachinelearning 机器之心编译 参与:李诗萌、刘晓坤 本文详细介绍了 word2vector 模型的模型架构,以及 T...

3697
来自专栏机器学习原理

图像处理和数据增强图片处理数据增强颜色空间转换噪音数据的加入样本不均衡

2954
来自专栏CreateAMind

keras doc 10终结篇 激活函数 回调函数 正则项 约束项 预训练模型

激活函数可以通过设置单独的激活层实现,也可以在构造层对象时通过传递activation参数实现。

1343
来自专栏和蔼的张星的图像处理专栏

opencv小项目练习之数独求解

对于给定的数独照片(尽可能干净整齐),进行一系列处理,提取位置和数字信息,这中间可能要用到一系列图像处理的基本算法,数字识别时初步打算用knn来做,knn对手写...

421
来自专栏yl 成长笔记

模拟退火算法实例(c++ 与 c# 实现)

此片文章主要参考CSDN博主里头的一篇文章, 将自己的理解写下来,以方便后期的查阅。

603

扫描关注云+社区