大家好,又见面了,我是你们的朋友全栈君。
Mean Shift算法,又称为均值漂移算法,Mean Shift的概念最早是由Fukunage在1975年提出的,在后来由Yizong Cheng对其进行扩充,主要提出了两点的改进:
核函数的定义使得偏移值对偏移向量的贡献随之样本与被偏移点的距离的不同而不同。权重系数使得不同样本的权重不同。Mean Shift算法在聚类,图像平滑、分割以及视频跟踪等方面有广泛的应用。
在Mean Shift算法中引入核函数的目的是使得随着样本与被偏移点的距离不同,其偏移量对均值偏移向量的贡献也不同。核函数是机器学习中常用的一种方式。核函数的定义如下所示:
X \mathbf{X}表示一个 d d维的欧式空间, x x是该空间中的一个点 x={ x1,x2,x3⋯,xd} x=\left \{ x_1,x_2,x_3\cdots ,x_d \right \},其中, x x的模 ∥x∥2=xxT \left \| x \right \|^2=xx^T, R \mathbf{R}表示实数域,如果一个函数 K:X→R K:\mathbf{X}\rightarrow \mathbf{R}存在一个剖面函数 k:[0,∞]→R k:\left [ 0,\infty \right ]\rightarrow \mathbf{R},即 K(x)=k(∥x∥2) K\left ( x \right )=k\left ( \left \| x \right \|^2 \right ) 并且满足: (1)、 k k是非负的 (2)、 k k是非增的 (3)、 k k是分段连续的 那么,函数 K(x) K\left ( x \right )就称为核函数。
常用的核函数有高斯核函数。高斯核函数如下所示:
N(x)=12π−−√he−x22h2
N\left ( x \right )=\frac{1}{\sqrt{2\pi }h}e^{-\frac{x^2}{2h^2}}
其中, h h称为带宽(bandwidth),不同带宽的核函数如下图所示:
上图的画图脚本如下所示:
''' Date:201604026 @author: zhaozhiyong '''
import matplotlib.pyplot as plt
import math
def cal_Gaussian(x, h=1):
molecule = x * x
denominator = 2 * h * h
left = 1 / (math.sqrt(2 * math.pi) * h)
return left * math.exp(-molecule / denominator)
x = []
for i in xrange(-40,40):
x.append(i * 0.5);
score_1 = []
score_2 = []
score_3 = []
score_4 = []
for i in x:
score_1.append(cal_Gaussian(i,1))
score_2.append(cal_Gaussian(i,2))
score_3.append(cal_Gaussian(i,3))
score_4.append(cal_Gaussian(i,4))
plt.plot(x, score_1, 'b--', label="h=1")
plt.plot(x, score_2, 'k--', label="h=2")
plt.plot(x, score_3, 'g--', label="h=3")
plt.plot(x, score_4, 'r--', label="h=4")
plt.legend(loc="upper right")
plt.xlabel("x")
plt.ylabel("N")
plt.show()
对于Mean Shift算法,是一个迭代的步骤,即先算出当前点的偏移均值,将该点移动到此偏移均值,然后以此为新的起始点,继续移动,直到满足最终的条件。此过程可由下图的过程进行说明(图片来自参考文献3):
从上述过程可以看出,在Mean Shift算法中,最关键的就是计算每个点的偏移均值,然后根据新计算的偏移均值更新点的位置。
对于给定的 d d维空间 Rd R^d中的 n n个样本点 xi,i=1,⋯,n x_i, i=1,\cdots , n,则对于 x x点,其Mean Shift向量的基本形式为:
Mh(x)=1k∑xi∈Sh(xi−x)
M_h\left ( x \right )=\frac{1}{k}\sum_{x_i\in S_h}\left ( x_i-x \right )
其中, Sh S_h指的是一个半径为 h h的高维球区域,如上图中的蓝色的圆形区域。 Sh S_h的定义为:
Sh(x)=(y∣(y−x)(y−x)T⩽h2)
S_h\left ( x \right )=\left ( y\mid \left ( y-x \right )\left ( y-x \right )^T\leqslant h^2 \right )
这样的一种基本的Mean Shift形式存在一个问题:在 Sh S_h的区域内,每一个点对 x x的贡献是一样的。而实际上,这种贡献与 x x到每一个点之间的距离是相关的。同时,对于每一个样本,其重要程度也是不一样的。
基于以上的考虑,对基本的Mean Shift向量形式中增加核函数和样本权重,得到如下的改进的Mean Shift向量形式:
Mh(x)=∑ni=1GH(xi−x)w(xi)(xi−x)∑ni=1GH(xi−x)w(xi)
M_h\left ( x \right )=\frac{\sum_{i=1}^{n}G_H\left ( x_i-x \right )w\left ( x_i \right )\left ( x_i-x \right )}{\sum_{i=1}^{n}G_H\left ( x_i-x \right )w\left ( x_i \right )}
其中:
GH(xi−x)=|H|−12G(H−12(xi−x))
G_H\left ( x_i-x \right )=\left | H \right |^{-\frac{1}{2}}G\left ( H^{-\frac{1}{2}}\left ( x_i-x \right ) \right )
G(x) G\left ( x\right )是一个单位的核函数。 H H是一个正定的对称 d×d d\times d矩阵,称为带宽矩阵,其是一个对角阵。 w(xi)⩾0 w\left ( x_i \right )\geqslant 0是每一个样本的权重。对角阵 H H的形式为:
H=⎛⎝⎜⎜⎜⎜⎜h210⋮00h22⋮0⋯⋯⋯00⋮h2d⎞⎠⎟⎟⎟⎟⎟d×d
H=\begin{pmatrix}h_1^2 & 0 & \cdots & 0\\ 0 & h_2^2 & \cdots & 0\\ \vdots & \vdots & & \vdots \\ 0 & 0 & \cdots & h_d^2\end{pmatrix}_{d\times d}
上述的Mean Shift向量可以改写成:
Mh(x)=∑ni=1G(xi−xhi)w(xi)(xi−x)∑ni=1G(xi−xhi)w(xi)
M_h\left ( x \right )=\frac{\sum_{i=1}^{n}G\left ( \frac{x_i-x}{h_i} \right )w\left ( x_i \right )\left ( x_i-x \right )}{\sum_{i=1}^{n}G\left ( \frac{x_i-x}{h_i} \right )w\left ( x_i \right )}
Mean Shift向量 Mh(x) M_h\left ( x \right )是归一化的概率密度梯度。
在Mean Shift算法中,实际上是利用了概率密度,求得概率密度的局部最优解。
对一个概率密度函数 f(x) f\left ( x \right ),已知 d d维空间中 n n个采样点 xi,i=1,⋯,n x_i,i=1,\cdots ,n, f(x) f\left ( x \right )的核函数估计(也称为Parzen窗估计)为: f^(x)=∑ni=1K(xi−xh)w(xi)hd∑ni=1w(xi) \hat{f}\left ( x \right )=\frac{\sum_{i=1}^{n}K\left ( \frac{x_i-x}{h} \right )w\left ( x_i \right )}{h^d\sum_{i=1}^{n}w\left ( x_i \right )} 其中 w(xi)⩾0 w\left ( x_i \right )\geqslant 0是一个赋给采样点 xi x_i的权重 K(x) K\left ( x \right )是一个核函数
概率密度函数 f(x) f\left ( x \right )的梯度 ▽f(x) \bigtriangledown f\left ( x \right )的估计为
▽f^(x)=2∑ni=1(x−xi)k′(∥∥xi−xh∥∥2)w(xi)hd+2∑ni=1w(xi)
\bigtriangledown \hat{f}\left ( x \right )=\frac{2\sum_{i=1}^{n}\left ( x-x_i \right ){k}'\left ( \left \| \frac{x_i-x}{h} \right \|^2 \right )w\left ( x_i \right )}{h^{d+2}\sum_{i=1}^{n}w\left ( x_i \right )}
令 g(x)=−k′(x) g\left ( x \right )=-{k}'\left ( x \right ), G(x)=g(∥x∥2) G\left ( x \right )=g\left ( \left \| x \right \|^2 \right ),则有:
▽f^(x)=2∑ni=1(xi−x)G(∥∥xi−xh∥∥2)w(xi)hd+2∑ni=1w(xi)=2h2⎡⎣⎢∑ni=1G(xi−xh)w(xi)hd∑ni=1w(xi)⎤⎦⎥⎡⎣⎢∑ni=1(xi−x)G(∥∥xi−xh∥∥2)w(xi)∑ni=1G(xi−xh)w(xi)⎤⎦⎥
\begin{align*}\bigtriangledown \hat{f}\left ( x \right ) &= \frac{2\sum_{i=1}^{n}\left ( x_i-x \right )G\left ( \left \| \frac{x_i-x}{h} \right \|^2 \right )w\left ( x_i \right )}{h^{d+2}\sum_{i=1}^{n}w\left ( x_i \right )}\\ &= \frac{2}{h^2}\left [ \frac{\sum_{i=1}^{n}G\left ( \frac{x_i-x}{h} \right )w\left ( x_i \right )}{h^d\sum_{i=1}^{n}w\left ( x_i \right )} \right ]\left [ \frac{\sum_{i=1}^{n}\left ( x_i-x \right )G\left ( \left \| \frac{x_i-x}{h} \right \|^2 \right )w\left ( x_i \right )}{\sum_{i=1}^{n}G\left ( \frac{x_i-x}{h} \right )w\left ( x_i \right )} \right ]\end{align*}
其中,第二个方括号中的就是Mean Shift向量,其与概率密度梯度成正比。
Mh(x)=∑ni=1G(∥∥xi−xh∥∥2)w(xi)xi∑ni=1G(xi−xh)w(xi)−x
M_h\left ( x \right )=\frac{\sum_{i=1}^{n}G\left ( \left \| \frac{x_i-x}{h} \right \|^2 \right )w\left ( x_i \right )x_i}{\sum_{i=1}^{n}G\left ( \frac{x_i-x}{h} \right )w\left ( x_i \right )}-x
记: mh(x)=∑ni=1G(∥∥xi−xh∥∥2)w(xi)xi∑ni=1G(xi−xh)w(xi) m_h\left ( x \right )=\frac{\sum_{i=1}^{n}G\left ( \left \| \frac{x_i-x}{h} \right \|^2 \right )w\left ( x_i \right )x_i}{\sum_{i=1}^{n}G\left ( \frac{x_i-x}{h} \right )w\left ( x_i \right )},则上式变成:
Mh(x)=mh(x)+x
M_h\left ( x \right )=m_h\left ( x \right )+x
这与梯度上升的过程一致。
Mean Shift算法的算法流程如下:
实验数据如下图所示(来自参考文献1):
画图的代码如下:
''' Date:20160426 @author: zhaozhiyong '''
import matplotlib.pyplot as plt
f = open("data")
x = []
y = []
for line in f.readlines():
lines = line.strip().split("\t")
if len(lines) == 2:
x.append(float(lines[0]))
y.append(float(lines[1]))
f.close()
plt.plot(x, y, 'b.', label="original data")
plt.title('Mean Shift')
plt.legend(loc="upper right")
plt.show()
#!/bin/python
#coding:UTF-8
''' Date:20160426 @author: zhaozhiyong '''
import math
import sys
import numpy as np
MIN_DISTANCE = 0.000001#mini error
def load_data(path, feature_num=2):
f = open(path)
data = []
for line in f.readlines():
lines = line.strip().split("\t")
data_tmp = []
if len(lines) != feature_num:
continue
for i in xrange(feature_num):
data_tmp.append(float(lines[i]))
data.append(data_tmp)
f.close()
return data
def gaussian_kernel(distance, bandwidth):
m = np.shape(distance)[0]
right = np.mat(np.zeros((m, 1)))
for i in xrange(m):
right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth)
right[i, 0] = np.exp(right[i, 0])
left = 1 / (bandwidth * math.sqrt(2 * math.pi))
gaussian_val = left * right
return gaussian_val
def shift_point(point, points, kernel_bandwidth):
points = np.mat(points)
m,n = np.shape(points)
#计算距离
point_distances = np.mat(np.zeros((m,1)))
for i in xrange(m):
point_distances[i, 0] = np.sqrt((point - points[i]) * (point - points[i]).T)
#计算高斯核
point_weights = gaussian_kernel(point_distances, kernel_bandwidth)
#计算分母
all = 0.0
for i in xrange(m):
all += point_weights[i, 0]
#均值偏移
point_shifted = point_weights.T * points / all
return point_shifted
def euclidean_dist(pointA, pointB):
#计算pointA和pointB之间的欧式距离
total = (pointA - pointB) * (pointA - pointB).T
return math.sqrt(total)
def distance_to_group(point, group):
min_distance = 10000.0
for pt in group:
dist = euclidean_dist(point, pt)
if dist < min_distance:
min_distance = dist
return min_distance
def group_points(mean_shift_points):
group_assignment = []
m,n = np.shape(mean_shift_points)
index = 0
index_dict = {}
for i in xrange(m):
item = []
for j in xrange(n):
item.append(str(("%5.2f" % mean_shift_points[i, j])))
item_1 = "_".join(item)
print item_1
if item_1 not in index_dict:
index_dict[item_1] = index
index += 1
for i in xrange(m):
item = []
for j in xrange(n):
item.append(str(("%5.2f" % mean_shift_points[i, j])))
item_1 = "_".join(item)
group_assignment.append(index_dict[item_1])
return group_assignment
def train_mean_shift(points, kenel_bandwidth=2):
#shift_points = np.array(points)
mean_shift_points = np.mat(points)
max_min_dist = 1
iter = 0
m, n = np.shape(mean_shift_points)
need_shift = [True] * m
#cal the mean shift vector
while max_min_dist > MIN_DISTANCE:
max_min_dist = 0
iter += 1
print "iter : " + str(iter)
for i in range(0, m):
#判断每一个样本点是否需要计算偏置均值
if not need_shift[i]:
continue
p_new = mean_shift_points[i]
p_new_start = p_new
p_new = shift_point(p_new, points, kenel_bandwidth)
dist = euclidean_dist(p_new, p_new_start)
if dist > max_min_dist:#record the max in all points
max_min_dist = dist
if dist < MIN_DISTANCE:#no need to move
need_shift[i] = False
mean_shift_points[i] = p_new
#计算最终的group
group = group_points(mean_shift_points)
return np.mat(points), mean_shift_points, group
if __name__ == "__main__":
#导入数据集
path = "./data"
data = load_data(path, 2)
#训练,h=2
points, shift_points, cluster = train_mean_shift(data, 2)
for i in xrange(len(cluster)):
print "%5.2f,%5.2f\t%5.2f,%5.2f\t%i" % (points[i,0], points[i, 1], shift_points[i, 0], shift_points[i, 1], cluster[i])
经过Mean Shift算法聚类后的数据如下所示:
''' Date:20160426 @author: zhaozhiyong '''
import matplotlib.pyplot as plt
f = open("data_mean")
cluster_x_0 = []
cluster_x_1 = []
cluster_x_2 = []
cluster_y_0 = []
cluster_y_1 = []
cluster_y_2 = []
center_x = []
center_y = []
center_dict = {}
for line in f.readlines():
lines = line.strip().split("\t")
if len(lines) == 3:
label = int(lines[2])
if label == 0:
data_1 = lines[0].strip().split(",")
cluster_x_0.append(float(data_1[0]))
cluster_y_0.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
elif label == 1:
data_1 = lines[0].strip().split(",")
cluster_x_1.append(float(data_1[0]))
cluster_y_1.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
else:
data_1 = lines[0].strip().split(",")
cluster_x_2.append(float(data_1[0]))
cluster_y_2.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
f.close()
plt.plot(cluster_x_0, cluster_y_0, 'b.', label="cluster_0")
plt.plot(cluster_x_1, cluster_y_1, 'g.', label="cluster_1")
plt.plot(cluster_x_2, cluster_y_2, 'k.', label="cluster_2")
plt.plot(center_x, center_y, 'r+', label="mean point")
plt.title('Mean Shift 2')
#plt.legend(loc="best")
plt.show()
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/158343.html原文链接:https://javaforall.cn