首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >获取除特定索引外的false掩码

获取除特定索引外的false掩码
EN

Stack Overflow用户
提问于 2021-11-19 14:38:47
回答 3查看 40关注 0票数 0

我正在学习numpy,并试图找到一种更好的方法来编写这个示例。我有一个数组,表示元素(在其他数组中)所属的集群。元素0属于簇1,元素1属于簇2,依此类推。我想创建一个集群映射,使用一个掩码来表示属于这个集群的元素。下面的代码可以工作,但我讨厌tmp_mask的两行代码,我想知道是否可以避免它们。

代码语言:javascript
复制
cluster = np.array([1,2,1,1,2,3,1,2])
cluster_map = {}

empty_mask = np.zeros(len(cluster), dtype=bool)

for idx, cl in enumerate(cluster):
    tmp_mask = empty_mask.copy() 
    tmp_mask[idx] = True
    cluster_map[cl] = cluster_map.get(cl, empty_mask) | tmp_mask

cluster_map

我只是想看看是否有更短的版本,例如:

代码语言:javascript
复制
    #tmp_mask = empty_mask.copy() 
    #tmp_mask[idx] = True
    cluster_map[cl] = cluster_map.get(cl, empty_mask) | get_falses_except(idx, len(cluster))

我知道我可以创建函数get_falses_except,只是想知道它是否存在,或者可以用更好的方式重写代码?

谢谢大家

EN

回答 3

Stack Overflow用户

回答已采纳

发布于 2021-11-19 15:20:09

你可以通过一个非常简单的理解来做到这一点:

代码语言:javascript
复制
In [8]: {k: cluster==k for k in set(cluster)}
Out[8]:
{1: array([ True, False,  True,  True, False, False,  True, False]),
 2: array([False,  True, False, False,  True, False, False,  True]),
 3: array([False, False, False, False, False,  True, False, False])}
票数 1
EN

Stack Overflow用户

发布于 2021-11-19 15:24:04

Advanced indexing可能会帮助您:

代码语言:javascript
复制
empty_mask = np.zeros((np.max(cluster), len(cluster)), dtype=bool)
empty_mask[cluster-1, np.arange(len(cluster))] = True
>>> empty_mask
array([[ True, False,  True,  True, False, False,  True, False],
       [False,  True, False, False,  True, False, False,  True],
       [False, False, False, False, False,  True, False, False]])

如果需要,也可以返回dict

代码语言:javascript
复制
>>> dict(zip(range(1, shape[0]+1), empty_mask))
{1: array([ True, False,  True,  True, False, False,  True, False]),
 2: array([False,  True, False, False,  True, False, False,  True]),
 3: array([False, False, False, False, False,  True, False, False])}
票数 2
EN

Stack Overflow用户

发布于 2021-11-19 15:13:50

我认为for循环也可以跳过,但我不知道如何跳过

代码语言:javascript
复制
import numpy as np
cluster = np.array([1,2,1,1,2,3,1,2])

cluster_set=np.unique(cluster)
cluster_map=np.zeros((cluster_set.size, cluster.size), dtype=bool)
for idx, val in enumerate(cluster_set):
    cluster_map[idx]=cluster==val

print(cluster_map)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70036532

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档