我正在尝试生成只包含0
's和1
's的序列。我已经编写了以下代码,它可以工作。
import numpy as np
batch = 1000
dim = 32
while 1:
is_same = False
seq = np.random.randint(0, 2, [batch, dim])
for i in range(batch):
for j in range(i + 1, batch):
if np.array_equal(seq[i], seq[j]):
is_same = True
if is_same:
continue
else:
break
我的batch
变量有数千个。上面的循环大约需要30秒才能完成。这是另一个for
循环的数据生成部分,它运行大约500个迭代,因此非常慢。是否有更快的方法不重复地生成这个序列列表?谢谢。
期望的结果是一个序列的batch_size
数集合,每个序列的长度都是dim
,只包含0
、S和1
s,因此集合中没有两个序列是相同的。
发布于 2021-01-12 08:19:05
生成所有的序列,然后检查它们是否是唯一的,这是相当昂贵的,正如您注意到的。考虑这一备选办法:
batch_size
,则返回它,否则转到步骤1这种方法可以实现如下:
def unique_01_sequences(dim, batch_size):
sequences = set()
while len(sequences) != batch_size:
sequences.add(tuple(np.random.randint(0, 2, dim)))
return sequences
运行dim=32
和batch_size=1000
的两种解决方案:
Original: 2.296s
Improved: 0.017s
注意:我建议的函数的结果是一组元组,但它可以转换成您喜欢的格式。
其他一些建议和考虑:
dim
和batch_size
的某些配置,建议的方法可能变得非常慢。例如,如果输入是dim=10
和batch_size=1024
,则结果包含10“位”的所有配置,这些配置是0到1023之间数字的二进制表示。在生成过程中,当集合sequences
的大小接近1024时,碰撞次数增加,从而减慢了函数的速度。在这些情况下,生成所有配置(作为数字)并对它们进行洗牌将更有效。dim=10
和batch_size=1025
,函数永远不会结束。考虑验证输入。发布于 2021-01-12 14:19:04
正如其他答案在@Marc中所指出的,生成一个随机样本,然后在有任何重复的情况下将其全部丢弃是非常浪费和缓慢的。相反,您可以使用内置set
,也可以使用np.unique
。我还会使用稍微快一点的算法,一次生成多个元组,然后去重复,检查丢失了多少元组,然后生成足够多的元组,假设现在存在重复,然后重复它。
def random_bytes_numpy(dim, n):
nums = np.unique(np.random.randint(0, 2, [n, dim]), axis=1)
while len(nums) < n:
nums = np.unique(
np.stack([nums, np.random.randint(0, 2, [n - len(nums), dim])]),
axis=1
)
return nums
这里有一种使用set
的替代方法,但是使用相同的算法,在没有重复的情况下,生成的样本总是完全相同的:
def random_bytes_set(dim, n):
nums = set()
while len(nums) < n:
nums.update(map(tuple, np.random.randint(0, 2, [n - len(nums), dim])))
return nums
下面是他们在固定batch_size
上增加dim=32
所需时间的比较,包括@Marc和您的函数:
对于更大的batch_size
值,不使用您的算法,因为这需要太长时间:
https://codereview.stackexchange.com/questions/254587
复制相似问题