首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何求出张量沿维数的最大值?

如何求出张量沿维数的最大值?
EN

Stack Overflow用户
提问于 2022-09-30 00:10:26
回答 2查看 131关注 0票数 0

我有一个三维张量,我想取最大值沿第0维在利比里亚火炬。

我知道如何用Python (PyTorch)来完成这个任务,但在LibTorch中却遇到了困难。

在LibTorch中,我的代码是

代码语言:javascript
复制
auto target_q_T = torch::rand({5, 10, 1});
auto max_q = torch::max({target_q_T}, 0);
std::cout << max_q;

它会返回这么长的重复错误。

代码语言:javascript
复制
note: candidate: ‘template<class _Traits> std::basic_ostream<char, _Traits>& std::operator<<(std::basic_ostream<char, _Traits>&, const char*)’
  611 |     operator<<(basic_ostream<char, _Traits>& __out, const char* __s)
      |     ^~~~~~~~
/usr/include/c++/11/ostream:611:5: note:   template argument deduction/substitution failed:
/home/iii/tor/m_gym/multiv_normal.cpp:432:18: note:   cannot convert ‘max_q’ (type ‘std::tuple<at::Tensor, at::Tensor>’) to type ‘const char*’
  432 |     std::cout << max_q;
      |                  ^~~~~
In file included from /usr/include/c++/11/istream:39,
                 from /usr/include/c++/11/sstream:38,
                 from /home/iii/tor/m_gym/libtorch/include/c10/macros/Macros.h:246,
                 from /home/iii/tor/m_gym/libtorch/include/c10/core/DeviceType.h:8,
                 from /home/iii/tor/m_gym/libtorch/include/c10/core/Device.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/ATen/core/TensorBody.h:11,
                 from /home/iii/tor/m_gym/libtorch/include/ATen/core/Tensor.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/ATen/Tensor.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/function_hook.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/cpp_hook.h:2,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/variable.h:6,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/autograd.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/autograd.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/all.h:7,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/torch.h:3,
                 from /home/iii/tor/m_gym/multiv_normal.cpp:2:
/usr/include/c++/11/ostream:624:5: note: candidate: ‘template<class _Traits> std::basic_ostream<char, _Traits>& std::operator<<(std::basic_ostream<char, _Traits>&, const signed char*)’
  624 |     operator<<(basic_ostream<char, _Traits>& __out, const signed char* __s)
      |     ^~~~~~~~

这就是它在Python中的工作方式。

代码语言:javascript
复制
target_q_np = torch.rand(5, 10, 1)
max_q = torch.max(target_q_np, 0)
max_q

torch.return_types.max(
values=tensor([[0.8517],
        [0.7526],
        [0.6546],
        [0.9913],
        [0.8521],
        [0.9757],
        [0.9080],
        [0.9376],
        [0.9901],
        [0.7445]]),
indices=tensor([[4],
        [2],
        [3],
        [4],
        [1],
        [0],
        [2],
        [4],
        [4],
        [4]]))
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2022-10-01 03:48:06

如果您读取编译器错误,它基本上告诉您您正在尝试打印两个张量的元组。这是因为C++代码的工作方式与python代码完全一样,并返回最大值及其各自的索引(您的python代码正是这样打印的)。您需要性病获取从元组中提取张量:

代码语言:javascript
复制
auto target_q_T = torch::rand({5, 10, 1});
auto max_q = torch::max({target_q_T}, 0);
std::cout << "max: " << std::get<0>(max_q) 
          << "indices: " << std::get<1>(max_q)
          << std::endl;

在C++17中,您还应该能够编写

代码语言:javascript
复制
auto [max_t, idx_t] = torch::max({target_q_T}, 0);
std::cout << ... ;
票数 1
EN

Stack Overflow用户

发布于 2022-10-01 02:50:16

我从来没有发现max在LibTorch中的相同用法,就像在PyTorch中一样,所以我做了一个解决方案。

max在LibTorch中将从一个一维数组中获取最大值,因此我在索引数组上循环并连接结果。它实际上返回与torch.max(target_q_np,0)相同的内容。

我的解决方案在LibTorch (C++)。最大值数组以反向顺序返回,作为原始张量,因此我将其翻转。

代码语言:javascript
复制
auto target_q_T = torch::rand({5, 10, 1});

torch::Tensor zero_max;
for (int i=0; i<5; i++) {
    zero_max = torch::cat({torch::max({target_q_T[i]}).unsqueeze(-1), zero_max}, 0);
}
zero_max = zero_max.flip(-1);
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73902752

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档