我想写一个函数my_func(n,l)
,对于某个正整数n
,它可以有效地枚举长度为l
(其中l
大于n
)的有序非负整数组成*。例如,我希望my_func(2,3)
返回[[0,0,2],[0,2,0],[2,0,0],[1,1,0],[1,0,1],[0,1,1]]
。
我最初的想法是使用现有的正整数分区代码(例如,this post中的accel_asc()
),将正整数分区扩展几个零,并返回所有排列。
def my_func(n, l):
for ip in accel_asc(n):
nic = numpy.zeros(l, dtype=int)
nic[:len(ip)] = ip
for p in itertools.permutations(nic):
yield p
此函数的输出是错误的,因为其中一个数字出现两次(或多次)的每个非负整数组合在my_func
的输出中出现多次。例如,list(my_func(2,3))
返回[(1, 1, 0), (1, 0, 1), (1, 1, 0), (1, 0, 1), (0, 1, 1), (0, 1, 1), (2, 0, 0), (2, 0, 0), (0, 2, 0), (0, 0, 2), (0, 2, 0), (0, 0, 2)]
。
我可以通过生成所有非负整数组合的列表,删除重复的条目,然后返回剩余的列表(而不是生成器)来纠正这个问题。但这似乎非常低效,而且可能会遇到内存问题。解决这个问题的更好方法是什么?
编辑
我对这篇文章的答案和another post在评论中指出的解决方案进行了快速比较。
左边是l=2*n
,右边是l=n+1
。在这两种情况下,当n<=5
时,user2357112的第二个解决方案比其他解决方案更快。对于n>5
,user2357112、Nathan Verzemnieks和AndyP提出的解决方案或多或少是并驾齐驱的。但当考虑到l
和n
之间的其他关系时,结论可能会不同。
.
*我最初要求的是非负整数分区。Joseph Wood正确地指出,我实际上是在寻找整数组合,因为数字在序列中的顺序对我很重要。
发布于 2019-04-17 06:16:53
使用stars and bars概念:选择位置以在n
星形之间放置l-1
条,并计算每个部分中有多少颗星形:
import itertools
def diff(seq):
return [seq[i+1] - seq[i] for i in range(len(seq)-1)]
def generator(n, l):
for combination in itertools.combinations_with_replacement(range(n+1), l-1):
yield [combination[0]] + diff(combination) + [n-combination[-1]]
我在这里使用了combinations_with_replacement
而不是combinations
,因此索引处理与使用combinations
时所需要的稍有不同。使用combinations
的代码将更接近于对星形和条形的标准处理。
或者,使用combinations_with_replacement
的另一种方法是:从l
零的列表开始,从n
可能的位置中选取带有替换的l
位置,然后将1添加到每个选择的位置以生成输出:
def generator2(n, l):
for combination in itertools.combinations_with_replacement(range(l), n):
output = [0]*l
for i in combination:
output[i] += 1
yield output
发布于 2019-04-17 06:40:07
从一个简单的递归解决方案开始,它与您的问题相同:
def nn_partitions(n, l):
if n == 0:
yield [0] * l
else:
for part in nn_partitions(n - 1, l):
for i in range(l):
new = list(part)
new[i] += 1
yield new
也就是说,对于下一个较低数字的每个分区,对于该分区中的每个位置,将该位置的元素加1。它产生的副本与您的副本相同。不过,对于类似的问题,我想出了一个诀窍:当您将n
的分区p
更改为n+1
的分区时,将p
的所有元素都固定在您增加的元素的左侧。也就是说,跟踪p
被修改的位置,并且永远不要修改左边p
的任何“后代”。下面是代码:
def _nn_partitions(n, l):
if n == 0:
yield [0] * l, 0
else:
for part, start in _nn_partitions(n - 1, l):
for i in range(start, l):
new = list(part)
new[i] += 1
yield new, i
def nn_partitions(n, l):
for part, _ in _nn_partitions(n, l):
yield part
这非常相似--每一步都会传递额外的参数,所以我添加了包装器来删除调用者的参数。
我没有对它进行过广泛的测试,但这似乎是相当快的- nn_partitions(3, 5)
大约35微秒,nn_partitions(10, 20)
大约18秒(产生2000多万个分区)。(来自user2357112的非常优雅的解决方案对于较小的情况大约需要两倍的时间,对于较大的情况大约需要四倍的时间。编辑:这指的是该答案中的第一个解决方案;第二个解决方案在某些情况下比我的快,在另一些情况下慢。)
https://stackoverflow.com/questions/55716916
复制相似问题