这篇文章中,我们将解释GQA的思想以及如何将其转化为代码。...这个过程如下图所示:
而GQA将查询头分成G组,每组共享一个键和值。...:对于我们定义的张量,原始维度8(查询的头数)现在被分成两组(以匹配键和值中的头数),每组大小为4。...在我们的例子中,这些张量的形状是(1,4,2,256,64)和(1,2,256,64),所以沿着最后两个维度的矩阵乘法得到(1,4,2,256,256)。...2、对第二个维度(维度g)上的元素求和——如果在指定的输出形状中省略了维度,einsum将自动完成这项工作,这样的求和是用来匹配键和值中的头的数量。