首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >计算质心和精度

计算质心和精度
EN

Stack Overflow用户
提问于 2017-06-01 08:53:33
回答 1查看 425关注 0票数 0

我从暹罗网络获得了两点feat_left, feat_right,我在x,y坐标中绘制了这些点,如下所示。

下面是python脚本

代码语言:javascript
复制
import json
import matplotlib.pyplot as plt
import numpy as np



data = json.load(open('predictions-mnist.txt'))

n=len(data['outputs'].items())
label_list = np.array(range(n))
feat_left = np.random.random((n,2))


count=1

for key,val in data['outputs'].items():
    feat = data['outputs'][key]['feat_left']
    feat_left[count-1] = feat
    key = key.split("/")
    key = int(key[6])
    label_list[count - 1] = key
    count = count + 1


f = plt.figure(figsize=(16,9))

c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
     '#ff00ff', '#990000', '#999900', '#009900', '#009999']

for i in range(10):
    plt.plot(feat_left[label_list==i,0].flatten(), feat_left[label_list==i,1].flatten(), '.', c=c[i])
plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])
plt.grid()
plt.show()

现在我要计算每个星系团的向心力,然后是纯度

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-06-01 14:04:12

质心就是mean

代码语言:javascript
复制
centorids = np.zeros((10,2), dtype='f4')
for i in xrange(10):
    centroids[i,:] = np.mean( feat_left[label_list==i, :2], axis=0 )

至于精度,您可以从质心计算均方误差(距离):

代码语言:javascript
复制
sqerr = np.zeros((10,), dtype='f4')
for i in xrange(10):
    sqerr[i] = np.sum( (feat_left[label_list==i, :2]-centroids[i,:])**2 )

计算纯度

代码语言:javascript
复制
def compute_cluster_purity(gt_labels, pred_labels):
  """
  Compute purity of predicted labels (pred_labels), given 
  the ground-truth labels (gt_labels).

  Assuming gt_labels and pred_labels are both lists of int of length n
  """
  n = len(gt_labels) # number of elements
  assert len(pred_labels) == n
  purity = 0
  for l in set(pred_labels):
    # for predicted label l, what are the gt_labels of this cluster?
    gt = [gt_labels[i] for i, il in enumerate(pred_labels) if il==l]
    # most frequent gt label in this cluster:
    mfgt = max(set(gt), key=gt.count)
    purity += gt.count(mfgt) # count intersection between most frequent ground truth and this cluster
  return float(purity)/n

有关选择集群中最常用的标签的详细信息,请参阅这个答案

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/44302824

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档