首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >PyTorch中的截断反向传播(代码检查)

PyTorch中的截断反向传播(代码检查)
EN

Stack Overflow用户
提问于 2020-07-14 18:20:10
回答 1查看 1K关注 0票数 5

对于K1=K2的简单情况,我正在尝试在PyTorch中实现截断的反向传播。我在下面有一个实现,可以产生合理的输出,但我只想确保它是正确的。当我在网上查找TBTT的PyTorch示例时,它们在分离隐藏状态和归零梯度以及这些操作的顺序方面做了一些不一致的事情。如果我弄错了,请告诉我。

在下面的代码中,H维护当前的隐藏状态,model(weights, H, x)输出预测和新的隐藏状态。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
while i < NUM_STEPS:
    # Grab x, y for ith datapoint
    x = data[i]
    target = true_output[i]

    # Run model
    output, new_hidden = model(weights, H, x)
    H = new_hidden

    # Update running error
    error += (output - target)**2

    if (i+1) % K == 0:
        # Backpropagate
        error.backward()
        opt.step()
        opt.zero_grad()
        error = 0
        H = H.detach()

    i += 1
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-07-18 07:07:26

因此,您的代码的思想是在每个第k步之后隔离最后一个变量。是的,你的实现是绝对正确的,这个answer证实了这一点。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# truncated to the last K timesteps
while i < NUM_STEPS:
    out = model(out)
    if (i+1) % K == 0:
        out.backward()
        out.detach()
out.backward()

您也可以按照this示例进行参考。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch

from ignite.engine import Engine, EventEnum, _prepare_batch
from ignite.utils import apply_to_tensor


class Tbptt_Events(EventEnum):
    """Aditional tbptt events.

    Additional events for truncated backpropagation throught time dedicated
    trainer.
    """

    TIME_ITERATION_STARTED = "time_iteration_started"
    TIME_ITERATION_COMPLETED = "time_iteration_completed"


def _detach_hidden(hidden):
    """Cut backpropagation graph.

    Auxillary function to cut the backpropagation graph by detaching the hidden
    vector.
    """
    return apply_to_tensor(hidden, torch.Tensor.detach)


def create_supervised_tbptt_trainer(
    model, optimizer, loss_fn, tbtt_step, dim=0, device=None, non_blocking=False, prepare_batch=_prepare_batch
):
    """Create a trainer for truncated backprop through time supervised models.

    Training recurrent model on long sequences is computationally intensive as
    it requires to process the whole sequence before getting a gradient.
    However, when the training loss is computed over many outputs
    (`X to many <https://karpathy.github.io/2015/05/21/rnn-effectiveness/>`_),
    there is an opportunity to compute a gradient over a subsequence. This is
    known as
    `truncated backpropagation through time <https://machinelearningmastery.com/
    gentle-introduction-backpropagation-time/>`_.
    This supervised trainer apply gradient optimization step every `tbtt_step`
    time steps of the sequence, while backpropagating through the same
    `tbtt_step` time steps.

    Args:
        model (`torch.nn.Module`): the model to train.
        optimizer (`torch.optim.Optimizer`): the optimizer to use.
        loss_fn (torch.nn loss function): the loss function to use.
        tbtt_step (int): the length of time chunks (last one may be smaller).
        dim (int): axis representing the time dimension.
        device (str, optional): device type specification (default: None).
            Applies to batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU,
            the copy may occur asynchronously with respect to the host. For other cases,
            this argument has no effect.
        prepare_batch (callable, optional): function that receives `batch`, `device`,
            `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`.

    .. warning::

        The internal use of `device` has changed.
        `device` will now *only* be used to move the input data to the correct device.
        The `model` should be moved by the user before creating an optimizer.

        For more information see:

        * `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_
        * `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_

    Returns:
        Engine: a trainer engine with supervised update function.

    """

    def _update(engine, batch):
        loss_list = []
        hidden = None

        x, y = batch
        for batch_t in zip(x.split(tbtt_step, dim=dim), y.split(tbtt_step, dim=dim)):
            x_t, y_t = prepare_batch(batch_t, device=device, non_blocking=non_blocking)
            # Fire event for start of iteration
            engine.fire_event(Tbptt_Events.TIME_ITERATION_STARTED)
            # Forward, backward and
            model.train()
            optimizer.zero_grad()
            if hidden is None:
                y_pred_t, hidden = model(x_t)
            else:
                hidden = _detach_hidden(hidden)
                y_pred_t, hidden = model(x_t, hidden)
            loss_t = loss_fn(y_pred_t, y_t)
            loss_t.backward()
            optimizer.step()

            # Setting state of engine for consistent behaviour
            engine.state.output = loss_t.item()
            loss_list.append(loss_t.item())

            # Fire event for end of iteration
            engine.fire_event(Tbptt_Events.TIME_ITERATION_COMPLETED)

        # return average loss over the time splits
        return sum(loss_list) / len(loss_list)

    engine = Engine(_update)
    engine.register_events(*Tbptt_Events)
    return engine
