首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >火把中有类似theano.tensor.switch的东西吗?

火把中有类似theano.tensor.switch的东西吗?
EN

Stack Overflow用户
提问于 2018-04-19 23:35:35
回答 2查看 752关注 0票数 1

我想把向量中低于某一阈值的所有元素强制为零。我想这样做,这样我仍然可以通过非零的梯度传播。

例如,在theano,我可以写:

B = theano.tensor.switch(A < .1, 0, A)

在火把里有解决办法吗?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-04-20 09:48:48

对于pytorch 0.4+,您可以轻松地使用torch.where(参见文档合并PR)。

就像在西亚诺一样容易。举个例子来看看你自己:

代码语言:javascript
运行
复制
import torch
from torch.autograd import Variable

x = Variable(torch.arange(0,4), requires_grad=True) # x     = [0 1 2 3]
zeros = Variable(torch.zeros(*x.shape))             # zeros = [0 0 0 0]

y = x**2                         # y = [0 1 4 9]
z = torch.where(y < 5, zeros, y) # z = [0 0 0 9]

# dz/dx = (dz/dy)(dy/dx) = (y < 5)(0) + (y ≥ 5)(2x) = 2x(x**2 ≥ 5) 
z.backward(torch.Tensor([1.0])) 
x.grad # (dz/dx) = [0 0 0 6]
票数 2
EN

Stack Overflow用户

发布于 2018-04-20 04:18:56

我不认为switch是默认在PyTorch中实现的。但是,您可以在PyTorch中通过torch.autograd.Function定义自己的函数。

所以,开关函数看起来就像

代码语言:javascript
运行
复制
class switchFunction(Function):
    @staticmethod
    def forward(ctx, flag, value, tensor):
        ctx.save_for_backward(flag)
        tensor[flag] = value
        return tensor

    @staticmethod
    def backward(ctx, grad_output):
        flag, = ctx.saved_variables
        grad_output[flag] = 0
        return grad_output
switch = switchFunction.apply

现在,您可以简单地将switch称为switch(A < 0.1, 0, A)

编辑

实际上有一个函数可以做到这一点。它被称为阈值。你可以把它当作

代码语言:javascript
运行
复制
import torch.nn as nn
m = nn.Threshold(0.1, 0)
B = m(A)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49931756

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档