pytorch 学习笔记之编写 C 扩展

pytorch利用 CFFI 进行 C 语言扩展。包括两个基本的步骤(docs):

  1. 编写 C 代码;
  2. python 调用 C 代码,实现相应的 Function 或 Module。

在之前的文章中,我们已经了解了如何自定义 Module。至于 [py]torch 的 C 代码库的结构,我们留待之后讨论; 这里,重点关注,如何在 pytorch C 代码库高层接口的基础上,编写 C 代码,以及如何调用自己编写的 C 代码。

官方示例了如何定义一个加法运算(见 repo)。这里我们定义ReLU函数(见 repo)。

1. C 代码

pytorch C 的基本数据结构是 THTensor(THFloatTensor、THByteTensor等)。我们以简单的 ReLU 函数为例,示例编写 C 。

y=ReLU(x)=max(x,0)

Function 需要定义前向和后向两个方向的操作,因此,C 代码要实现相应的功能。

1.1 头文件声明

/* ext_lib.h */
int relu_forward(THFloatTensor *input, THFloatTensor *output);
int relu_backward(THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *grad_input);

1.2 函数实现

TH/TH.h 包括了 pytorch C 代码数据结构和函数的声明,这是唯一需要添加的 include 依赖。

/* ext_lib.c */

#include <TH/TH.h>

int relu_forward(THFloatTensor *input, THFloatTensor *output)
{
  THFloatTensor_resizeAs(output, input);
  THFloatTensor_clamp(output, input, 0, INFINITY);
  return 1;
}

int relu_backward(THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *grad_input)
{
  THFloatTensor_resizeAs(grad_input, grad_output);
  THFloatTensor_zero(grad_input);

  THLongStorage* size = THFloatTensor_newSizeOf(grad_output);
  THLongStorage *stride = THFloatTensor_newStrideOf(grad_output);
  THByteTensor *mask = THByteTensor_newWithSize(size, stride);

  THFloatTensor_geValue(mask, input, 0);
  THFloatTensor_maskedCopy(grad_input, mask, grad_output);
  return 1;
}

2. 编译代码

2.1 依赖

由于 pytorch 的代码是纯 C 的,因此没有过多的依赖,只需要安装:

  • pytorch - 安装方法见官网
  • cffi - pip install cffi

编译文件非常简单,主要是添加头文件和实现文件,以及相关的宏定义; 同时文件还指定了编译后的调用位置(此外为_ext.ext_lib):

# build.py
import os
import torch
from torch.utils.ffi import create_extension


sources = ['src/ext_lib.c']
headers = ['src/ext_lib.h']
defines = []
with_cuda = False

if torch.cuda.is_available():
    print('Including CUDA code.')
    sources += ['src/ext_lib_cuda.c']
    headers += ['src/ext_lib_cuda.h']
    defines += [('WITH_CUDA', None)]
    with_cuda = True

ffi = create_extension(
    '_ext.ext_lib',
    headers=headers,
    sources=sources,
    define_macros=defines,
    relative_to=__file__,
    with_cuda=with_cuda
)

if __name__ == '__main__':
    ffi.build()
python build.py

3. python 调用

3.1 编写配置文件

python 的调用非常简单——pytorch 的 tensor 对象,对应 C 代码的 THTensor 对象,以此作参数进行调用即可。配置文件如下:

import torch
from torch.autograd import Function
from _ext import ext_lib

class ReLUF(Function):
    def forward(self, input):
        self.save_for_backward(input)

        output = input.new()
        if not input.is_cuda:
            ext_lib.relu_forward(input, output)
        else:
            raise Exception, "No CUDA Implementation"
        return output

    def backward(self, grad_output):
        input, = self.saved_tensors

        grad_input = grad_output.new()
        if not grad_output.is_cuda:
            ext_lib.relu_backward(grad_output, input, grad_input)
        else:
            raise Exception, "No CUDA Implementation"
        return grad_input

