首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >语言模型 loss 居高不下?未移位 labels、PAD 未忽略、注意力 mask 错误的三连坑

语言模型 loss 居高不下?未移位 labels、PAD 未忽略、注意力 mask 错误的三连坑

原创
作者头像
九年义务漏网鲨鱼
发布2025-12-26 16:11:27
发布2025-12-26 16:11:27
1820
举报

语言模型 loss 居高不下?未移位 labels、PAD 未忽略、注意力 mask 错误的三连坑

场景:做自回归语言建模(next-token prediction)。训练 loss 长期在 3.x~5.x,perplexity 居高不下;采样生成大量 PAD/EOS;训练集 loss 偶尔还能降,但验证集几乎不动。复盘后是三件老问题叠加:

  1. 没有把 labels 做“右移一位”(off-by-one);
  2. 交叉熵没有 ignore PAD,PAD 也被当成监督;
  3. 自注意力没有因果 mask,当前位置偷看未来 token。

下面给出一个 CPU 可跑的极小脚本,复现问题与一键修复模板。


Bug 现象

  • 训练 loss 波动或缓慢下降,验证基本不动;
  • 采样输出里充斥 PAD/EOS,或者重复当前 token;
  • 打印注意力掩码与 label 对齐后,发现没有做 shift,且没有因果 mask。

场景复现(两行命令 A/B)

保存为 lm_shift_pad_mask_pitfalls.py

代码语言:python
复制
# lm_shift_pad_mask_pitfalls.py
import argparse, math, torch, torch.nn as nn, torch.nn.functional as F
torch.manual_seed(0)

PAD, BOS, EOS = 0, 1, 2

class TinyLM(nn.Module):
    def __init__(self, vocab=128, d=128, nlayer=2, nhead=4, max_len=64):
        super().__init__()
        self.tok = nn.Embedding(vocab, d)
        self.pos = nn.Embedding(max_len, d)
        enc_layer = nn.TransformerEncoderLayer(d_model=d, nhead=nhead, dim_feedforward=256, batch_first=True, norm_first=True)
        self.enc = nn.TransformerEncoder(enc_layer, num_layers=nlayer)
        self.proj = nn.Linear(d, vocab)
        self.max_len = max_len
    def forward(self, x, pad_mask=None, causal_mask=None):
        B, T = x.shape
        pos_id = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)
        h = self.tok(x) + self.pos(pos_id)
        h = self.enc(h, mask=causal_mask, src_key_padding_mask=pad_mask)  # 注意 mask 语义:True=masked
        return self.proj(h)  # [B,T,V]

def make_batch(B=64, T=32, vocab=128):
    # 构造:BOS + 随机 token(长度5~T-2) + EOS + PAD
    x = torch.full((B, T), PAD, dtype=torch.long)
    for b in range(B):
        L = torch.randint(5, T-1, (1,)).item()
        core = torch.randint(3, vocab, (L-2,))  # 避开特殊符
        seq = torch.cat([torch.tensor([BOS]), core, torch.tensor([EOS])], dim=0)
        x[b, :len(seq)] = seq
    return x

def subsequent_mask(T, device):
    # True=mask 掉未来
    return torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)

def run(bug=True, steps=400, device="cpu"):
    vocab, T, B = 128, 32, 64
    model = TinyLM(vocab=vocab, max_len=T).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)

    def perplexity(loss): return float(math.exp(min(20.0, loss)))

    for it in range(1, steps+1):
        x = make_batch(B=B, T=T, vocab=vocab).to(device)

        if bug:
            # 三连坑:未移位、未忽略PAD、无因果mask
            inp = x                               # [B,T]
            lab = x                               # [B,T]
            pad_mask = None
            causal = None
            ignore_index = -100                   # 但我们没用到
        else:
            # 修复:labels 右移、忽略 PAD、加因果 mask
            inp = x[:, :-1]                       # 输入到 T-1
            lab = x[:, 1:]                        # 预测下一个
            pad_mask = (inp == PAD)               # [B,T-1] True=mask
            causal = subsequent_mask(inp.size(1), device)  # [T-1,T-1]
            ignore_index = PAD

        logits = model(inp, pad_mask=pad_mask, causal_mask=causal)  # [B,Ti,V]
        loss = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            lab.reshape(-1),
            ignore_index=ignore_index
        )

        opt.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        if it % 50 == 0:
            print(f"[{'BUG' if bug else 'FIX'}] step={it:04d} loss={loss.item():.3f} ppl≈{perplexity(loss.item()):.1f}")

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--bug", choices=["on","off"], default="on")
    args = ap.parse_args()
    print("== 错误设置 =="); run(bug=True)
    print("\n== 正确设置 =="); run(bug=False)

