首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >告诉metaflow使用conda装饰器使用pip安装包

告诉metaflow使用conda装饰器使用pip安装包
EN

Stack Overflow用户
提问于 2020-12-07 11:10:38
回答 1查看 153关注 0票数 2

在AWS上运行时,我通常会定义一个步骤:

代码语言:javascript
运行
复制
@batch(cpu=1, memory=5000)
@conda(libraries={'pandas': '1'})
@step
def hello(self):
    do stuff...

但是,我正在使用深入学习的库(MXnet/tensorflow/pytorch),它们在conda上并不是特别新,最好使用pip安装。

如何定义此步骤的pip依赖项?

EN

回答 1

Stack Overflow用户

发布于 2021-06-30 01:15:39

MetaFlow本身并不支持pip安装(discussion here)。然而,我做了一个@pip装饰器,你可以使用它(在那个问题中基于on a comment ):

代码语言:javascript
运行
复制
import functools
from importlib import import_module
import logging
import signal
from typing import Dict

def pip(libraries: Dict[str, str]):
    """
    A Flow decorator which mimics @conda, except installs pip deps. Use @conda instead whenever possible.

    Note: this requires 3rd party modules to be imported _inside_ the flow/step this decorator scopes; otherwise you
    will get ModuleNotFound errors. Also note that this decorator has to be on the line before @conda is used.

    To install wheels from a specific source url, put the url after the library name separated by a pipe, i.e.
    @pip({'torch|your.urlr/here':1.8.1)

    Will check to see if the pkg is already installed before re-installing. This means that this will not install the
    exact pinned version if the library already exists.

    Based on: https://github.com/Netflix/metaflow/issues/24#issuecomment-571976372
    """

    def decorator(function):
        @functools.wraps(function)
        def wrapper(*args, **kwargs):
            import subprocess

            to_install = []
            to_install_source_override = []
            for library, version in libraries.items():  # NOTE: for some reason, list comp breaks this
                if "|" in library:  # signal alternative source url
                    # if specifying the alternative source, always flag to install. this is because the
                    # import_module() step won't fail, resulting in the wrong env
                    library, source = library.split("|")
                    parsed = f"{library}=={version}"
                    to_install_source_override.append((parsed, source))
                    continue
                try:
                    # note: this will fail to import any reqs that have an extra, e.g. rag[serve]. however, we do not
                    # want to ignore the extra or else we will not pip install the extras and we will get downstream
                    # errors
                    import_module(library)  # note: will not throw exception if existing lib is wrong version
                except ModuleNotFoundError:
                    logger.info(f"failed to import library {library}; pip installing")
                    parsed = f"{library}=={version}"
                    to_install.append(parsed)
                except BaseException as e:
                    raise Exception(f"An error occurred while loading module {library}") from e

            # without this context manager, you can break your venv if you keyboard interrupt a flow while it's pip
            # installing libraries
            with DelayedKeyboardInterrupt():
                # install directly from pip
                # NOTE: do not use sys.executable, "-m", "pip" because this will pip install to the wrong conda env!
                subprocess.run(["pip", "install", "--quiet", *to_install])
                # install pkgs from remote source
                for pkg in to_install_source_override:
                    parsed, src = pkg
                    logger.info(f"pip installing {parsed} from {src}")
                    subprocess.run(
                        [
                            "pip",
                            "install",
                            # "--ignore-installed",  # force install of remote version
                            "--quiet",
                            parsed,
                            "--find-links",
                            src,
                        ]
                    )

            return function(*args, **kwargs)

        return wrapper

    return decorator


class DelayedKeyboardInterrupt:
    """
    Context manager to prevent keyboardinterrupt from interrupting important code.

    source: https://stackoverflow.com/a/21919644/4212158
    """

    def __enter__(self):
        self.signal_received = False
        self.old_handler = signal.signal(signal.SIGINT, self.handler)

    def handler(self, sig, frame):
        self.signal_received = (sig, frame)
        logging.debug("SIGINT received. Delaying KeyboardInterrupt.")

    def __exit__(self, type, value, traceback):
        signal.signal(signal.SIGINT, self.old_handler)
        if self.signal_received:
            self.old_handler(*self.signal_received)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65175602

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档