3.2 测试

此处省略 Module 的定义。下面测试下新定义的基于 C 的 ReLU 函数。

import torch
import torch.nn as nn
from torch.autograd import Variable

from modules.relu import ReLUM

torch.manual_seed(1111)

class MyNetwork(nn.Module):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.relu = ReLUM()

    def forward(self, input):
        return self.relu(input)

model = MyNetwork()
x = torch.randn(1, 25).view(5, 5)
input = Variable(x, requires_grad=True)
output = model(input)
print(output)
print(input.clamp(min=0))

output.backward(torch.ones(input.size()))
print(input.grad.data)

输出结果如下:

Variable containing:
 0.8749  0.5990  0.6844  0.0000  0.0000
 0.6516  0.0000  1.5117  0.5734  0.0072
 0.1286  1.4171  0.0796  1.0355  0.0000
 0.0000  0.0000  0.0312  0.0999  0.0000
 1.0401  1.0599  0.0000  0.0000  0.0000
[torch.FloatTensor of size 5x5]

Variable containing:
 0.8749  0.5990  0.6844  0.0000  0.0000
 0.6516  0.0000  1.5117  0.5734  0.0072
 0.1286  1.4171  0.0796  1.0355  0.0000
 0.0000  0.0000  0.0312  0.0999  0.0000
 1.0401  1.0599  0.0000  0.0000  0.0000
[torch.FloatTensor of size 5x5]


 1  1  1  0  0
 1  0  1  1  1
 1  1  1  1  0
 0  0  1  1  0
 1  1  0  0  0

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Laoqi's Linux运维专列

正则扩展练习

grep命令的-P选项: 最典型的用法是,匹配指定字符串之间的字符。 比如,我们想在一句话(Hello,my name is aming.)中匹配中间的一段字符...

4136
来自专栏用户2442861的专栏

C++ STL空间配置源码分析以及实现一

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/haluoluo211/article/d...

1213
来自专栏扎心了老铁

python重试(指数退避算法)

本文实现了一个重试的装饰器,并且使用了指数退避算法。指数退避算法实现还是很简单的。先上代码再详细解释。 1、指数退避算法 欠奉。http://hugnew.co...

4754
来自专栏人工智能LeadAI

机器学习实战 | 第五章:模型保存(持久化)

一、工具 sklearn官方给出了两种保存模型的方式:3.4. Model persistence 其中一种是pickle的方式,还有一种就是joblib包的...

3988
来自专栏CSDN技术头条

使用Go语言来理解Tensorflow

【译者注】本文通过一个简单的Go绑定实例,让读者一步一步地学习到Tensorflow有关ID、作用域、类型等方面的知识。以下是译文。 Tensorflow并不是...

27210
来自专栏深度学习自然语言处理

【python】命令行参数argparse用法详解

prog.py是我在linux下测试argparse的文件,放在/tmp目录下,其内容如下:

1113
来自专栏FreeBuf

使用Burpsuite扩展Hackvertor绕过WAF并解密XOR

最近,我一直在忙于开发自己的一个Burp扩展Hackvertor。这是一个具有基于标签转换功能的编码器,相比起Burp内置的解码器它的功能要强大的多。通过标签的...

1421
来自专栏编程

说说正则表达式的使用

今日分享:正则表达式 一:正则表达式的定义及用途 正则表达式是一种特殊的字符串,字符串中的每个字符都含有特定的意义。使用者通过将正则中不同的字符组合成不同的字符...

2018
来自专栏一“技”之长

一个移动开发者的Mock数据之路 原

    在前端开发中,很大一部分工作都是将后台数据获取到后展示在前端界面上。如果接口是现成的,这个过程还相对容易一些,但是如果接口的开发和前端开发是同时进行的,...

751
来自专栏云霄雨霁

范式总结

1514

扫码关注云+社区

领取腾讯云代金券