
场景:做自回归语言建模(next-token prediction)。训练 loss 长期在 3.x~5.x,perplexity 居高不下;采样生成大量 PAD/EOS;训练集 loss 偶尔还能降,但验证集几乎不动。复盘后是三件老问题叠加:
下面给出一个 CPU 可跑的极小脚本,复现问题与一键修复模板。
保存为 lm_shift_pad_mask_pitfalls.py:
# lm_shift_pad_mask_pitfalls.py
import argparse, math, torch, torch.nn as nn, torch.nn.functional as F
torch.manual_seed(0)
PAD, BOS, EOS = 0, 1, 2
class TinyLM(nn.Module):
def __init__(self, vocab=128, d=128, nlayer=2, nhead=4, max_len=64):
super().__init__()
self.tok = nn.Embedding(vocab, d)
self.pos = nn.Embedding(max_len, d)
enc_layer = nn.TransformerEncoderLayer(d_model=d, nhead=nhead, dim_feedforward=256, batch_first=True, norm_first=True)
self.enc = nn.TransformerEncoder(enc_layer, num_layers=nlayer)
self.proj = nn.Linear(d, vocab)
self.max_len = max_len
def forward(self, x, pad_mask=None, causal_mask=None):
B, T = x.shape
pos_id = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)
h = self.tok(x) + self.pos(pos_id)
h = self.enc(h, mask=causal_mask, src_key_padding_mask=pad_mask) # 注意 mask 语义:True=masked
return self.proj(h) # [B,T,V]
def make_batch(B=64, T=32, vocab=128):
# 构造:BOS + 随机 token(长度5~T-2) + EOS + PAD
x = torch.full((B, T), PAD, dtype=torch.long)
for b in range(B):
L = torch.randint(5, T-1, (1,)).item()
core = torch.randint(3, vocab, (L-2,)) # 避开特殊符
seq = torch.cat([torch.tensor([BOS]), core, torch.tensor([EOS])], dim=0)
x[b, :len(seq)] = seq
return x
def subsequent_mask(T, device):
# True=mask 掉未来
return torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
def run(bug=True, steps=400, device="cpu"):
vocab, T, B = 128, 32, 64
model = TinyLM(vocab=vocab, max_len=T).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
def perplexity(loss): return float(math.exp(min(20.0, loss)))
for it in range(1, steps+1):
x = make_batch(B=B, T=T, vocab=vocab).to(device)
if bug:
# 三连坑:未移位、未忽略PAD、无因果mask
inp = x # [B,T]
lab = x # [B,T]
pad_mask = None
causal = None
ignore_index = -100 # 但我们没用到
else:
# 修复:labels 右移、忽略 PAD、加因果 mask
inp = x[:, :-1] # 输入到 T-1
lab = x[:, 1:] # 预测下一个
pad_mask = (inp == PAD) # [B,T-1] True=mask
causal = subsequent_mask(inp.size(1), device) # [T-1,T-1]
ignore_index = PAD
logits = model(inp, pad_mask=pad_mask, causal_mask=causal) # [B,Ti,V]
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
lab.reshape(-1),
ignore_index=ignore_index
)
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if it % 50 == 0:
print(f"[{'BUG' if bug else 'FIX'}] step={it:04d} loss={loss.item():.3f} ppl≈{perplexity(loss.item()):.1f}")
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--bug", choices=["on","off"], default="on")
args = ap.parse_args()
print("== 错误设置 =="); run(bug=True)
print("\n== 正确设置 =="); run(bug=False)典型观测:
mask_PAD_in_loss = (lab == PAD).float().mean();若远大于 0 且没有 ignore_index=PAD,那就是把 PAD 当监督了。语言建模前向与损失:
# x: [B,T] 包含 BOS..EOS..PAD
inp = x[:, :-1] # [B,T-1]
lab = x[:, 1:] # [B,T-1]
pad_mask = (inp == PAD) # [B,T-1] True=mask
causal = subsequent_mask(inp.size(1), inp.device)
logits = model(inp, pad_mask=pad_mask, causal_mask=causal)
loss = F.cross_entropy(logits.reshape(-1, V), lab.reshape(-1), ignore_index=PAD)注意 Transformer 的 mask 语义:bool 类型时 True 表示位置被屏蔽;float mask 时需用 0/−inf 的加法方式。
def assert_shift(inp, lab):
# inp=t[:T-1], lab=t[1:]
ok = torch.allclose(inp[:,1:], lab[:, :-1])
assert ok, "labels 未右移或对齐错误"
def assert_causal(mask):
# 上三角应为 True
assert mask.dtype == torch.bool
T = mask.size(0)
expect = torch.triu(torch.ones(T,T, dtype=torch.bool, device=mask.device), diagonal=1)
assert torch.equal(mask, expect), "causal mask 形状/取值不正确"
def loss_pad_ratio(lab, ignore_index):
frac = float((lab == PAD).float().mean())
if ignore_index != PAD and frac > 0.05:
print(f"[warn] {frac:.2%} 的 PAD 落入损失,请设置 ignore_index=PAD")将上述断言加入训练脚手架的单元测试/启动自检,可避免回归。
自回归语言建模里,labels 右移、PAD 忽略、因果 mask 是三件必须同时到位的基础工作。一旦漏了其中之一,就会出现 loss 居高不下、ppl 不降、生成充斥 PAD/EOS 等顽固问题。把上面的最小复现跑一遍,再把修复模板与护栏固化到项目脚手架,这类问题基本可以一次性清零。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。