首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在numpy/openblas上在运行时设置最大线程数

在numpy/openblas上在运行时设置最大线程数
EN

Stack Overflow用户
提问于 2015-04-10 10:39:13
回答 2查看 10.1K关注 0票数 15

我想知道是否可以在(Python)运行时更改numpy后面的OpenBLAS使用的最大线程数?

我知道在通过环境变量OMP_NUM_THREADS运行解释器之前可以设置它,但是我想在运行时更改它。

通常,当使用MKL而不是OpenBLAS时,有可能:

代码语言:javascript
运行
复制
import mkl
mkl.set_num_threads(n)
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2015-04-11 21:01:44

您可以通过使用openblas_set_num_threads调用ctypes函数来实现这一点。我经常发现自己想要这样做,所以我写了一个小上下文管理器:

代码语言:javascript
运行
复制
import contextlib
import ctypes
from ctypes.util import find_library

# Prioritize hand-compiled OpenBLAS library over version in /usr/lib/
# from Ubuntu repos
try_paths = ['/opt/OpenBLAS/lib/libopenblas.so',
             '/lib/libopenblas.so',
             '/usr/lib/libopenblas.so.0',
             find_library('openblas')]
openblas_lib = None
for libpath in try_paths:
    try:
        openblas_lib = ctypes.cdll.LoadLibrary(libpath)
        break
    except OSError:
        continue
if openblas_lib is None:
    raise EnvironmentError('Could not locate an OpenBLAS shared library', 2)


def set_num_threads(n):
    """Set the current number of threads used by the OpenBLAS server."""
    openblas_lib.openblas_set_num_threads(int(n))


# At the time of writing these symbols were very new:
# https://github.com/xianyi/OpenBLAS/commit/65a847c
try:
    openblas_lib.openblas_get_num_threads()
    def get_num_threads():
        """Get the current number of threads used by the OpenBLAS server."""
        return openblas_lib.openblas_get_num_threads()
except AttributeError:
    def get_num_threads():
        """Dummy function (symbol not present in %s), returns -1."""
        return -1
    pass

try:
    openblas_lib.openblas_get_num_procs()
    def get_num_procs():
        """Get the total number of physical processors"""
        return openblas_lib.openblas_get_num_procs()
except AttributeError:
    def get_num_procs():
        """Dummy function (symbol not present), returns -1."""
        return -1
    pass


@contextlib.contextmanager
def num_threads(n):
    """Temporarily changes the number of OpenBLAS threads.

    Example usage:

        print("Before: {}".format(get_num_threads()))
        with num_threads(n):
            print("In thread context: {}".format(get_num_threads()))
        print("After: {}".format(get_num_threads()))
    """
    old_n = get_num_threads()
    set_num_threads(n)
    try:
        yield
    finally:
        set_num_threads(old_n)

你可以这样使用它:

代码语言:javascript
运行
复制
with num_threads(8):
    np.dot(x, y)

正如注释中提到的,在编写本报告时,openblas_get_num_threadsopenblas_get_num_procs是非常新的特性,因此,除非您从最新版本的源代码中编译OpenBLAS,否则可能无法使用。

票数 16
EN

Stack Overflow用户

发布于 2019-06-04 11:42:49

我们最近开发了threadpoolctl,这是一个跨平台的包,用于控制在python中调用C级线程池时使用的线程数。它的工作原理类似于@ali_m的答案,但自动检测需要通过循环遍历所有加载库来限制的库。它还附带了内省API。

这个包可以使用pip install threadpoolctl安装,并附带一个上下文管理器,允许您控制包(如numpy )所使用的线程数。

代码语言:javascript
运行
复制
from threadpoolctl import threadpool_limits
import numpy as np


with threadpool_limits(limits=1, user_api='blas'):
    # In this block, calls to blas implementation (like openblas or MKL)
    # will be limited to use only one thread. They can thus be used jointly
    # with thread-parallelism.
    a = np.random.randn(1000, 1000)
    a_squared = a @ a

您还可以对不同的线程池进行更精细的控制(例如,不同的blasopenmp调用)。

注意:这个包还在开发中,欢迎任何反馈意见。

票数 12
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/29559338

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档