首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >RuntimeError:找到dtype但需要浮点型

RuntimeError:找到dtype但需要浮点型
EN

Stack Overflow用户
提问于 2021-10-14 13:57:52
回答 1查看 309关注 0票数 1

我正在使用Python3和Pytorch 1.9.1编写强化学习的代码。

我发布了一个问题,因为我不理解错误行。错误发生在loss.mean().backward()的行上。

据说数据类型应该有一个浮点数,但是双精度数进来了,但是无论打印多少数据类型,它都是一个浮点数32。有什么问题吗?

有问题的代码如下。

代码语言:javascript
运行
复制
def train_net_ap(self, idx):
    s, a, r, s_prime, done_mask, prob_a = self.make_batch(idx)
    print("a is ", a)

    for i in range(K_epoch):
        td_target = r + gamma * self.v_ap(s_prime) * done_mask
        delta = td_target - self.v_ap(s)
        delta = delta.detach().numpy()

        advantage_lst = []
        advantage = 0.0
        for delta_t in delta[::-1]:
            advantage = gamma * lmbda * advantage + delta_t[0]
            advantage_lst.append([advantage])
        advantage_lst.reverse()
        advantage = torch.tensor(advantage_lst, dtype=torch.float)

        pi = self.pi_ap(s, softmax_dim=1)
        pi_a = pi.gather(1, a)
        ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a))  # a/b == exp(log(a)-log(b))

        surr1 = ratio * advantage
        surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantage
        loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v_ap(s), td_target.detach())

        print("loss is ", loss)
        print("loss dtype is ", loss.dtype)
        print("loss.mean() is ", loss.mean(), loss.mean().dtype)
        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()

打印的短语和错误消息如下。

代码语言:javascript
运行
复制
loss dtype is  torch.float32 
loss.mean() is  tensor(6.1353,   grad_fn=<MeanBackward0>) torch.float32


Traceback (most recent call last):
  main()
  model.train_net_ap(x)
  loss.mean().backward()
    
  torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

RuntimeError: Found dtype Double but expected Float
EN

回答 1

Stack Overflow用户

发布于 2021-10-14 14:14:41

该错误表明它需要一个浮点数据类型,但它接收的是一个双精度类型的数据,您可以做的是将变量类型更改为本例中所需的类型,执行类似以下操作:

代码语言:javascript
运行
复制
float(double_variable)

或者,如果您需要更精确的浮点值或具有特定小数位数,则可以使用以下命令:

代码语言:javascript
运行
复制
                                   (This is an example)
v1 = 0.00582811585976
import numpy as np
np.float32(v1)
float(np.float32(v1))  #Convert to 32bit and then back to 64bit
'%.14f'%np.float32(v1) #This rounds to v2 if you're printing 14 places of precision ...
票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69572029

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档