前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Python中的堆排序与优先队列

Python中的堆排序与优先队列

原创
作者头像
杜逸先
修改2021-06-10 11:51:32
9810
修改2021-06-10 11:51:32
举报

对数据进行排序是一个很常见的需求,但有时候我们并不需要对完整的数据进行排序,只需要排前几的数据,也就是经典的 Top-K 问题。

Top-K 问题的经典解法有两种:一种是脱胎于快速排序(Quick Sort)的快速选择(Quick Select)算法,核心思路是在每一次Partion操作后下一次递归只操作前K项数据。另一种是基于堆排序的方法。

Python 中有两个标准库可以原生的支持堆排序(优先队列),分别是heapqPriorityQueue(queue)

heapq

heapq标准库提供了一些工具函数用来对list对象实现二叉堆的各种操作(就地修改list对象)。简单的用法如下:

建堆

代码语言:txt
复制
import heapq

# 可以用过random.shuffle函数创造乱序数组
arr = [4, 0, 3, 1, 6, 5, 9, 7, 8, 2]
heapq.heapify(arr)
assert arr == [0, 1, 3, 4, 2, 5, 9, 7, 8, 6]

获取堆顶元素

代码语言:txt
复制
assert heapq.heappop(arr) == 0
assert arr == [1, 2, 3, 4, 6, 5, 9, 7, 8]

<!-- more -->

插入新元素

代码语言:txt
复制
heapq.heappush(arr, 11)
assert arr == [1, 2, 3, 4, 6, 5, 9, 7, 8, 11]

heapq也提供了直接获取nlargestnsmallest函数,并且这两个函数并不会就地修改原数据。

代码语言:txt
复制
arr = [4, 0, 3, 1, 6, 5, 9, 7, 8, 2]
assert heapq.nlargest(5, arr) == [9, 8, 7, 6, 5]
assert arr == [4, 0, 3, 1, 6, 5, 9, 7, 8, 2]
assert heapq.nsmallest(5, arr) == [0, 1, 2, 3, 4]

queue.PriorityQueue

queue标准库为 Python 代码提供了原生线程安全的队列实现。queue.PriorityQueue则是 Python 原生的优先队列实现,相比heapq有着更直观易用的接口。

创建优先队列

代码语言:txt
复制
from queue import PriorityQueue

pq = PriorityQueue()

arr = [4, 0, 3, 1, 6, 5, 9, 7, 8, 2]

for num in arr:
    pq.put(num)

获取队首元素

代码语言:txt
复制
while not pq.empty():
    assert pq.get() == 0

对比

heapq标准库是专门用来做堆排序相关操作的,而PriorityQueue类毕竟继承于queue.Queue,适用于多线程通信场景。两者的效率还是有着不小差距的。

我们以 LeetCode 973(最接近原点的 K 个点)为例,分别用heapqPriorityQueue实现,比较 一下二者的运行效率。

题目描述

973. 最接近原点的 K 个点

我们有一个由平面上的点组成的列表 points。需要从中找出 K 个距离原点 (0, 0) 最近的点。

(这里,平面上两点之间的距离是欧几里德距离。)

你可以按任何顺序返回答案。除了点坐标的顺序之外,答案确保是唯一的。

示例 1

输入:points = [1,3,-2,2], K = 1

输出:[-2,2]

解释:

(1, 3) 和原点之间的距离为 sqrt(10),

(-2, 2) 和原点之间的距离为 sqrt(8),

由于 sqrt(8) < sqrt(10),(-2, 2) 离原点更近。

我们只需要距离原点最近的 K = 1 个点,所以答案就是 [-2,2]。

示例 2

输入:points = [3,3,5,-1,-2,4], K = 2

输出:[3,3,-2,4]

(答案 [-2,4,3,3] 也会被接受。)

提示:

1 <= K <= points.length <= 10000

-10000 < pointsi < 10000

-10000 < pointsi < 10000

生成测试数据

代码语言:txt
复制
from random import randint
def genPoints(n:int = 100):
    return [(randint(0, 100), randint(0, 100)) for _ in range(n)]
points = genPoints(1_0000)
less_points = genPoints(100)

heapq实现

代码语言:txt
复制
import heapq

from typing import List


def distance(point: List[int]):
    return point[0] ** 2 + point[1] ** 2


class Solution:
    def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
        distances = [(distance(point), point) for point in points]
        return [e[1] for e in heapq.nsmallest(K, distances)]


solution = Solution()
%timeit solution.kClosest(points, 100)
# 6.79 ms ± 181 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

PriorityQueue实现

代码语言:txt
复制
from queue import PriorityQueue
from typing import List


def distance(point: List[int]):
    return point[0] ** 2 + point[1] ** 2


class Solution:
    def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
        pq = PriorityQueue()
        for point in points:
            pq.put((distance(point), point), block=False)
        ret = []
        while not pq.empty():
            ret.append(pq.get()[1])
        return ret


solution = Solution()
%timeit solution.kClosest(points,100)
# 52.2 ms ± 1.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

我们可以看到heapq版本比PriorityQueue版本快接近一个数量级,并且代码也更精简。

这也说明了我们要在合适的地方使用合适的工具。

原文

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • heapq
    • 建堆
      • 获取堆顶元素
        • 插入新元素
        • queue.PriorityQueue
          • 创建优先队列
            • 获取队首元素
            • 对比
              • 题目描述
                • 973. 最接近原点的 K 个点
              • 生成测试数据
                • heapq实现
                  • PriorityQueue实现
                  领券
                  问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档