票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62901561

复制
相关文章
CNN的反向传播DNN中的反向传播卷积神经网络中的反向传播
DNN中的反向传播 反向传播算法是神经网络的训练的基本算法组成之一,在训练神经网络时,训练分为两个步骤:计算梯度和更新权值。其中反向传播负责的是梯度的计算,而训练算法的区分主要在更新权值的方式上。对于DNN,基本的反向传播思路为: $$\cfrac{dz}{dw_{i}} = \cfrac{dz}{da_{i+1}} \times \cfrac{da_{i+1}}{dw_{i}}$$ 其中,$\cfrac{dz}{dw_{i}}$为输出(多为代价函数输出)对第i层的权值的梯度,$\cfrac{da_{i+1
月见樽
2018/04/27
1.2K0
pytorch学习笔记(三):反向传播
细节提要: Tensor张量可以视作数据结构:数据data + 梯度grad(grad也是张量) 核心代码:l.backward()反向传播,程序自动求出所有需要的梯度 w.grad.data.zero_()每步做完清零很重要,否则影响下一次求的梯度
zstar
2022/06/14
2660
pytorch学习笔记(三):反向传播
pytorch基础知识-反向传播
如图所示为由最开始的输入到输出。给了每个节点上一个weight(权重),第一层和第二层均用O代表输出。通过链式法则分解成上式子。
用户6719124
2019/11/17
4750
pytorch loss反向传播出错
今天在使用pytorch进行训练,在运行 loss.backward() 误差反向传播时出错 :
全栈程序员站长
2022/08/31
1.2K0
CNN中的反向传播
我们知道Pooling操作会使得feature map的尺寸发生变化,假如做$2\times 2$的池化,假设$l+1$层的feature map有16个梯度,那么第$l$层应该需要64个梯度,做法很简单,只需要把1个像素的梯度反向传给4个像素即可,但是要保证传递的梯度总和不变。由于有这条原则,因此avg pooling和max pooling的反向传播并不相同
mathor
2020/02/20
8460
[源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播
上文我们已经对Reduer的前向传播进行了分析,本文就接着来看看如何进行反向传播。
罗西的思考
2021/12/01
9600
[源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播
pytorch .detach() .detach_() 和 .data用于切断反向传播的实现
这篇文章主要介绍了pytorch .detach() .detach_() 和 .data用于切断反向传播的实现。
狼啸风云
2020/07/13
6.6K1
反向传播算法
反向传播算法 目录 关键词 损失函数 反向传播算法 BP算法伪代码 上一篇----》神经网络 》点击查看 1 关键词 反向传播算法 Backpropagation Algorithm 批量梯度下降法 batch gradient descent 整体代价函数 overall cost function 方差 squared-error 均方差 average sum-of-squares error 规则化项 regularization term 权重衰减 weight decay 偏置项 bias te
昱良
2018/04/08
1.3K0
反向传播算法
前反向传播
X: [[0.83494319 0.11482951] [0.66899751 0.46594987] [0.60181666 0.58838408] [0.31836656 0.20502072] [0.87043944 0.02679395] [0.41539811 0.43938369] [0.68635684 0.24833404] [0.97315228 0.68541849] [0.03081617 0.89479913] [0.24665715 0.28584862] [0.31375667 0.47718349] [0.56689254 0.77079148] [0.7321604 0.35828963] [0.15724842 0.94294584] [0.34933722 0.84634483] [0.50304053 0.81299619] [0.23869886 0.9895604 ] [0.4636501 0.32531094] [0.36510487 0.97365522] [0.73350238 0.83833013] [0.61810158 0.12580353] [0.59274817 0.18779828] [0.87150299 0.34679501] [0.25883219 0.50002932] [0.75690948 0.83429824] [0.29316649 0.05646578] [0.10409134 0.88235166] [0.06727785 0.57784761] [0.38492705 0.48384792] [0.69234428 0.19687348] [0.42783492 0.73416985] [0.09696069 0.04883936]] Y: [[1], [0], [0], [1], [1], [1], [1], [0], [1], [1], [1], [0], [0], [0], [0], [0], [0], [1], [0], [0], [1], [1], [0], [1], [0], [1], [1], [1], [1], [1], [0], [1]] w1: [[-0.8113182 1.4845988 0.06532937] [-2.4427042 0.0992484 0.5912243 ]] w2: [[-0.8113182 ] [ 1.4845988 ] [ 0.06532937]] After 0 training steps,loss on all data is 5.13118 After 500 training steps,loss on all data is 0.429111 After 1000 training steps,loss on all data is 0.409789 After 1500 training steps,loss on all data is 0.399923 After 2000 training steps,loss on all data is 0.394146 After 2500 training steps,loss on all data is 0.390597 w1: [[-0.7000663 0.9136318 0.08953571] [-2.3402493 -0.14641273 0.58823055]] w2: [[-0.06024271] [ 0.9195618 ] [-0.06820712]]
裴来凡
2022/05/29
5970
前反向传播
反向传播详解
反向传播 下降的幅度与斜率大小成正比,越到最小点时,每步应越来越小,防止调过头 某一点导数大小即可表现为 w的变化幅度 w - k * ▽J(w) ▽J(w)为正表示向左走 w变小 ,为负表示
Dean0731
2020/05/11
6740
BP反向传播
""" 案例:研究生学院录取数据,用梯度下降训练一个网络。 数据有三个输入特征:GRE 分数、GPA 分数和本科院校排名(从 1 到 4)。排名 1 代表最好,排名 4 代表最差。 """ # ----------分隔线------------- """ 数据解读说明: admit 0未录取 1 录取 gre分数 gpa绩点分数 rank 本科院校等级 #数据预处理 admit --目标标签 rank ---分类变量--》亚编码|one-hot独热编码,相当于去除量纲的影响 gre,gpa
诡途
2021/12/30
4400
反向传播算法详解和Python代码实现
反向传播算法是训练神经网络的经典算法,是深度学习的最重要的基础,适合于多层神经元网络的一种学习算法,它建立在梯度下降法的基础上,通过迭代的方法求出目标损失函数(loss function)的近似最小值。
deephub
2020/05/09
3.5K0
反向传播算法详解和Python代码实现
反向传播算法原理推导及代码实现
《实例》阐述算法,通俗易懂,助您对算法的理解达到一个新高度。包含但不限于:经典算法,机器学习,深度学习,LeetCode 题解,Kaggle 实战。期待您的到来! 01 — 回顾 昨天,分析了手写字数据集分类的原理,利用神经网络模型,编写了SGD算法的代码,分多个epochs,每个 epoch 又对 mini_batch 样本做多次迭代计算,详细的过程,请参考: 这其中,非常重要的一个步骤,便是利用反向传播(BP)算法求权重参数的梯度,偏置量的梯度。下面根据以下几个问题展开BP算法: 什么是BP算法? 为什
企鹅号小编
2018/02/12
1.1K0
反向传播算法原理推导及代码实现
[MachineLearning] 反向传播Back Propagation
如何直观地解释 back propagation 算法? - 胡逸夫的回答 - 知乎
wOw
2018/09/18
9010
[MachineLearning] 反向传播Back Propagation
【TensorFlow篇】--反向传播
反向自动求导是 TensorFlow 实现的方案,首先,它执行图的前向阶段,从输入到输出,去计算节点 值,然后是反向阶段,从输出到输入去计算所有的偏导。
LhWorld哥陪你聊算法
2018/09/13
6750
【TensorFlow篇】--反向传播
反向传播是什么?
深度学习系统能够学习极其复杂的模式,它们通过调整权重来实现这一点。深度神经网络则通过反向传播的过程进行调整,如果没有反向传播,深度神经网络就无法执行识别图像和解释自然语言等任务。
AiTechYun
2020/02/23
1.4K0
反向传播算法(Backpropagation)
BP算法(即反向传播算法)是在有导师指导下,适合于多层神经元网络的一种学习算法,它建立在梯度下降法的基础上。BP网络的输入输出关系实质上是一种映射关系:一个n输入m输出的BP神经网络所完成的功能是从n维欧氏空间向m维欧氏空间中一有限域的连续映射,这一映射具有高度非线性。它的信息处理能力来源于简单非线性函数的多次复合,因此具有很强的函数复现能力。这是BP算法得以应用的基础。
easyAI
2019/12/18
9030
理解多层CNN中转置卷积的反向传播(附代码)
【导读】转置卷积一直不太好理解,今天我们通过详细的推导示例及代码了解简单的两层CNN中转置卷积的反向传播。 编译 | 专知 参与 | Yingying, Xiaowen 今天,我们要训练一个简单的有两
WZEARW
2018/06/05
3.7K0
卷积神经网络中卷积运算的前向传播与反向传播推导
版权声明:博客文章都是作者辛苦整理的,转载请注明出处,谢谢! https://blog.csdn.net/Quincuntial/article/details/90412121
Tyan
2019/05/25
1.2K0
点击加载更多

相似问题

了解PyTorch中的反向传播

32

PyTorch布尔-停止反向传播?

23

截断反向传播对时间反向传播步数的影响

11

pyTorch中的反向传播考虑了哪些参数?

14

在火炬中截断时间的反向传播(BPTT)

13
添加站长 进交流群

领取专属 10元无门槛券

AI混元助手 在线答疑

扫码加入开发者社群
关注 腾讯云开发者公众号

洞察 腾讯核心技术

剖析业界实践案例

扫码关注腾讯云开发者公众号
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
查看详情【社区公告】 技术创作特训营有奖征文