pytorch学习笔记(十八):C 语言扩展 pytorch

上篇博文已经介绍了如何通过 继承 Function ,然后使用python 来扩展 pytorch, 本文主要介绍如何通过 cffi 来扩展 pytorch

官网给出了一个 MyAddDemo github地址,本文通过 这个 Demo 来搞定如何 通过 cffi 来扩展 pytorch

自定义 OP

pytorch 自定义 op 的基本步骤总结如下。

一、C部分

  • new_op.h : CPU forward(), backward() 接口声明
  • new_op_cu.h : GPU forward(), backward() 接口声明
  • new_op.c: 实现 forward(), backward() CPU 代码
  • new_op.cu: 实现 forward(), backward() GPU 代码

二、编译上面写的 C/CUDA 代码

三、python部分:

  • Function 包装 C OP
  • Module 包装 Function

下面,来看一下 官方的 Demo

看Script 部分

Script 部分的文件结构如下:

  • src/ : 放着 C 代码
  • functions/Function 包装
  • modules/Module 包装
  • build : 编译 C 源码的 代码

C/CUDA 代码

#include <TH/TH.h>

int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
               THFloatTensor *output)
{
  if (!THFloatTensor_isSameSizeAs(input1, input2))
    return 0;
  THFloatTensor_resizeAs(output, input1);
  THFloatTensor_cadd(output, input1, 1.0, input2);
  return 1;
}

int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
{
  THFloatTensor_resizeAs(grad_input, grad_output);
  THFloatTensor_fill(grad_input, 1);
  return 1;
}

编译用代码

import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = ['src/my_lib.c']
headers = ['src/my_lib.h']
defines = []
with_cuda = False

if torch.cuda.is_available():
    print('Including CUDA code.')
    sources += ['src/my_lib_cuda.c']
    headers += ['src/my_lib_cuda.h']
    defines += [('WITH_CUDA', None)]
    with_cuda = True

ffi = create_extension(
    '_ext.my_lib', # _ext/my_lib 编译后的动态 链接库 存放路径。
    headers=headers,
    sources=sources,
    define_macros=defines,
    relative_to=__file__,
    with_cuda=with_cuda
)

if __name__ == '__main__':
    ffi.build()

Function Wrapper

import torch
from torch.autograd import Function
from _ext import my_lib
from torch.autograd import Variable


class MyAddFunction(Function):

    @staticmethod
    def forward(ctx, input1, input2):
        output = input1.new()
        if not input1.is_cuda:
            my_lib.my_lib_add_forward(input1, input2, output)
        else:
            my_lib.my_lib_add_forward_cuda(input1, input2, output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        t_grad_output = grad_output.data
        t_grad_input = t_grad_output.new().resize_as_(t_grad_output).zero_()
        grad_input = Variable(t_grad_input, requires_grad=grad_output.requires_grad, volatile=grad_output.volatile)
        if not grad_output.is_cuda:
            my_lib.my_lib_add_backward(grad_output.data, t_grad_input)
        else:
            my_lib.my_lib_add_backward_cuda(grad_output.data, t_grad_input)
        return grad_input, grad_input

Module Wrapper

class MyAddModule(Module):
    def forward(self, input1, input2):
        return MyAddFunction.apply(input1, input2)

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏一枝花算不算浪漫

JSON.toJSONString中序列化空字符串遇到的坑

最近在做系统Bug修复时遇到了一个问题,调用其他服务时传递的参数和自己预先的不一致,例如Map中有10条记录,然后使用JSON.toJSONString 包装后...

44320
来自专栏武培轩的专栏

京东面经汇总

一、Java Java的优势 平台无关性、垃圾回收 Java有哪些特性,举个多态的例子。 封装、继承、多态 abstract interface区别 含有abs...

63660
来自专栏源哥的专栏

多媒体处理类

import java.io.*; import java.util.*; import javax.servlet.http.*;

9010
来自专栏Dawnzhang的开发者手册

Spring中的@Transactional(rollbackFor = Exception.class)属性详解

今天我在写代码的时候,看到了。一个注解@Transactional(rollbackFor = Exception.class),今天就和大家分享一下,这个注解...

23710
来自专栏Golang语言社区

GO语言并发编程之互斥锁、读写锁详解

在本节,我们对Go语言所提供的与锁有关的API进行说明。这包括了互斥锁和读写锁。我们在第6章描述过互斥锁,但却没有提到过读写锁。这两种锁对于传统的并发程序来说都...

41470
来自专栏码匠的流水账

解决WebDriverWait中的cannot be applied的问题

本文主要描述下如何解决WebDriverWait中的cannot applied的问题。

9410
来自专栏C/C++基础

Linux命令(36)——awk命令

AWK是一个优良的文本处理工具,Linux及Unix环境中现有的功能最强大的数据处理引擎之一。数据可以来自标准输入(stdin)、一个或多个文件,或其它命令的输...

18820
来自专栏游戏杂谈

Node.js文件编码格式的转换

项目很多 lua 文件不是 utf-8格式,使用 EditPlus 查看的时候,显示为ASCII。还有的是带BOM的,带BOM倒好处理,之前写过,有一定规律。

24840
来自专栏领域驱动设计DDD实战进阶

14-TypeScript简单工厂模式

在TypeScript中,要调用功能,通常在调用方通过实例化被调用方对象来调用相关方法,但这种实现在调用方和被调用方形成了强耦合的关系。 另外如果被调用方有种实...

36930
来自专栏封碎

Java多线程参考手册 博客分类: 经典文章转载

http://blog.csdn.net/ring0hx/article/details/6858582

7920

扫码关注云+社区

领取腾讯云代金券