# （数据科学学习手札14）Mean-Shift聚类法简单介绍及Python实现

1.随机确定样本空间内一个半径确定的高维球及其球心；

2.求该高维球内质心，并将高维球的球心移动至该质心处；

3.重复2，直到高维球内的密度随着继续的球心滑动变化低于设定的阈值，算法结束

```from sklearn.cluster import MeanShift
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from matplotlib.pyplot import style
import numpy as np
'''设置绘图风格'''
style.use('ggplot')
'''生成演示用样本数据'''
data1 = np.random.normal(0,0.3,(1000,2))
data2 = np.random.normal(1,0.2,(1000,2))
data3 = np.random.normal(2,0.3,(1000,2))

data = np.concatenate((data1,data2,data3))

# data_tsne = TSNE(learning_rate=100).fit_transform(data)
'''搭建Mean-Shift聚类器'''
clf=MeanShift()
'''对样本数据进行聚类'''
predicted=clf.fit_predict(data)
colors = [['red','green','blue','grey'][i] for i in predicted]
'''绘制聚类图'''
plt.scatter(data[:,0],data[:,1],c=colors,s=10)
plt.title('Mean Shift')```

```from sklearn.cluster import MeanShift
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from matplotlib.pyplot import style
import numpy as np
'''设置绘图风格'''
style.use('ggplot')
'''生成演示用样本数据'''
data1 = np.random.normal(0,0.3,(1000,6))
data2 = np.random.normal(1,0.2,(1000,6))
data3 = np.random.normal(2,0.3,(1000,6))

data = np.concatenate((data1,data2,data3))

data_tsne = TSNE(learning_rate=100).fit_transform(data)
'''搭建Mean-Shift聚类器'''
clf=MeanShift()
'''对样本数据进行聚类'''
predicted=clf.fit_predict(data)
colors = [['red','green','blue','grey'][i] for i in predicted]
'''绘制聚类图'''
plt.scatter(data_tsne[:,0],data_tsne[:,1],c=colors,s=10)
plt.title('Mean Shift')```

```import matplotlib.pyplot as plt
from sklearn.cluster import MeanShift
from sklearn.manifold import TSNE
import pandas as pd
import numpy as np
from matplotlib.pyplot import style

style.use('ggplot')

input = pd.DataFrame({'score':data['商家评分'][data['数据所属期'] == data.iloc[0,0]],
'comment':data['商家评论数'][data['数据所属期'] == data.iloc[0,0]],
'sales':data['本月销售额'][data['数据所属期'] == data.iloc[0,0]]})

'''去缺省值'''
input = input.dropna()

input_tsne = TSNE(learning_rate=100).fit_transform(input)

'''创造色彩列表'''
with open(r'C:\Users\windows\Desktop\colors.txt','r') as cc:
col = [col[i][:7] for i in range(len(col)) if col[i][0] == '#']

'''进行Mean-Shift聚类'''
clf = MeanShift()
cl = clf.fit_predict(input)

'''绘制聚类结果'''
np.random.shuffle(col)
plt.scatter(input_tsne[:,0],input_tsne[:,1],c=[col[i] for i in cl],s=8)
plt.title('Mean-Shift Cluster of {}'.format(str(len(set(cl)))))```

0 条评论

## 相关文章

9013

### Python数据处理从零开始----第四章（可视化）（3）目录正文

plt.scatter相对于plt.plot的主要优势在于，前者在创建散点图时具有更高的灵活性，可以单独控制每个散点与数据匹配，也可以让每个散点具有不同的属性（...

792

1542

3967

5237

3355

2312

7250

1432

3450