学习一时爽,一直学习一直爽
Hello,大家好,我是 もうり,一个从无到有的技术+语言小白。
https://blog.csdn.net/weixin_44510615/article/details/89216162
EM 算法,指的是最大期望算法(Expectation Maximization Algorithm,期望最大化算法),是一种迭代算法,在统计学中被用于寻找,依赖于不可观察的隐性变量的概率模型中,参数的最大似然估计。基本思想是首先随机取一个值去初始化待估计的参数值,然后不断迭代寻找更优的参数使得其似然函数比原来的似然函数大。
EM 算法当做最大似然估计的拓展,解决难以给出解析解(模型中存在隐变量)的最大似然估计(MLE)问题
# !/usr/bin/python# -*- coding:utf-8 -*-import numpy as npimport pandas as pdfrom sklearn.mixture import GaussianMixtureimport matplotlib as mplimport matplotlib.colorsimport matplotlib.pyplot as pltfrom sklearn.metrics.pairwise import pairwise_distances_argminmpl.rcParams<span style="font.sans-serif">] = ['SimHei'</span>mpl.rcParams['axes.unicode_minus'] = Falseiris_feature = '花萼长度', '花萼宽度', '花瓣长度', '花瓣宽度'def expand(a, b, rate=0.05): d = (b - a) * rate return a-d, b+dif __name__ == '__main__': path = 'iris.data' data = pd.read_csv(path, header=None) x_prime = data[np.arange(4)] y = pd.Categorical(data[4]).codes n_components = 3 feature_pairs = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]] plt.figure(figsize=(8, 6), facecolor='w') for k, pair in enumerate(feature_pairs, start=1): x = x_prime[pair] m = np.array([np.mean(x[y == i], axis=0) for i in range(3)]) # 均值的实际值 print('实际均值 = \n', m) gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=0) gmm.fit(x) print('预测均值 = \n', gmm.means_) print('预测方差 = \n', gmm.covariances_) y_hat = gmm.predict(x) order = pairwise_distances_argmin(m, gmm.means_, axis=1, metric='euclidean') print('顺序:\t', order) n_sample = y.size n_types = 3 change = np.empty((n_types, n_sample), dtype=np.bool) for i in range(n_types): change[i] = y_hat == order[i] for i in range(n_types): y_hat[change[i]] = i acc = '准确率:%.2f%%' % (100*np.mean(y_hat == y)) print(acc) cm_light = mpl.colors.ListedColormap(<span style="color:#FF8080">, '#77E0A0', '#A0A0FF'</span>) cm_dark = mpl.colors.ListedColormap(<span style="r">, 'g', '#6060FF'</span>) x1_min, x2_min = x.min() x1_max, x2_max = x.max() x1_min, x1_max = expand(x1_min, x1_max) x2_min, x2_max = expand(x2_min, x2_max) x1, x2 = np.mgrid[x1_min:x1_max:200j, x2_min:x2_max:200j]
grid_test = np.stack((x1.flat, x2.flat), axis=1) grid_hat = gmm.predict(grid_test) change = np.empty((n_types, grid_hat.size), dtype=np.bool) for i in range(n_types): change[i] = grid_hat == order[i] for i in range(n_types): grid_hat[change[i]] = i grid_hat = grid_hat.reshape(x1.shape) plt.subplot(2, 3, k) plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light) plt.scatter(x[pair[0]], x[pair[1]], s=20, c=y, marker='o', cmap=cm_dark, edgecolors='k') xx = 0.95 * x1_min + 0.05 * x1_max yy = 0.1 * x2_min + 0.9 * x2_max plt.text(xx, yy, acc, fontsize=10) plt.xlim((x1_min, x1_max)) plt.ylim((x2_min, x2_max)) plt.xlabel(iris_feature[pair[0]], fontsize=11) plt.ylabel(iris_feature[pair[1]], fontsize=11) plt.grid(b=True, ls=':', color='#606060') plt.suptitle('EM算法无监督分类鸢尾花数据', fontsize=14) plt.tight_layout(1, rect=(0, 0, 1, 0.95)) plt.show()