首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在CUDA中用SIMD实现位旋转算子

在CUDA中用SIMD实现位旋转算子
EN

Stack Overflow用户
提问于 2017-08-27 00:02:24
回答 2查看 2.5K关注 0票数 6

我知道StackOverflow并不是用来向其他人询问代码的,但是让我来谈谈。

我试图在CUDA C++设备代码中实现一些AES功能。当我试图实现左转操作符时,我感到很困惑,因为我发现没有任何本机SIMD对此感兴趣。所以我开始了一个天真的实现,但是.它是巨大的,虽然我还没有试过它,但由于昂贵的解压/包装,它不会很快.那么,是否有一种方法来执行至少有点效率的每字节位旋转操作?

这是密码,如果你不想看的话。

代码语言:javascript
运行
复制
__inline__ __device__ uint32_t per_byte_bit_left_rotate(uint32_t input, uint8_t amount) {
return ((((input & 0xFF) >> 0) << amount) | (((input & 0xFF) >> 0) >> 7) & ~0x100) << 0 |
     ((((input & 0xFF00) >> 8) << amount) | ((input & 0xFF00 >> 8) >> 7) & ~0x100) << 8 |
     ((((input & 0xFF0000) >> 16) << amount) | ((input & 0xFF0000 >> 16) >> 7) & ~0x100) << 16 |
     ((((input & 0xFF000000) >> 24) << amount) | ((input & 0xFF000000 >> 24) >> 7) & ~0x100) << 24; } // The XORs are for clearing the old 7th bit who is getting pushed to the next byte of the intermediate int
EN

Stack Overflow用户

发布于 2017-08-27 05:14:39

code有一个__byte_perm()内部,它直接映射到机器代码( has )级别的PRMT指令,这是一个按字节排列的指令。它可以用于有效地提取和合并字节。为了影响一个字节的左旋转,我们可以将每个字节加倍,按所需的数量移动字节对,然后提取和合并这四个字节对的高字节。

对于字节级旋转,我们只需要最小的三位移位量,因为s的旋转和s mod 8的旋转是一样的。为了提高效率,最好避免包含小于32位的整数类型,因为C++语义要求在表达式中使用之前将比int窄的整数类型扩展到int。这可能并确实会在包括GPU在内的许多体系结构上产生转换开销。

PRMT指令的吞吐量依赖于体系结构,因此使用__byte_perm()可能会导致代码比使用另一个答案中演示的经典SIMD-a-寄存器方法更快或更慢,因此在部署之前一定要在用例上下文中进行基准测试。

代码语言:javascript
运行
复制
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>

__device__ uint32_t per_byte_bit_left_rotate (uint32_t input, uint32_t amount)
{
     uint32_t l = __byte_perm (input, 0, 0x1100) << (amount & 7);
     uint32_t h = __byte_perm (input, 0, 0x3322) << (amount & 7);
     return __byte_perm (l, h, 0x7531);
}

__global__ void rotl_kernel (uint32_t input, uint32_t amount, uint32_t *res)
{
    *res = per_byte_bit_left_rotate (input, amount);
}

uint32_t ref_per_byte_bit_left_rotate (uint32_t input, uint32_t amount)
{
   int s = amount & 7;
   uint8_t b0 = (input >>  0) & 0xff;
   uint8_t b1 = (input >>  8) & 0xff;
   uint8_t b2 = (input >> 16) & 0xff;
   uint8_t b3 = (input >> 24) & 0xff;
   b0 = s ? ((b0 << s) | (b0 >> (8 - s))) : b0;
   b1 = s ? ((b1 << s) | (b1 >> (8 - s))) : b1;
   b2 = s ? ((b2 << s) | (b2 >> (8 - s))) : b2;
   b3 = s ? ((b3 << s) | (b3 >> (8 - s))) : b3;
   return (b3 << 24) | (b2 << 16) | (b1 << 8) | (b0 << 0);
}

// Fixes via: Greg Rose, KISS: A Bit Too Simple. http://eprint.iacr.org/2011/007
static unsigned int z=362436069,w=521288629,jsr=362436069,jcong=123456789;
#define znew (z=36969*(z&0xffff)+(z>>16))
#define wnew (w=18000*(w&0xffff)+(w>>16))
#define MWC  ((znew<<16)+wnew)
#define SHR3 (jsr^=(jsr<<13),jsr^=(jsr>>17),jsr^=(jsr<<5)) /* 2^32-1 */
#define CONG (jcong=69069*jcong+13579)                     /* 2^32 */
#define KISS ((MWC^CONG)+SHR3)

// Macro to catch CUDA errors in CUDA runtime calls
#define CUDA_SAFE_CALL(call)                                          \
do {                                                                  \
    cudaError_t err = call;                                           \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString(err) );       \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
} while (0)

// Macro to catch CUDA errors in kernel launches
#define CHECK_LAUNCH_ERROR()                                          \
do {                                                                  \
    /* Check synchronous errors, i.e. pre-launch */                   \
    cudaError_t err = cudaGetLastError();                             \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString(err) );       \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
    /* Check asynchronous errors, i.e. kernel failed (ULF) */         \
    err = cudaThreadSynchronize();                                    \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString( err) );      \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
} while (0)

int main (void)
{
    uint32_t arg, ref, res = 0, *res_d = 0;
    uint32_t shft;

    CUDA_SAFE_CALL (cudaMalloc ((void**)&res_d, sizeof(*res_d)));
    for (int i = 0; i < 100000; i++) {
        arg  = KISS;
        shft = KISS;
        ref = ref_per_byte_bit_left_rotate (arg, shft);
        rotl_kernel <<<1,1>>>(arg, shft, res_d);
        CHECK_LAUNCH_ERROR();
        CUDA_SAFE_CALL (cudaMemcpy (&res, res_d, sizeof (res), 
                                    cudaMemcpyDeviceToHost));
        if (res != ref) {
            printf ("!!!! arg=%08x shft=%d  res=%08x  ref=%08x\n", 
                    arg, shft, res, ref);
        }
    }
    CUDA_SAFE_CALL (cudaFree (res_d));
    CUDA_SAFE_CALL (cudaDeviceSynchronize());
    return EXIT_SUCCESS;
}
票数 6
EN
查看全部 2 条回答
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45900662

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档