前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >如何使用 PyTorch Hook

如何使用 PyTorch Hook

作者头像
McGL
发布2020-11-10 11:07:59
9760
发布2020-11-10 11:07:59
举报
文章被收录于专栏:PyVision

作者:Frank Odom 编译:McGL

什么是钩子(Hook)?

Hook 实际上在软件工程中相当常见,并不是 PyTorch 所独有的。一般来说,“hook”是在特定事件之后自动执行的函数。在现实世界中,你可能遇到过的一些 hook 的例子:

  • 网站在你访问 N 个不同页面后会显示一个广告。
  • 你的账户有资金入账时,银行 app 发送通知消息。
  • 当周围光线减弱时,手机屏幕亮度会变暗。

这些事情没有 hook 也可以实现,但是很多情况下,hook 使程序员的生活更轻松。

PyTorch 为每个张量或 nn.Module 对象注册 hook。hook 由对象的向前或向后传播触发。它们具有以下函数签名:

代码语言:javascript
复制
from torch import nn, Tensor

def module_hook(module: nn.Module, input: Tensor, output: Tensor):
    # For nn.Module objects only.
    
def tensor_hook(grad: Tensor):
    # For Tensor objects only.
    # Only executed during the *backward* pass!

每个 hook 都可以修改输入、输出或内部模块参数。最常见的是用于调试目的。但我们将看到它们还有很多其他用途。

示例 #1: 模型执行详情

你自己有没有在模型中插入 print 语句,来试图找出错消息的原因?(我当然对此有罪恶感。)这是一个丑陋的调试实践,而且在很多情况下,我们在完成 print 语句时忘记删除它。导致我们的代码看起来很不专业,用户每次使用你的代码都会得到一些奇怪的信息。

以后再也不会了!让我们使用 hook 来调试模型,而不用以任何方式修改它们的实现。例如,假如你想知道每个层输出的形状。我们可以创建一个简单的 wrapper,使用 hook 打印输出形状。

代码语言:javascript
复制
class VerboseExecution(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

        # Register a hook for each layer
        for name, layer in self.model.named_children():
            layer.__name__ = name
            layer.register_forward_hook(
                lambda layer, _, output: print(f"{layer.__name__}: {output.shape}")
            )

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

最大的好处是: 它甚至可以用于不是我们创建的 PyTorch 模块!下面用 ResNet50 和一些虚拟输入来展示一下。

代码语言:javascript
复制
import torch
from torchvision.models import resnet50

verbose_resnet = VerboseExecution(resnet50())
dummy_input = torch.ones(10, 3, 224, 224)

_ = verbose_resnet(dummy_input)
# conv1: torch.Size([10, 64, 112, 112])
# bn1: torch.Size([10, 64, 112, 112])
# relu: torch.Size([10, 64, 112, 112])
# maxpool: torch.Size([10, 64, 56, 56])
# layer1: torch.Size([10, 256, 56, 56])
# layer2: torch.Size([10, 512, 28, 28])
# layer3: torch.Size([10, 1024, 14, 14])
# layer4: torch.Size([10, 2048, 7, 7])
# avgpool: torch.Size([10, 2048, 1, 1])
# fc: torch.Size([10, 1000])

示例 #2: 特征提取

通常,我们希望从一个预先训练好的网络中生成特性,然后用它们来完成另一个任务(例如分类、相似度搜索等)。使用 hook,我们可以提取特征,而不需要重新创建现有模型或以任何方式修改它。

代码语言:javascript
复制
from typing import Dict, Iterable, Callable

class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        self.model = model
        self.layers = layers
        self._features = {layer: torch.empty(0) for layer in layers}

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            layer.register_forward_hook(self.save_outputs_hook(layer_id))

    def save_outputs_hook(self, layer_id: str) -> Callable:
        def fn(_, __, output):
            self._features[layer_id] = output
        return fn

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        _ = self.model(x)
        return self._features

我们可以像使用其他 PyTorch 模块一样使用特征提取器。用之前同样的虚拟输入,运行得到:

代码语言:javascript
复制
resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"])
features = resnet_features(dummy_input)

print({name: output.shape for name, output in features.items()})
# {'layer4': torch.Size([10, 2048, 7, 7]), 'avgpool': torch.Size([10, 2048, 1, 1])}

示例 #3: 梯度裁剪

梯度裁剪是处理梯度爆炸的一种著名方法。PyTorch 已经提供了梯度裁剪的工具方法,但是我们也可以很容易地使用 hook 来实现它。其他任何用于梯度裁剪/归一化/修改的方法都可以用同样的方式实现。

代码语言:javascript
复制
def gradient_clipper(model: nn.Module, val: float) -> nn.Module:
    for parameter in model.parameters():
        parameter.register_hook(lambda grad: grad.clamp_(-val, val))
    
    return model

这个 hook 是后向传播时触发的,所以这次我们还计算了一个虚拟的损失度量。在执行 loss.backward() 之后,我们可以手动检查参数梯度,以确认它是否正常工作。

代码语言:javascript
复制
clipped_resnet = gradient_clipper(resnet50(), 0.01)
pred = clipped_resnet(dummy_input)
loss = pred.log().mean()
loss.backward()

print(clipped_resnet.fc.bias.grad[:25])
# tensor([-0.0010, -0.0047, -0.0010, -0.0009, -0.0015,  0.0027,  0.0017, -0.0023,
#          0.0051, -0.0007, -0.0057, -0.0010, -0.0039, -0.0100, -0.0018,  0.0062,
#          0.0034, -0.0010,  0.0052,  0.0021,  0.0010,  0.0017, -0.0100,  0.0021,
#          0.0020])

「来源:」https://towardsdatascience.com/how-to-use-pytorch-hooks-5041d777f904

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-11-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 PyVision 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 什么是钩子(Hook)?
  • 示例 #1: 模型执行详情
  • 示例 #2: 特征提取
  • 示例 #3: 梯度裁剪
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档