我想复制代码这里,在Google中运行时会得到以下错误?
(/usr/local/lib/python3.7/dist-packages/torch/autograd/gradcheck.py):无法从“torch.autograd.gradcheck”torch.autograd.gradcheck导入名称“zero_gradients”
有人能帮我解决这个问题吗?
发布于 2021-07-17 14:41:42
这似乎是在使用一个非常旧的PyTorch版本,这个函数本身不再可用了。但是,如果您查看此承诺,您将看到zero_gradients
的实现。它所做的只是将输入的梯度为零:
def zero_gradients(i):
for t in iter_gradients(i):
t.zero_()
那么zero_gradients(x)
应该与x.zero_grad()
相同,这是当前的API,假设x
是nn.Module
!
或者只是:
if x.grad is not None:
x.grad.zero_()
https://stackoverflow.com/questions/68419612
复制相似问题