
场景:在分类/分割/检索模型上,换了优化器和 weight_decay 之后,训练明显更慢、验证精度长时间不上升。常见三件事:
weight_decay 以为等价于 L2,实则是“耦合衰减”,与自适应矩影响叠加,收敛受损;下面给出最小复现实验(CPU 可跑)与一键修复模板。
保存为 adamw_decay_pitfalls.py:
# adamw_decay_pitfalls.py
import argparse, math, torch, torch.nn as nn, torch.nn.functional as F
torch.manual_seed(0)
class TinyTransformer(nn.Module):
def __init__(self, d=128, ncls=5, nhead=4, nl=2):
super().__init__()
self.embed = nn.Linear(20, d)
encoder_layer = nn.TransformerEncoderLayer(d_model=d, nhead=nhead, dim_feedforward=256, batch_first=True, norm_first=True)
self.enc = nn.TransformerEncoder(encoder_layer, num_layers=nl)
self.pool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(d, ncls)
def forward(self, x): # x: [B,20] → 伪装成 5 个 token
x = x.view(x.size(0), 5, 4) # [B,5,4]
x = self.embed(x) # [B,5,d]
x = self.enc(x) # [B,5,d]
x = x.transpose(1,2) # [B,d,5]
x = self.pool(x).squeeze(-1) # [B,d]
return self.head(x)
def make_loader(n=6000, bs=128):
X = torch.randn(n, 20)
W = torch.randn(20, 5); b = torch.randn(5)
y = (X @ W + b).argmax(dim=1) # 线性可分 + 轻微扰动
ds = torch.utils.data.TensorDataset(X, y)
return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, drop_last=True)
def split_decay(model):
decay, no_decay = [], []
for n, p in model.named_parameters():
if p.requires_grad is False:
continue
is_norm = any(k in n.lower() for k in ["norm", "bn"])
if n.endswith(".bias") or is_norm:
no_decay.append(p)
else:
decay.append(p)
return [
{"params": decay, "weight_decay": 0.01},
{"params": no_decay, "weight_decay": 0.0},
]
def run(bug=True, steps=400):
device = "cpu"
model = TinyTransformer().to(device)
train_loader = make_loader(n=6000, bs=128)
val_loader = make_loader(n=2000, bs=256)
if bug:
# 错误:Adam + weight_decay(耦合 L2)且不排除 Norm/偏置
opt = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-2)
else:
# 正确:AdamW(解耦)+ param group 排除 Norm/偏置
opt = torch.optim.AdamW(split_decay(model), lr=3e-4, weight_decay=1e-2)
def one_epoch(loader, train=True):
model.train(train)
total, correct, loss_sum, n = 0, 0, 0.0, 0
for x, y in loader:
logits = model(x.to(device))
loss = F.cross_entropy(logits, y.to(device))
if train:
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
with torch.no_grad():
pred = logits.argmax(1)
correct += (pred.cpu() == y).sum().item()
total += y.numel()
loss_sum += float(loss.item()) * y.size(0)
n += y.size(0)
return loss_sum / n, correct / total
# 训练若干 epoch 并周期评估
for ep in range(1, 11):
tr_loss, tr_acc = one_epoch(train_loader, train=True)
va_loss, va_acc = one_epoch(val_loader, train=False)
# 监控:Norm 权重的均值范数
with torch.no_grad():
norm_ws = [p for n,p in model.named_parameters() if ("norm" in n.lower()) and p.ndim>=1]
norm_mean = float(torch.stack([w.norm() for w in norm_ws]).mean()) if norm_ws else 0.0
tag = "BUG" if bug else "FIX"
print(f"[{tag}] epoch={ep:02d} train_acc={tr_acc*100:5.1f}% val_acc={va_acc*100:5.1f}% norm|w|≈{norm_mean:.4f}")
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--bug", choices=["on","off"], default="on")
args = ap.parse_args()
print("== 错误设置 =="); run(bug=True)
print("\n== 正确设置 =="); run(bug=False)你通常会看到:
norm|w| 持续变小;loss += λ * ||w||^2;*.bias、*.norm*(LayerNorm/BatchNorm/RMSNorm/GroupNorm)应在 no_decay 组;Embedding、绝大多数归一化层的权重也不应做 weight decay(除非论文特别说明)。optimizer = torch.optim.AdamW(
[
{"params": decay_params, "weight_decay": wd},
{"params": no_decay_params, "weight_decay": 0.0},
],
lr=lr, betas=(0.9, 0.999), eps=1e-8
)def build_param_groups(model, wd=0.01):
decay, no_decay = [], []
for n, p in model.named_parameters():
if p.requires_grad is False: continue
if n.endswith(".bias") or "norm" in n.lower() or "bn" in n.lower():
no_decay.append(p)
else:
decay.append(p)
return [
{"params": decay, "weight_decay": wd},
{"params": no_decay, "weight_decay": 0.0},
]weight_decay。weight_decay 等价于 L2;def assert_param_groups(optimizer):
has_decay = any(g.get("weight_decay", 0) > 0 for g in optimizer.param_groups)
has_nodecay = any(g.get("weight_decay", 0) == 0 for g in optimizer.param_groups)
assert has_decay and has_nodecay, "缺少衰减/不衰减分组,请检查"
def summarize_decay(model, optimizer):
names = {id(p): n for n,p in model.named_parameters()}
for i, g in enumerate(optimizer.param_groups):
wd = g.get("weight_decay", 0)
cnt = len(g["params"])
some = [names.get(id(p), "<?>") for p in g["params"][:5]]
print(f"[group {i}] wd={wd} count={cnt} e.g. {some}")
def norm_weight_health(model):
with torch.no_grad():
norms = [p.norm().item() for n,p in model.named_parameters() if "norm" in n.lower()]
if norms:
print(f"[health] norm|w| mean={sum(norms)/len(norms):.4f}")配合训练日志把 val_acc、norm|w|、|bias| 的变化一起记录,快速定位“正则过强”的问题。
“Adam + weight_decay 当 L2 用”“Norm/偏置也衰减”“AdamW 之外再叠 L2”是阻碍收敛的三连坑。把优化器统一为 AdamW,严格区分衰减/不衰减参数组,并监控标准化层与偏置的权重范数,你会看到验证曲线立刻恢复应有的上升势头。上面的复现实验与模板可以直接拷贝进项目,作为优化器配置的通用基线。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。