专栏首页腾讯智能钛AI开发者【技术分享】高斯混合模型
原创

【技术分享】高斯混合模型

本文原作者:尹迪,经授权发布。

| 导语 现有的高斯模型有单高斯模型(SGM)和高斯混合模型(GMM)两种。从几何上讲,单高斯分布模型在二维空间上近似于椭圆,在三维空间上近似于椭球。在很多情况下,属于同一类别的样本点并不满足“椭圆”分布的特性,所以我们需要引入混合高斯模型来解决这种情况。

1 单高斯模型

  多维变量X服从高斯分布时,它的概率密度函数PDF定义如下:

  在上述定义中,x是维数为D的样本向量,mu是模型期望,sigma是模型协方差。对于单高斯模型,可以明确训练样本是否属于该高斯模型,所以我们经常将mu用训练样本的均值代替,将sigma用训练样本的协方差代替。 假设训练样本属于类别C,那么上面的定义可以修改为下面的形式:

  这个公式表示样本属于类别C的概率。我们可以根据定义的概率阈值来判断样本是否属于某个类别。

2 高斯混合模型

  高斯混合模型,顾名思义,就是数据可以看作是从多个高斯分布中生成出来的。从中心极限定理可以看出,高斯分布这个假设其实是比较合理的。 为什么我们要假设数据是由若干个高斯分布组合而成的,而不假设是其他分布呢?实际上不管是什么分布,只K取得足够大,这个XX Mixture Model就会变得足够复杂,就可以用来逼近任意连续的概率密度分布。只是因为高斯函数具有良好的计算性能,所GMM被广泛地应用。

  每个GMMK个高斯分布组成,每个高斯分布称为一个组件(Component),这些组件线性加成在一起就组成了GMM的概率密度函数 (1):

  

根据上面的式子,如果我们要从GMM分布中随机地取一个点,需要两步:

  • 随机地在这K个组件之中选一个,每个组件被选中的概率实际上就是它的系数pi_k
  • 选中了组件之后,再单独地考虑从这个组件的分布中选取一个点。

  怎样用GMM来做聚类呢?其实很简单,现在我们有了数据,假定它们是由GMM生成出来的,那么我们只要根据数据推出GMM的概率分布来就可以了,然后GMMK个组件实际上就对应了K个聚类了。 在已知概率密度函数的情况下,要估计其中的参数的过程被称作“参数估计”。

  我们可以利用最大似然估计来确定这些参数,GMM的似然函数 (2) 如下(此处公式有误,括号中的x应该为x_i):

  可以用EM算法来求解这些参数。EM算法求解的过程如下:

  • 1 E-步。求数据点由各个组件生成的概率(并不是每个组件被选中的概率)。对于每个数据$x_{i}$来说,它由第k个组件生成的概率为公式 (3)

  在上面的概率公式中,我们假定musigma均是已知的,它们的值来自于初始化值或者上一次迭代。

  • 2 M-步。估计每个组件的参数。由于每个组件都是一个标准的高斯分布,可以很容易分布求出最大似然所对应的参数值,分别如下公式 (4), (5), (6), (7)

3 源码分析

3.1 实例

  在分析源码前,我们还是先看看高斯混合模型如何使用。

import org.apache.spark.mllib.clustering.GaussianMixture
import org.apache.spark.mllib.clustering.GaussianMixtureModel
import org.apache.spark.mllib.linalg.Vectors
// 加载数据
val data = sc.textFile("data/mllib/gmm_data.txt")
val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))).cache()
// 使用高斯混合模型聚类
val gmm = new GaussianMixture().setK(2).run(parsedData)
// 保存和加载模型
gmm.save(sc, "myGMMModel")
val sameModel = GaussianMixtureModel.load(sc, "myGMMModel")
// 打印参数
for (i <- 0 until gmm.k) {
  println("weight=%f\nmu=%s\nsigma=\n%s\n" format
    (gmm.weights(i), gmm.gaussians(i).mu, gmm.gaussians(i).sigma))
}

  由上面的代码我们可以知道,使用高斯混合模型聚类使用到了GaussianMixture类中的run方法。下面我们直接进入run方法,分析它的实现。

3.2 高斯混合模型的实现

3.2.1 初始化

  在run方法中,程序所做的第一步就是初始化权重(上文中介绍的pi)及其相对应的高斯分布。

