pytorch学习笔记(十八):C 语言扩展 pytorch

上篇博文已经介绍了如何通过 继承 Function ,然后使用python 来扩展 pytorch, 本文主要介绍如何通过 cffi 来扩展 pytorch

官网给出了一个 MyAddDemo github地址,本文通过 这个 Demo 来搞定如何 通过 cffi 来扩展 pytorch

自定义 OP

pytorch 自定义 op 的基本步骤总结如下。

一、C部分

  • new_op.h : CPU forward(), backward() 接口声明
  • new_op_cu.h : GPU forward(), backward() 接口声明
  • new_op.c: 实现 forward(), backward() CPU 代码
  • new_op.cu: 实现 forward(), backward() GPU 代码

二、编译上面写的 C/CUDA 代码

三、python部分:

  • Function 包装 C OP
  • Module 包装 Function

下面,来看一下 官方的 Demo

看Script 部分

Script 部分的文件结构如下:

  • src/ : 放着 C 代码
  • functions/Function 包装
  • modules/Module 包装
  • build : 编译 C 源码的 代码

C/CUDA 代码

#include <TH/TH.h>

int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
               THFloatTensor *output)
{
  if (!THFloatTensor_isSameSizeAs(input1, input2))
    return 0;
  THFloatTensor_resizeAs(output, input1);
  THFloatTensor_cadd(output, input1, 1.0, input2);
  return 1;
}

int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
{
  THFloatTensor_resizeAs(grad_input, grad_output);
  THFloatTensor_fill(grad_input, 1);
  return 1;
}

编译用代码

import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = ['src/my_lib.c']
headers = ['src/my_lib.h']
defines = []
with_cuda = False

if torch.cuda.is_available():
    print('Including CUDA code.')
    sources += ['src/my_lib_cuda.c']
    headers += ['src/my_lib_cuda.h']
    defines += [('WITH_CUDA', None)]
    with_cuda = True

ffi = create_extension(
    '_ext.my_lib', # _ext/my_lib 编译后的动态 链接库 存放路径。
    headers=headers,
    sources=sources,
    define_macros=defines,
    relative_to=__file__,
    with_cuda=with_cuda
)

if __name__ == '__main__':
    ffi.build()

Function Wrapper

import torch
from torch.autograd import Function
from _ext import my_lib
from torch.autograd import Variable


class MyAddFunction(Function):

    @staticmethod
    def forward(ctx, input1, input2):
        output = input1.new()
        if not input1.is_cuda:
            my_lib.my_lib_add_forward(input1, input2, output)
        else:
            my_lib.my_lib_add_forward_cuda(input1, input2, output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        t_grad_output = grad_output.data
        t_grad_input = t_grad_output.new().resize_as_(t_grad_output).zero_()
        grad_input = Variable(t_grad_input, requires_grad=grad_output.requires_grad, volatile=grad_output.volatile)
        if not grad_output.is_cuda:
            my_lib.my_lib_add_backward(grad_output.data, t_grad_input)
        else:
            my_lib.my_lib_add_backward_cuda(grad_output.data, t_grad_input)
        return grad_input, grad_input

Module Wrapper

class MyAddModule(Module):
    def forward(self, input1, input2):
        return MyAddFunction.apply(input1, input2)

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏未闻Code

记住变量类型的三种方式

Python作为一门动态语言,其变量的类型可以自由变化。这个特性提高了代码的开发效率,却也增加了阅读代码和维护代码的难度。

1011
来自专栏性能与架构

Linux 内存性能指标

内存基础概念 先执行一下 top 命令,看结果中关于内存的相关部分 # top ? 其中的 VIRT、RES、SWAP 都是什么呢? 分别是下面的3个概念 ...

3685
来自专栏xingoo, 一个梦想做发明家的程序员

非递归版归并排序

非递归版的归并排序,省略了中间的栈空间,直接申请一段O(n)的地址空间即可,因此空间复杂度为O(n),时间复杂度为O(nlogn); 算法思想:   开始以间隔...

1907
来自专栏xingoo, 一个梦想做发明家的程序员

快速排序

快速排序时间复杂度为O(nlogn),由于是在原数组上面利用替换来实现,因此不需要额外的存储空间。 算法思想:   通过设置一个岗哨,每次跟这个岗哨进行比较,比...

18110
来自专栏漫漫深度学习路

pytorch学习笔记(十):learning rate decay(学习率衰减)

pytorch learning rate decay 本文主要是介绍在pytorch中如何使用learning rate decay. 先上代码: def...

54110
来自专栏Petrichor的专栏

tensorflow编程: Running Graphs

  A class for running TensorFlow operations.   这是一个类,执行 tensorflow 中的 op 。它里面定...

1252
来自专栏人工智能LeadAI

TensorFlow会话的配置项

01 TensorFlow配置项的文档位于这里 TensorFlow可以通过指定配置项,来配置需要运行的会话,示例代码如下: run_config = tf.C...

4314
来自专栏nummy

Python数据科学手册(三)【Pandas的对象介绍】

Pandas构建在Numpy的基础上,它同时支持行和列的操作。 使用pip进行安装:

393
来自专栏光变

Java中关于i=i++的问题解些

JVM在方法体中的操作指令,一部分是直接作用stack栈,也有一些部分是直接操作Local Variable(本地变量区/局部变量区)。

501
来自专栏Fish

零拷贝内存 or 页锁定内存

这是一个小实验,在于验证GPU上使用零拷贝内存和页锁定内存的性能差别。使用的是点积计算,数据量在100M左右。实验步骤很简单,分别在主机上开辟普通内存,页锁定内...

2025

扫码关注云+社区