这篇文章中,我们将解释GQA的思想以及如何将其转化为代码。...d -> b g h n d", g=num_head_groups)
print(query.shape) # torch.Size([1, 4, 2, 256, 64])
上面的代码我们将二维重塑为二维...:对于我们定义的张量,原始维度8(查询的头数)现在被分成两组(以匹配键和值中的头数),每组大小为4。...value张量的形状是一样的。...在我们的例子中,这些张量的形状是(1,4,2,256,64)和(1,2,256,64),所以沿着最后两个维度的矩阵乘法得到(1,4,2,256,256)。