典型观测:

  • 错误设置:loss 与 ppl 长期维持在高位,波动大;
  • 正确设置:loss 稳定下降,ppl 明显降低。

Debug 过程

  1. 检查 label 是否右移 打印一小段样本的 inp:1 与 lab:1,确认 lab 是 inp 向右平移一位,且第一位对应 BOS 后的 token。
  2. 检查损失是否忽略 PAD 统计 mask_PAD_in_loss = (lab == PAD).float().mean();若远大于 0 且没有 ignore_index=PAD,那就是把 PAD 当监督了。
  3. 检查因果 mask 是否生效 取长度 T,验证 causal_mask 的上三角是否为 True,且传入 Transformer 的 mask 语义正确(bool 的 True=masked)。
  4. 采样自检 关掉教师强制,按自回归逐步采样几步。若输出充斥 PAD/EOS,多半是把 PAD 学进去了或偷看了未来。

代码修改要点(模板)

语言建模前向与损失:

代码语言:python
复制
# x: [B,T] 包含 BOS..EOS..PAD
inp = x[:, :-1]                  # [B,T-1]
lab = x[:, 1:]                   # [B,T-1]
pad_mask = (inp == PAD)          # [B,T-1] True=mask
causal   = subsequent_mask(inp.size(1), inp.device)
logits   = model(inp, pad_mask=pad_mask, causal_mask=causal)
loss     = F.cross_entropy(logits.reshape(-1, V), lab.reshape(-1), ignore_index=PAD)

注意 Transformer 的 mask 语义:bool 类型时 True 表示位置被屏蔽;float mask 时需用 0/−inf 的加法方式。


监控与护栏

代码语言:python
复制
def assert_shift(inp, lab):
    # inp=t[:T-1], lab=t[1:]
    ok = torch.allclose(inp[:,1:], lab[:, :-1])
    assert ok, "labels 未右移或对齐错误"

def assert_causal(mask):
    # 上三角应为 True
    assert mask.dtype == torch.bool
    T = mask.size(0)
    expect = torch.triu(torch.ones(T,T, dtype=torch.bool, device=mask.device), diagonal=1)
    assert torch.equal(mask, expect), "causal mask 形状/取值不正确"

def loss_pad_ratio(lab, ignore_index):
    frac = float((lab == PAD).float().mean())
    if ignore_index != PAD and frac > 0.05:
        print(f"[warn] {frac:.2%} 的 PAD 落入损失,请设置 ignore_index=PAD")

将上述断言加入训练脚手架的单元测试/启动自检,可避免回归。


常见问答

  • 交叉熵是否一定用 ignore_index=PAD 是。否则模型会被强迫在 PAD 处也最小化损失,学会“预测 PAD”,影响生成与 perplexity。
  • 因果 mask 与 key_padding_mask 的关系 两者互补:causal 防止看未来;key_padding 屏蔽序列中的 PAD 位置。两者都需要。
  • 为什么“未移位 + 无 causal”时训练集 loss 还能降 模型可以在当前位置“抄答案”(偷看自己或未来 token),训练 loss 假降但泛化与生成都崩。
  • 验证时需要哪些对齐 验证集构造同样按 inp=x:,:-1、lab=x:,1:,并且使用相同的 ignore_index 与 mask 规则。

结语

自回归语言建模里,labels 右移、PAD 忽略、因果 mask 是三件必须同时到位的基础工作。一旦漏了其中之一,就会出现 loss 居高不下、ppl 不降、生成充斥 PAD/EOS 等顽固问题。把上面的最小复现跑一遍,再把修复模板与护栏固化到项目脚手架,这类问题基本可以一次性清零。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 语言模型 loss 居高不下?未移位 labels、PAD 未忽略、注意力 mask 错误的三连坑
    • Bug 现象
    • 场景复现(两行命令 A/B)
    • Debug 过程
    • 代码修改要点(模板)
    • 监控与护栏
    • 常见问答
    • 结语
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档