nn.MultiheadAttention
到 SDPA 的排障实录)当我们做中文小型 GPT时,明明 Batch/学习率都合理,但就是有些时候loss 长期不降或震荡,ppl
比基线还差。单句推理“看起来能写”,但是在验证集的效果确实很差。为了解释这种玄幻的现象,本文记录完整的bug以及debug过程,并给出PyTorch 2.x 推荐写法(F.scaled_dot_product_attention
,布尔 Mask 更稳)。
grad_norm
正常、无 NaN。ppl
异常、EOS 附近预测错乱;可视化注意力发现 PAD/未来位置有高权重。import torch, torch.nn.functional as F
B, h, T, d = 4, 8, 16, 64
q = torch.randn(B, h, T, d)
k = torch.randn(B, h, T, d)
v = torch.randn(B, h, T, d)
pad_id = 0
tokens = torch.randint(1, 100, (B, T))
tokens[:, -3:] = pad_id # 尾部 PAD
# ❌ 错:用 1 表示可见,0 表示屏蔽,然后直接“加”到分数
# (很多人从 HF attention_mask 误迁移)
vis_mask = (tokens != pad_id).float() # [B, T],1=可见,0=PAD
scores = (q @ k.transpose(-1, -2)) / d**0.5 # [B,h,T,T]
scores = scores + vis_mask[:, None, None, :] # ⛔ 广播后只是+0或+1,PAD仍可见
attn = torch.softmax(scores, dim=-1) # PAD 被分到概率
out = attn @ v
2:key_padding_ma
sk 语义用反(nn.MultiheadAttenti
on)import torch, torch.nn as nn
mha = nn.MultiheadAttention(embed_dim=d*h, num_heads=h, batch_first=True)
x = torch.randn(B, T, d*h)
# ❌ 传入“1=可见,0=PAD”的 mask,但 MHA 的 key_padding_mask 语义是 True=PAD
attention_mask = (tokens != pad_id) # True=可见
out, _ = mha(x, x, x, key_padding_mask=attention_mask) # ⛔ 语义反了
两类错误共同后果:模型能看见 PAD/未来位,梯度里掺进“无意义对齐”,尤其 Batch 一大,PAD 数量成规模,学习被噪声淹没。python
print("vis_mask sum per batch:", (tokens != pad_id).sum(dim=1)) # 看看 PAD 比例
print("scores:", scores.shape) # 期望 [B,h,T,T]
print("mask:", vis_mask[:,None,None,:].shape) # [B,1,1,T]
若
loss
对 Batch 极度敏感,先怀疑 PAD 处理。
i, t = 0, T-1 # 看最后一个非 PAD 位置的注意力
print("attn to PAD tail:", attn[i, 0, t, -4:])
# 期望 ~0;若出现明显权重,基本就是 mask 错了
PyTorch 2 的 SDPA 支持 布尔 Mask(True=屏蔽),可以快速验证你的语义是否写对:
from torch.nn.functional import scaled_dot_product_attention as sdpa
pad_mask = (tokens == pad_id) # True=PAD
causal = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1) # True=未来位
attn_mask = pad_mask[:, None, None, :] | causal[None, None, :, :] # 布尔或
out = sdpa(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
# 若这样 loss 明显变好,你的旧实现 99% 是 mask 语义/广播错了
布尔 Mask 在 AMP/半精度下更稳(免去 -1e9
精度问题)。
import torch
import torch.nn.functional as F
def causal_pad_mask(tokens: torch.Tensor, pad_id: int):
B, T = tokens.shape
pad_mask = (tokens == pad_id) # [B,T], True=PAD
causal = torch.triu(torch.ones(T, T, dtype=torch.bool, device=tokens.device), 1)
# [B,1,1,T] | [1,1,T,T] → [B,1,T,T], True=屏蔽
attn_mask = pad_mask[:, None, None, :] | causal[None, None, :, :]
return attn_mask
def attention(q, k, v, tokens, pad_id=0, dropout_p=0.0):
mask = causal_pad_mask(tokens, pad_id) # 布尔
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout_p)
nn.MultiheadAttention
时的正确语义key_padding_mask
: True=PAD/忽略(与 HF 的 attention_mask
语义相反)attn_mask
: L,S 或 B⋯,L,S,加性 -inf。mha = nn.MultiheadAttention(embed_dim=d*h, num_heads=h, batch_first=True)
key_padding_mask = (tokens == pad_id) # True=PAD
causal = torch.triu(torch.ones(T, T, dtype=torch.bool), 1)
out, _ = mha(x, x, x, key_padding_mask=key_padding_mask, attn_mask=causal)
不要把 0/1 mask 直接“加到分数上”;若用加性 mask,请用
-inf
或-torch.finfo(scores.dtype).max
,且注意 半精度溢出。
ppl
接近基线;很多“玄学不收敛”不是优化器的锅,而是Mask 语义/形状这类“工程正确性”。把布尔 Mask + SDPA作为默认姿势,可以基本把这类问题钉死。这类问题的最后定位是注意力 Mask 写错:把PAD/未来位当可见、或把 0/1 语义用反,甚至让 Mask 在 [B,T]
和 [B,h,T,T]
间广播错位。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。