val (weights, gaussians) = initialModel match {
      case Some(gmm) => (gmm.weights, gmm.gaussians)
      case None => {
        val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
        (Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
          val slice = samples.view(i * nSamples, (i + 1) * nSamples)
          new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
        })
      }
    }

  在上面的代码中,当initialModel为空时,用所有值均为1.0/k的数组初始化权重,用值为MultivariateGaussian对象的数组初始化所有的高斯分布(即上文中提到的组件)。 每一个MultivariateGaussian对象都由从数据集中抽样的子集计算而来。这里用样本数据的均值和方差初始化MultivariateGaussianmusigma

//计算均值
private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
    val v = BDV.zeros[Double](x(0).length)
    x.foreach(xi => v += xi)
    v / x.length.toDouble
  }
//计算方差
private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = {
    val mu = vectorMean(x)
    val ss = BDV.zeros[Double](x(0).length)
    x.foreach(xi => ss += (xi - mu) :^ 2.0)
    diag(ss / x.length.toDouble)
  }

3.2.2 EM算法求参数

  初始化后,就可以使用EM算法迭代求似然函数中的参数。迭代结束的条件是迭代次数达到了我们设置的次数或者两次迭代计算的对数似然值之差小于阈值。

 while (iter < maxIterations && math.abs(llh-llhp) > convergenceTol)

  在迭代内部,就可以按照E-步M-步来更新参数了。

  • E-步:更新参数gamma
 val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_)
 val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _)

  我们先要了解ExpectationSum以及add方法的实现。

private class ExpectationSum(
    var logLikelihood: Double,
    val weights: Array[Double],
    val means: Array[BDV[Double]],
    val sigmas: Array[BreezeMatrix[Double]]) extends Serializable

ExpectationSum是一个聚合类,它表示部分期望结果:主要包含对数似然值,权重值(第二章中介绍的pi),均值,方差。add方法的实现如下:

def add( weights: Array[Double],dists: Array[MultivariateGaussian])
      (sums: ExpectationSum, x: BV[Double]): ExpectationSum = {
    val p = weights.zip(dists).map {
      //计算pi_i * N(x)
      case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x)
    }
    val pSum = p.sum
    sums.logLikelihood += math.log(pSum)
    var i = 0
    while (i < sums.k) {
      p(i) /= pSum  
      sums.weights(i) += p(i)  
      sums.means(i) += x * p(i)  
      //A := alpha * x * x^T^ + A
      BLAS.syr(p(i), Vectors.fromBreeze(x),
        Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
      i = i + 1
    }
    sums
  }

  从上面的实现我们可以看出,最终,logLikelihood表示公式 (2) 中的对数似然。pweights分别表示公式 (3) 中的gammapimeans表示公式 (6) 中的求和部分,sigmas表示公式 (7) 中的求和部分。

  调用RDDaggregate方法,我们可以基于所有给定数据计算上面的值。利用计算的这些新值,我们可以在M-步中更新musigma

  • M-步:更新参数musigma
 var i = 0
 while (i < k) {
    val (weight, gaussian) =
       updateWeightsAndGaussians(sums.means(i), sums.sigmas(i), sums.weights(i), sumWeights)
    weights(i) = weight
    gaussians(i) = gaussian
    i = i + 1
 }
 private def updateWeightsAndGaussians(
      mean: BDV[Double],
      sigma: BreezeMatrix[Double],
      weight: Double,
      sumWeights: Double): (Double, MultivariateGaussian) = {
    //  mean/weight
    val mu = (mean /= weight)
    // -weight * mu * mut +sigma
    BLAS.syr(-weight, Vectors.fromBreeze(mu),
      Matrices.fromBreeze(sigma).asInstanceOf[DenseMatrix])
    val newWeight = weight / sumWeights
    val newGaussian = new MultivariateGaussian(mu, sigma / weight)
    (newWeight, newGaussian)
  }

  基于 E-步 计算出来的值,根据公式 (6) ,我们可以通过(mean /= weight)来更新mu;根据公式 (7) ,我们可以通过BLAS.syr()来更新sigma;同时,根据公式 (5), 我们可以通过weight / sumWeights来计算pi

  迭代执行以上的 E-步M-步,到达一定的迭代数或者对数似然值变化较小后,我们停止迭代。这时就可以获得聚类后的参数了。

3.3 多元高斯模型中相关方法介绍

  在上面的求参代码中,我们用到了MultivariateGaussian以及MultivariateGaussian中的部分方法,如pdfMultivariateGaussian定义如下:

