首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在哪里可以找到torch.unique()的源代码?

在哪里可以找到torch.unique()的源代码?
EN

Stack Overflow用户
提问于 2022-01-22 00:57:06
回答 1查看 977关注 0票数 2

我只能在pytorch源代码(https://github.com/pytorch/pytorch/blob/2367face24afb159f73ebf40dc6f23e46132b770/torch/functional.py#L783)中找到以下函数调用:

_VF.unique_dim()torch._unique2()

但它们没有指向目录中的其他任何位置

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-01-22 04:05:54

大部分的火炬后端代码都是用C++和/或CUDA实现的。要查看它,需要在源代码中找到适当的入口点。有几种方法可以做到这一点,但我发现,在没有下载所有代码的情况下,最简单的方法就是在github上搜索关键字。

例如,如果您转到github.com并搜索unique_dim repo:pytorch/pytorch,然后单击左侧的"Code“选项卡,您将很快找到以下内容。

来自builtins.py:103

代码语言:javascript
运行
复制
 17: _builtin_ops = [
...
103:    (torch._VF.unique_dim, "aten::unique_dim"),

通过对代码的进一步分析,我们可以得出结论,torch._VF.unique_dim实际上是从ATen库调用aten::unique_dim函数。

ATen中的大多数函数一样,该函数有多个实现。大多数ATen函数都是在functions.yaml中注册的,通常这里的函数都有一个_cpu_cuda版本。

回到搜索结果,我们可以发现CUDA实现实际上正在调用函数unique_dim_cuda at ATen/src/ATen/本地/cuda/Unique.cu:197

代码语言:javascript
运行
复制
196: std::tuple<Tensor, Tensor, Tensor>
197: unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
198:   return AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, self.scalar_type(), "unique_dim", [&] {
199:     return unique_dim_cuda_template<scalar_t>(self, dim, false, return_inverse, return_counts);
200:   });
201: }

并且CPU实现在ATen/src/ATen/本地/Unique.cpp:271上调用函数ATen/src/ATen/本地/Unique.cpp:271

代码语言:javascript
运行
复制
270: std::tuple<Tensor, Tensor, Tensor>
271: unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
272:   return AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "unique_dim", [&] {
273:     // The current implementation using `dim` always sorts due to unhashable tensors
274:     return _unique_dim_cpu_template<scalar_t>(self, dim, false, return_inverse, return_counts);
275:   });
276: }

从这一点开始,您应该能够进一步跟踪函数调用,以查看它们到底在做什么。

在进行类似的搜索之后,您应该会发现,torch._unique2分别在ATen/src/ATen/本地/cuda/Unique.cu:188ATen/src/ATen/本地/Unique.cpp:264上实现。

票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70809160

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档