首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >【Debug日志 | Transformer 泛化能力差】

【Debug日志 | Transformer 泛化能力差】

原创
作者头像
九年义务漏网鲨鱼
发布2025-09-10 14:22:29
发布2025-09-10 14:22:29
6410
代码可运行
举报
文章被收录于专栏:tencent cloudtencent cloud
运行总次数:0
代码可运行

Transformer 训不动:注意力 Mask 用反 / 广播错位,PAD 被看到导致 Loss 不降(一次从 nn.MultiheadAttention 到 SDPA 的排障实录)

当我们做中文小型 GPT时,明明 Batch/学习率都合理,但就是有些时候loss 长期不降或震荡,ppl 比基线还差。单句推理“看起来能写”,但是在验证集的效果确实很差。为了解释这种玄幻的现象,本文记录完整的bug以及debug过程,并给出PyTorch 2.x 推荐写法(F.scaled_dot_product_attention,布尔 Mask 更稳)。

❓ Bug 现象

  • 训练 loss 基本不降(或震荡 3.8~4.2),grad_norm 正常、无 NaN。
  • Batch=1 还能降一点,一到 Batch≥8 就“卡住”。
  • 评估 ppl 异常、EOS 附近预测错乱;可视化注意力发现 PAD/未来位置有高权重
  • 换优化器/学习率没用;关掉 Dropout/LayerNorm 也无改善。

📽️ 场景复现

1️⃣ 把 0/1 Mask 直接加到注意力分数上
代码语言:python
代码运行次数:0
运行
复制
import torch, torch.nn.functional as F

B, h, T, d = 4, 8, 16, 64
q = torch.randn(B, h, T, d)
k = torch.randn(B, h, T, d)
v = torch.randn(B, h, T, d)

pad_id = 0
tokens = torch.randint(1, 100, (B, T))
tokens[:, -3:] = pad_id  # 尾部 PAD

# ❌ 错:用 1 表示可见,0 表示屏蔽,然后直接“加”到分数
# (很多人从 HF attention_mask 误迁移)
vis_mask = (tokens != pad_id).float()              # [B, T],1=可见,0=PAD
scores = (q @ k.transpose(-1, -2)) / d**0.5        # [B,h,T,T]
scores = scores + vis_mask[:, None, None, :]       # ⛔ 广播后只是+0或+1,PAD仍可见
attn = torch.softmax(scores, dim=-1)               # PAD 被分到概率
out = attn @ v
2️⃣ 错误 2:key_padding_mask 语义用反(nn.MultiheadAttention)
代码语言:python
代码运行次数:0
运行
复制
import torch, torch.nn as nn

mha = nn.MultiheadAttention(embed_dim=d*h, num_heads=h, batch_first=True)
x = torch.randn(B, T, d*h)
# ❌ 传入“1=可见,0=PAD”的 mask,但 MHA 的 key_padding_mask 语义是 True=PAD
attention_mask = (tokens != pad_id)               # True=可见
out, _ = mha(x, x, x, key_padding_mask=attention_mask)  # ⛔ 语义反了

两类错误共同后果:模型能看见 PAD/未来位,梯度里掺进“无意义对齐”,尤其 Batch 一大,PAD 数量成规模,学习被噪声淹没。python

Debug过程

1️⃣ 立即打印 Mask 统计与形状
代码语言:python
代码运行次数:0
运行
复制
print("vis_mask sum per batch:", (tokens != pad_id).sum(dim=1))  # 看看 PAD 比例
print("scores:", scores.shape)           # 期望 [B,h,T,T]
print("mask:", vis_mask[:,None,None,:].shape)  # [B,1,1,T]

loss 对 Batch 极度敏感,先怀疑 PAD 处理。

2️⃣在一条样本上可视化一行注意力
代码语言:python
代码运行次数:0
运行
复制
i, t = 0, T-1   # 看最后一个非 PAD 位置的注意力
print("attn to PAD tail:", attn[i, 0, t, -4:])
# 期望 ~0;若出现明显权重,基本就是 mask 错了
3️⃣ 切到 SDPA + 布尔 Mask 验证

PyTorch 2 的 SDPA 支持 布尔 Mask(True=屏蔽),可以快速验证你的语义是否写对:

代码语言:python
代码运行次数:0
运行
复制
from torch.nn.functional import scaled_dot_product_attention as sdpa

pad_mask = (tokens == pad_id)                     # True=PAD
causal = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)  # True=未来位
attn_mask = pad_mask[:, None, None, :] | causal[None, None, :, :]    # 布尔或

out = sdpa(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
# 若这样 loss 明显变好,你的旧实现 99% 是 mask 语义/广播错了

代码调整

✅ 推荐:统一使用布尔 Mask + SDPA

布尔 Mask 在 AMP/半精度下更稳(免去 -1e9 精度问题)。

代码语言:python
代码运行次数:0
运行
复制
import torch
import torch.nn.functional as F

def causal_pad_mask(tokens: torch.Tensor, pad_id: int):
    B, T = tokens.shape
    pad_mask = (tokens == pad_id)                                     # [B,T], True=PAD
    causal = torch.triu(torch.ones(T, T, dtype=torch.bool, device=tokens.device), 1)
    # [B,1,1,T] | [1,1,T,T] → [B,1,T,T], True=屏蔽
    attn_mask = pad_mask[:, None, None, :] | causal[None, None, :, :]
    return attn_mask

def attention(q, k, v, tokens, pad_id=0, dropout_p=0.0):
    mask = causal_pad_mask(tokens, pad_id)                            # 布尔
    return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout_p)
✅ 使用 nn.MultiheadAttention 时的正确语义
  • key_padding_mask: True=PAD/忽略(与 HF 的 attention_mask 语义相反)
  • attn_mask: L,S 或 B⋯,L,S,加性 -inf。
代码语言:python
代码运行次数:0
运行
复制
mha = nn.MultiheadAttention(embed_dim=d*h, num_heads=h, batch_first=True)
key_padding_mask = (tokens == pad_id)             # True=PAD
causal = torch.triu(torch.ones(T, T, dtype=torch.bool), 1)
out, _ = mha(x, x, x, key_padding_mask=key_padding_mask, attn_mask=causal)

不要把 0/1 mask 直接“加到分数上”;若用加性 mask,请用 -inf-torch.finfo(scores.dtype).max,且注意 半精度溢出

验证结果

  • 修复后,loss 在前几千步明显下降,ppl 接近基线;
  • 可视化注意力:PAD 与未来位的权重 ≈ 0(数值 < 1e-4);
  • Batch 从 1 → 64,收敛曲线单调更稳,不再出现“大 batch 更差”的反常。

结语

很多“玄学不收敛”不是优化器的锅,而是Mask 语义/形状这类“工程正确性”。把布尔 Mask + SDPA作为默认姿势,可以基本把这类问题钉死。这类问题的最后定位是注意力 Mask 写错:把PAD/未来位当可见、或把 0/1 语义用反,甚至让 Mask 在 [B,T][B,h,T,T] 间广播错位。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Transformer 训不动:注意力 Mask 用反 / 广播错位,PAD 被看到导致 Loss 不降(一次从 nn.MultiheadAttention 到 SDPA 的排障实录)
    • ❓ Bug 现象
    • 📽️ 场景复现
    • Debug过程
    • 代码调整
    • 验证结果
    • 结语
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档