我知道StackOverflow并不是用来向其他人询问代码的,但是让我来谈谈。
我试图在CUDA C++设备代码中实现一些AES功能。当我试图实现左转操作符时,我感到很困惑,因为我发现没有任何本机SIMD对此感兴趣。所以我开始了一个天真的实现,但是.它是巨大的,虽然我还没有试过它,但由于昂贵的解压/包装,它不会很快.那么,是否有一种方法来执行至少有点效率的每字节位旋转操作?
这是密码,如果你不想看的话。
__inline__ __device__ uint32_t per_byte_bit_left_rotate(uint32_t input, uint8_t amount) {
return ((((input & 0xFF) >> 0) << amount) | (((input & 0xFF) >> 0) >> 7) & ~0x100) << 0 |
     ((((input & 0xFF00) >> 8) << amount) | ((input & 0xFF00 >> 8) >> 7) & ~0x100) << 8 |
     ((((input & 0xFF0000) >> 16) << amount) | ((input & 0xFF0000 >> 16) >> 7) & ~0x100) << 16 |
     ((((input & 0xFF000000) >> 24) << amount) | ((input & 0xFF000000 >> 24) >> 7) & ~0x100) << 24; } // The XORs are for clearing the old 7th bit who is getting pushed to the next byte of the intermediate int发布于 2017-08-27 02:06:05