class MultivariateGaussian @Since("1.3.0") (
    @Since("1.3.0") val mu: Vector,
    @Since("1.3.0") val sigma: Matrix) extends Serializable

MultivariateGaussian包含一个向量mu和一个矩阵sigma,分别表示期望和协方差。MultivariateGaussian最重要的方法是pdf,顾名思义就是计算给定数据的概率密度函数。它的实现如下:

private[mllib] def pdf(x: BV[Double]): Double = {
    math.exp(logpdf(x))
}
private[mllib] def logpdf(x: BV[Double]): Double = {
    val delta = x - breezeMu
    val v = rootSigmaInv * delta
    u + v.t * v * -0.5
 }

  上面的rootSigmaInvu通过方法calculateCovarianceConstants计算。根据公式 (1) ,这个概率密度函数的计算需要计算sigma的行列式以及逆。

sigma = U * D * U.t
inv(Sigma) = U * inv(D) * U.t = (D^{-1/2}^ * U.t).t * (D^{-1/2}^ * U.t)
-0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U.t  * (x-mu))^2^

  这里,UD是奇异值分解得到的子矩阵。calculateCovarianceConstants具体的实现代码如下:

private def calculateCovarianceConstants: (DBM[Double], Double) = {
    val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t
    val tol = MLUtils.EPSILON * max(d) * d.length
    try {
      //所有非0奇异值的对数和
      val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum
      //通过求非负值的倒数平方根,计算奇异值对角矩阵的根伪逆矩阵
      val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray))
      (pinvS * u.t, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
    } catch {
      case uex: UnsupportedOperationException =>
        throw new IllegalArgumentException("Covariance matrix has no non-zero singular values")
    }
  }

  上面的代码中,eigSym用于分解sigma矩阵。

4 参考文献

【1】漫谈 Clustering (3): Gaussian Mixture Model

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 【技术分享】奇异值分解

      在了解特征值分解之后,我们知道,矩阵A不一定是方阵。为了得到方阵,可以将矩阵A的转置乘以该矩阵。从而可以得到公式:

    腾讯智能钛AI开发者
  • 【技术分享】带权最小二乘

    $$minimize_{x}\frac{1}{2} \sum_{i=1}^n \frac{w_i(a_i^T x -b_i)^2}{\sum_{k=1}^n w...

    腾讯智能钛AI开发者
  • 【技术分享】梯度下降算法

      梯度下降(GD)是最小化风险函数、损失函数的一种常用方法,随机梯度下降和批量梯度下降是两种迭代求解思路。

    腾讯智能钛AI开发者
  • 易车公司创始人兼CEO&蔚来汽车创始人李斌:出行不等于汽车,大数据和人工智能将重新定义汽车

    <数据猿导读> 2016中国互联网大会于6月21日在北京国际会议中心举行。易车公司创始人兼CEO&蔚来汽车创始人李斌在大会中分享了自己对未来汽车发展的一些看法。...

    数据猿
  • 小学生编程入门从哪种编程语言学起?

    如果是编程零基础学习者,那么以Scratch为切入点是个不错选择。Scratch语法基于一系列孩子们可以拼插彼此的图形化“代码块”,其设计极具交互性,甚至单击一...

    贝尔科教
  • cPanet面板绑定域名和删除已绑定域名教程

    cPanel面板是一款功能强大的主机面板,国外众多大型主机商都在使用。对于cPanel面板中几十个上百个按钮功能,其实能用上的也不多。这里我们直接先步入正题,把...

    傲云
  • 零基础,如何选择一门编程语言?

    这种问题一般会被初学者问上N多遍,在这姑且分析下,选择什么语言决定性因素太多了,每个人的情况不一样,做出的决定又不尽相同。如果选择的出发点不一样选择的结果也是不...

    程序员互动联盟
  • 9-51单片机ESP8266学习-AT指令(测试TCP服务器--51单片机程序配置8266,C#TCP客户端发信息给单片机控制小灯的亮灭)

    http://www.cnblogs.com/yangfengwu/p/8780182.html 自己都是现做现写,如果想知道最终实现的功能,请看最后 先把源...

    杨奉武
  • WordPress插件设计

    如果是Php开发的同学,或者对博客和CMS有一定了解的同学都知道这个,以下是百度的解释:

    心平气和
  • 支付宝红包暴力薅羊毛

    特地去知乎搜了一波,果然有各路大佬在分享源码,特地弄了一个进行源码审计,学习学习~

    信安之路

扫码关注云+社区

领取腾讯云代金券