假设我有一个坐标列表:
data = [
[(10, 20), (100, 120), (0, 5), (50, 60)],
[(13, 20), (300, 400), (100, 120), (51, 62)]
]我想取所有元组,这些元组要么出现在每个列表中的数据中,要么任何元组与列表中的所有元组有3或更少的区别。我如何在Python中高效地完成这个任务?
对于上面的例子,结果应该是:
[[(100, 120), # since it occurs in both lists
(10, 20), (13, 20), # since they differ by only 3
(50, 60), (51, 60)]](0,5)和(300,400)将不包括在内,因为它们没有出现在两个列表中,并且与列表中的元素没有3或更少的区别。
如何计算?谢谢。
发布于 2010-06-02 02:11:46
这方面的天真实现将是缓慢的: O(n^2),测试每个节点对彼此的节点。用一棵树加速它。
这个实现使用一个简单的四叉树来提高搜索的效率。这并没有试图平衡这棵树,所以一张排列得很糟糕的点数列表可能会使它效率低下。对于许多用途,简单地调整列表可能会使其足够好;只是要确保不要传递许多按坐标排序的项,因为这会将其简化为链接列表。
这里的优化很简单:如果我们在欧几里得距离内寻找某个点的3个单位,而且我们知道一个子树中的所有项目都在右边至少有3个单位,那么该区域的任何点都不可能小于3个单位。
此代码是公共域。尽量不要把它当作作业交上来。
#!/usr/bin/python
import math
def euclidean_distance(pos1, pos2):
x = math.pow(pos1[0] - pos2[0], 2)
y = math.pow(pos1[1] - pos2[1], 2)
return math.sqrt(x + y)
class QuadTreeNode(object):
def __init__(self, pos):
"""
Create a QuadTreeNode at the specified position. pos must be an (x, y) tuple.
Children are classified by quadrant.
"""
# Children of this node are ordered TL, TR, BL, BL (origin top-left).
self.children = [None, None, None, None]
self.pos = pos
def classify_node(self, pos):
"""
Return which entry in children can contain pos. If pos is equal to this
node, return None.
>>> node = QuadTreeNode((10, 20))
>>> node.classify_node((10, 20)) == None
True
>>> node.classify_node((2, 2))
0
>>> node.classify_node((50, 2))
1
>>> node.classify_node((2, 50))
2
>>> node.classify_node((50, 50))
3
X boundary condition:
>>> node.classify_node((10, 2))
0
>>> node.classify_node((10, 50))
2
Y boundary conditoin:
>>> node.classify_node((2, 20))
0
>>> node.classify_node((50, 20))
1
"""
if pos == self.pos:
return None
if pos[0] <= self.pos[0]: # Left
if pos[1] <= self.pos[1]: # Top-left
return 0
else: # Bottom-left
return 2
else: # Right
if pos[1] <= self.pos[1]: # Top-right
return 1
else: # Bottom-right
return 3
assert False, "not reached"
def add_node(self, node):
"""
Add a specified point under this node.
"""
type = self.classify_node(node.pos)
if type is None:
# node is equal to self, so this is a duplicate node. Ignore it.
return
if self.children[type] is None:
self.children[type] = node
else:
# We already have a node there; recurse and add it to the child.
self.children[type].add_node(node)
@staticmethod
def CreateQuadTree(data):
"""
Create a quad tree from the specified list of points.
"""
root = QuadTreeNode(data[0])
for val in data[1:]:
node = QuadTreeNode(val)
root.add_node(node)
return root
def distance_from_pos(self, pos):
return euclidean_distance(self.pos, pos)
def __str__(self): return str(self.pos)
def find_point_within_range(self, pos, distance):
"""
If a point exists within the specified Euclidean distance of the specified
point, return it. Otherwise, return None.
"""
if self.distance_from_pos(pos) <= distance:
return self
for axis in range(0, 4):
if self.children[axis] is None:
# We don't have a node on this axis.
continue
# If moving forward on this axis would permanently put us out of range of
# the point, short circuit the search on that axis.
if axis in (0, 2): # axis moves left on X
if self.pos[0] < pos[0] - distance:
continue
if axis in (1, 3): # axis moves right on X
if self.pos[0] > pos[0] + distance:
continue
if axis in (0, 1): # axis moves up on Y
if self.pos[1] < pos[1] - distance:
continue
if axis in (2, 3): # axis moves down on Y
if self.pos[1] > pos[1] + distance:
continue
node = self.children[axis].find_point_within_range(pos, distance)
if node is not None:
return node
return None
@staticmethod
def find_point_in_range_for_all_trees(point, trees, distance):
"""
If all QuadTreeNodes in trees contain a a point within the specified distance
of point, return True, Otherwise, return False.
"""
for tree in trees:
if tree.find_point_within_range(point, distance) is None:
return False
return True
def test_naive(data, distance):
def find_point_in_list(iter, point):
for i in iter:
if euclidean_distance(i, point) <= distance:
return True
return False
def find_point_in_all_lists(point):
for d in data:
if not find_point_in_list(d, point):
return False
return True
results = []
for d in data:
for point in d:
if find_point_in_all_lists(point):
results.append(point)
return set(results)
def test_tree(data, distance):
trees = [QuadTreeNode.CreateQuadTree(d) for d in data]
results = []
for d in data:
for point in d:
if QuadTreeNode.find_point_in_range_for_all_trees(point, trees, 3):
results.append(point)
return set(results)
def test():
sample_data = [
[(10, 20), (100, 120), (0, 5), (50, 60)],
[(13, 20), (300, 400), (100, 120), (51, 62)]
]
result1 = test_naive(sample_data, 3)
result2 = test_tree(sample_data, 3)
print result1
assert result1 == result2
# Loosely validate the tree algorithm against a lot of sample data, and compare
# performance while we're at it:
def random_data():
import random
return [(random.randint(0,1000), random.randint(0,1000)) for d in range(0,500)]
data = [random_data() for x in range(0,10)]
print "Searching (naive)..."
result1 = test_naive(data, 3)
print "Searching (tree)..."
result2 = test_tree(data, 3)
assert result1 == result2
if __name__ == "__main__":
test()
import doctest
doctest.testmod()发布于 2010-06-02 00:42:26
我希望这能让你开始。如有任何改进,将不胜感激。
在所有列表中显示都是微不足道的--只需取列表中所有元素的交集即可。
>>> data = [
... [(10, 20), (100, 120), (0, 5), (50, 60)],
... [(13, 20), (300, 400), (100, 120), (51, 62)]
... ]
>>> dataset = [set(d) for d in data]
>>> dataset[0].intersection(*dataset[1:])
set([(100, 120)])在我看来,除了在同一列表中的元组之外,“在3或3以下不同”似乎是一个图/2d空间问题。如果没有多项式算法,就没有简单的算法,如果数据集不是很大,您只需迭代它们并将不在同一列表中的关闭点分组。
发布于 2010-06-02 02:38:44
@barrycarter的直觉是有趣的:为了减少比较的次数(我们所说的“比较”两点的意思是检查它们的距离是否为<= 3),“实际上将”2D平面分割成合适的“单元”,这样每个点只需要与相邻的“单元”中的点进行比较。如果您的数据集很大,这确实会有帮助(wrt是一种蛮力解决方案,需要将每个点与所有其他点进行比较)。
下面是这个想法的Python实现(因为barry的代码草图似乎是Perl之类的),目的是为了清晰而不是速度.:
import collections
import math
def cellof(point):
x, y = point
return x//3, y//3
def distance(p1, p2):
return math.hypot(p1[0]-p2[0], p1[1]-p2[1])
def process(data):
cells = collections.defaultdict(list)
for i, points in enumerate(data):
for p in points:
cx, cy = cellof(p)
cells[cx, cy].append((i, p))
res = set()
for c, alist in cells.items():
for i, p in alist:
for cx in range(c[0]-1, c[0]+2):
for cy in range(c[1]-1, c[1]+2):
otherc = cells[cx, cy]
for otheri, otherp in otherc:
if i == otheri: continue
dst = distance(p, otherp)
if dst <= 3: res.add(p)
return sorted(res)
if __name__ == '__main__': # just an example
data = [
[(10, 20), (100, 120), (0, 5), (50, 60)],
[(13, 20), (300, 400), (100, 120), (51, 62)]
]
print process(data)作为脚本运行时,这将生成输出
[(10, 20), (13, 20), (50, 60), (51, 62), (100, 120)]当然,要确定这是否值得,或者更简单的蛮力方法确实更好,唯一可行的方法是在现实的数据上运行这两种解决方案的基准--你的程序在现实生活中需要处理的数据集类型。取决于你有多少个列表,每个列表上有多少点,距离有多大,性能会有很大的差异,衡量比猜测更好!)
https://stackoverflow.com/questions/2953878
复制相似问题