所有元素的旋转计数是相同的,对吗?
向左和右移动整个输入,然后是和那些带掩码的人,即所有跨越字节边界的位为零,所有4个字节都在一个和。我认为amount在AES中总是一个编译时常数,所以您不必担心动态生成掩码的运行时成本。就让编译器来做吧。(IDK CUDA,但这似乎与为普通斯瓦尔比特黑客编写32位整数的C++问题相同)
这是基于通常的旋转成语,有掩蔽和一个不同的右移计数,使它成为单独的8位旋转。
inline
uint32_t per_byte_bit_left_rotate(uint32_t input, unsigned amount)
{
    // With constant amount, the left/right masks are constants
    uint32_t rmask = 0xFF >> ((8 - amount) & 7);
    rmask = (rmask<<24 | rmask<<16 | rmask<<8 | rmask);
    uint32_t lmask = ~rmask;
    uint32_t lshift = input << amount;
    lshift &= lmask;
    if (amount == 1) {  // special case left-shift by 1 using an in-lane add instead of shift&mask
        lshift = __vadd4(input, input);
    }
    uint32_t rshift = input >> ((8 - amount) & 7);
    rshift &= rmask;
    uint32_t rotated = lshift | rshift;
    return rotated;
}在换档前用一种方式屏蔽输入,在移位后屏蔽输出可能更有效((in&lmask)<<amount | ((in>>(8-amount))&rmask),使用不同的l掩码)。NVidia硬件是有序的超标量,而班次吞吐量有限。是.这样做更有可能作为两个独立的shift+mask对执行。
(这并不是为了避免使用C++ UB和amount>=32。见C++中循环移位(旋转)操作的最佳实践。在这种情况下,我认为更改为lshift = input << (amount & 7)可以做到这一点。
为了测试这是否有效地编译,我使用一个常量的asm输出查看x86-64的amount。戈德波特编译器浏览器有各种架构的编译器(但不是CUDA ),所以单击该链接并翻转到ARM、MIPS或PowerPC,如果您可以比x86更容易地读取这些asm语言的话。
uint32_t rol7(uint32_t a) {
    return per_byte_bit_left_rotate(a, 7);
}
    mov     eax, edi
    shl     eax, 7
    shr     edi
    and     eax, -2139062144   # 0x80808080
    and     edi, 2139062143    # 0x7F7F7F7F
    lea     eax, [rdi + rax]   # ADD = OR when no bits intersect
    ret太完美了,正是我所希望的。
几个测试用例:
uint32_t test_rol() {
    return per_byte_bit_left_rotate(0x02ffff04, 0);
}
    // yup, returns the input with count=0
    // return 0x2FFFF04
uint32_t test2_rol() {
    return per_byte_bit_left_rotate(0x02f73804, 4);
}
    // yup, swaps nibbles
    // return 0x207F8340对于使用x86 SSE2 / AVX2的8位移位,这是相同的情况,因为硬件支持的最小移位粒度是16位。
发布于 2017-08-27 05:14:39
code有一个__byte_perm()内部,它直接映射到机器代码( has )级别的PRMT指令,这是一个按字节排列的指令。它可以用于有效地提取和合并字节。为了影响一个字节的左旋转,我们可以将每个字节加倍,按所需的数量移动字节对,然后提取和合并这四个字节对的高字节。
对于字节级旋转,我们只需要最小的三位移位量,因为s的旋转和s mod 8的旋转是一样的。为了提高效率,最好避免包含小于32位的整数类型,因为C++语义要求在表达式中使用之前将比int窄的整数类型扩展到int。这可能并确实会在包括GPU在内的许多体系结构上产生转换开销。
PRMT指令的吞吐量依赖于体系结构,因此使用__byte_perm()可能会导致代码比使用另一个答案中演示的经典SIMD-a-寄存器方法更快或更慢,因此在部署之前一定要在用例上下文中进行基准测试。
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
__device__ uint32_t per_byte_bit_left_rotate (uint32_t input, uint32_t amount)
{
     uint32_t l = __byte_perm (input, 0, 0x1100) << (amount & 7);
     uint32_t h = __byte_perm (input, 0, 0x3322) << (amount & 7);
     return __byte_perm (l, h, 0x7531);
}
__global__ void rotl_kernel (uint32_t input, uint32_t amount, uint32_t *res)
{
    *res = per_byte_bit_left_rotate (input, amount);
}
uint32_t ref_per_byte_bit_left_rotate (uint32_t input, uint32_t amount)
{
   int s = amount & 7;
   uint8_t b0 = (input >>  0) & 0xff;
   uint8_t b1 = (input >>  8) & 0xff;
   uint8_t b2 = (input >> 16) & 0xff;
   uint8_t b3 = (input >> 24) & 0xff;
   b0 = s ? ((b0 << s) | (b0 >> (8 - s))) : b0;
   b1 = s ? ((b1 << s) | (b1 >> (8 - s))) : b1;
   b2 = s ? ((b2 << s) | (b2 >> (8 - s))) : b2;
   b3 = s ? ((b3 << s) | (b3 >> (8 - s))) : b3;
   return (b3 << 24) | (b2 << 16) | (b1 << 8) | (b0 << 0);
}
// Fixes via: Greg Rose, KISS: A Bit Too Simple. http://eprint.iacr.org/2011/007
static unsigned int z=362436069,w=521288629,jsr=362436069,jcong=123456789;
#define znew (z=36969*(z&0xffff)+(z>>16))
#define wnew (w=18000*(w&0xffff)+(w>>16))
#define MWC  ((znew<<16)+wnew)
#define SHR3 (jsr^=(jsr<<13),jsr^=(jsr>>17),jsr^=(jsr<<5)) /* 2^32-1 */
#define CONG (jcong=69069*jcong+13579)                     /* 2^32 */
#define KISS ((MWC^CONG)+SHR3)
// Macro to catch CUDA errors in CUDA runtime calls
#define CUDA_SAFE_CALL(call)                                          \
do {                                                                  \
    cudaError_t err = call;                                           \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString(err) );       \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
} while (0)
// Macro to catch CUDA errors in kernel launches
#define CHECK_LAUNCH_ERROR()                                          \
do {                                                                  \
    /* Check synchronous errors, i.e. pre-launch */                   \
    cudaError_t err = cudaGetLastError();                             \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString(err) );       \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
    /* Check asynchronous errors, i.e. kernel failed (ULF) */         \
    err = cudaThreadSynchronize();                                    \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString( err) );      \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
} while (0)
int main (void)
{
    uint32_t arg, ref, res = 0, *res_d = 0;
    uint32_t shft;
    CUDA_SAFE_CALL (cudaMalloc ((void**)&res_d, sizeof(*res_d)));
    for (int i = 0; i < 100000; i++) {
        arg  = KISS;
        shft = KISS;
        ref = ref_per_byte_bit_left_rotate (arg, shft);
        rotl_kernel <<<1,1>>>(arg, shft, res_d);
        CHECK_LAUNCH_ERROR();
        CUDA_SAFE_CALL (cudaMemcpy (&res, res_d, sizeof (res), 
                                    cudaMemcpyDeviceToHost));
        if (res != ref) {
            printf ("!!!! arg=%08x shft=%d  res=%08x  ref=%08x\n", 
                    arg, shft, res, ref);
        }
    }
    CUDA_SAFE_CALL (cudaFree (res_d));
    CUDA_SAFE_CALL (cudaDeviceSynchronize());
    return EXIT_SUCCESS;
}https://stackoverflow.com/questions/45900662
复制相似问题