首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >准确率长期卡在 70%?把 weight_decay 用成 L2、给 Norm/偏置也衰减、与 AdamW 混用的三连坑

准确率长期卡在 70%?把 weight_decay 用成 L2、给 Norm/偏置也衰减、与 AdamW 混用的三连坑

原创
作者头像
九年义务漏网鲨鱼
发布2025-12-19 16:06:14
发布2025-12-19 16:06:14
4800
举报

准确率长期卡在 70%?把 weight_decay 用成 L2、给 Norm/偏置也衰减、与 AdamW 混用的三连坑

场景:在分类/分割/检索模型上,换了优化器和 weight_decay 之后,训练明显更慢、验证精度长时间不上升。常见三件事:

  1. 用 Adam 配 weight_decay 以为等价于 L2,实则是“耦合衰减”,与自适应矩影响叠加,收敛受损;
  2. 对 LayerNorm/BatchNorm 的权重、所有偏置项一并做衰减,等效把“尺度/偏置”往 0 拉,表达力下降;
  3. 代码里一会儿 AdamW、一会儿又手搓 L2 正则(或 scheduler 的 param group 混乱),造成双重/不一致正则。

下面给出最小复现实验(CPU 可跑)与一键修复模板。


Bug 现象

  • 训练 loss 能降,但验证准确率长期停在 60%~75% 左右,继续训练收益甚微;
  • 统计 LayerNorm/BatchNorm 权重范数缓慢下降,甚至接近 0;
  • 同样超参,把 Adam(weight_decay=0.01) 换成 AdamW(weight_decay=0.01) + 正确的“排除表”,曲线立刻改善。

场景复现

保存为 adamw_decay_pitfalls.py

代码语言:python
复制
# adamw_decay_pitfalls.py
import argparse, math, torch, torch.nn as nn, torch.nn.functional as F
torch.manual_seed(0)

class TinyTransformer(nn.Module):
    def __init__(self, d=128, ncls=5, nhead=4, nl=2):
        super().__init__()
        self.embed = nn.Linear(20, d)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d, nhead=nhead, dim_feedforward=256, batch_first=True, norm_first=True)
        self.enc = nn.TransformerEncoder(encoder_layer, num_layers=nl)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(d, ncls)
    def forward(self, x):              # x: [B,20] → 伪装成 5 个 token
        x = x.view(x.size(0), 5, 4)    # [B,5,4]
        x = self.embed(x)              # [B,5,d]
        x = self.enc(x)                # [B,5,d]
        x = x.transpose(1,2)           # [B,d,5]
        x = self.pool(x).squeeze(-1)   # [B,d]
        return self.head(x)

def make_loader(n=6000, bs=128):
    X = torch.randn(n, 20)
    W = torch.randn(20, 5); b = torch.randn(5)
    y = (X @ W + b).argmax(dim=1)      # 线性可分 + 轻微扰动
    ds = torch.utils.data.TensorDataset(X, y)
    return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, drop_last=True)

def split_decay(model):
    decay, no_decay = [], []
    for n, p in model.named_parameters():
        if p.requires_grad is False: 
            continue
        is_norm = any(k in n.lower() for k in ["norm", "bn"])
        if n.endswith(".bias") or is_norm:
            no_decay.append(p)
        else:
            decay.append(p)
    return [
        {"params": decay, "weight_decay": 0.01},
        {"params": no_decay, "weight_decay": 0.0},
    ]

def run(bug=True, steps=400):
    device = "cpu"
    model = TinyTransformer().to(device)
    train_loader = make_loader(n=6000, bs=128)
    val_loader   = make_loader(n=2000, bs=256)

    if bug:
        # 错误:Adam + weight_decay(耦合 L2)且不排除 Norm/偏置
        opt = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-2)
    else:
        # 正确:AdamW(解耦)+ param group 排除 Norm/偏置
        opt = torch.optim.AdamW(split_decay(model), lr=3e-4, weight_decay=1e-2)

    def one_epoch(loader, train=True):
        model.train(train)
        total, correct, loss_sum, n = 0, 0, 0.0, 0
        for x, y in loader:
            logits = model(x.to(device))
            loss = F.cross_entropy(logits, y.to(device))
            if train:
                opt.zero_grad(set_to_none=True)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step()
            with torch.no_grad():
                pred = logits.argmax(1)
                correct += (pred.cpu() == y).sum().item()
                total   += y.numel()
                loss_sum += float(loss.item()) * y.size(0)
                n += y.size(0)
        return loss_sum / n, correct / total

    # 训练若干 epoch 并周期评估
    for ep in range(1, 11):
        tr_loss, tr_acc = one_epoch(train_loader, train=True)
        va_loss, va_acc = one_epoch(val_loader, train=False)
        # 监控:Norm 权重的均值范数
        with torch.no_grad():
            norm_ws = [p for n,p in model.named_parameters() if ("norm" in n.lower()) and p.ndim>=1]
            norm_mean = float(torch.stack([w.norm() for w in norm_ws]).mean()) if norm_ws else 0.0
        tag = "BUG" if bug else "FIX"
        print(f"[{tag}] epoch={ep:02d} train_acc={tr_acc*100:5.1f}% val_acc={va_acc*100:5.1f}% norm|w|≈{norm_mean:.4f}")

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--bug", choices=["on","off"], default="on")
    args = ap.parse_args()
    print("== 错误设置 =="); run(bug=True)
    print("\n== 正确设置 =="); run(bug=False)

你通常会看到:

  • 错误设置:val_acc 在 60%~75% 上下徘徊,norm|w| 持续变小;
  • 正确设置:val_acc 稳步上升(常>85%),Norm 权重范数维持在健康区间。

Debug 过程

  1. 排查优化器与正则来源 搜索是否同时出现:
  • Adam(weight_decay=λ) 与手搓 loss += λ * ||w||^2
  • AdamW + 额外 L2 项; 这会导致双重正则或语义不一致。
  1. 检查 param group 是否正确 打印每个参数名是否被归入“衰减组”。凡是 *.bias*.norm*(LayerNorm/BatchNorm/RMSNorm/GroupNorm)应在 no_decay 组;Embedding、绝大多数归一化层的权重也不应做 weight decay(除非论文特别说明)。
  2. 观测层级健康度 定期打印/记录:
  • 标准化层(Norm)与偏置参数的范数是否持续下滑;
  • 训练/验证损失与准确率的背离(过强正则常见)。
  1. 做最小 A/B 用上面的脚本在 CPU 直接跑 10 个 epoch;若你的任务较大,先在子集复现。

代码修改要点(模板)

  1. 统一改为 AdamW,并解耦 weight decay
代码语言:python
复制
optimizer = torch.optim.AdamW(
    [
        {"params": decay_params, "weight_decay": wd},
        {"params": no_decay_params, "weight_decay": 0.0},
    ],
    lr=lr, betas=(0.9, 0.999), eps=1e-8
)
  1. 构建“排除表”工具(适用于任意模型)
代码语言:python
复制
def build_param_groups(model, wd=0.01):
    decay, no_decay = [], []
    for n, p in model.named_parameters():
        if p.requires_grad is False: continue
        if n.endswith(".bias") or "norm" in n.lower() or "bn" in n.lower():
            no_decay.append(p)
        else:
            decay.append(p)
    return [
        {"params": decay, "weight_decay": wd},
        {"params": no_decay, "weight_decay": 0.0},
    ]
  1. 杜绝“手搓 L2 + AdamW” 若为了研究目的必须显式 L2,请只用 Adam(注意这与 AdamW 的正则效果不同),并不要再设 weight_decay
  2. 不同优化器的注意事项
  • SGD:weight_decay 等价于 L2;
  • Adam/RMSprop/Adagrad:若要 L2,请谨慎评估耦合影响;主流实践优先 AdamW(解耦)。

监控与护栏

代码语言:python
复制
def assert_param_groups(optimizer):
    has_decay = any(g.get("weight_decay", 0) > 0 for g in optimizer.param_groups)
    has_nodecay = any(g.get("weight_decay", 0) == 0 for g in optimizer.param_groups)
    assert has_decay and has_nodecay, "缺少衰减/不衰减分组,请检查"

def summarize_decay(model, optimizer):
    names = {id(p): n for n,p in model.named_parameters()}
    for i, g in enumerate(optimizer.param_groups):
        wd = g.get("weight_decay", 0)
        cnt = len(g["params"])
        some = [names.get(id(p), "<?>") for p in g["params"][:5]]
        print(f"[group {i}] wd={wd} count={cnt} e.g. {some}")

def norm_weight_health(model):
    with torch.no_grad():
        norms = [p.norm().item() for n,p in model.named_parameters() if "norm" in n.lower()]
        if norms:
            print(f"[health] norm|w| mean={sum(norms)/len(norms):.4f}")

配合训练日志把 val_accnorm|w||bias| 的变化一起记录,快速定位“正则过强”的问题。


常见问答

  • 为什么 Adam 的 weight_decay 不等价于 L2? Adam 的自适应缩放与 L2 的梯度项相互耦合,会改变等效正则强度;AdamW 通过参数更新时“解耦衰减”避免此问题。
  • 哪些参数应该参与衰减? 大多数线性/卷积权重;不建议对偏置与标准化层权重(LayerNorm/BatchNorm/RMSNorm 等)做衰减。Embedding 也常设为不衰减(大型语言模型惯例)。
  • weight_decay 取多大合适? Transformer 族常见 0.01;视觉网络 1e-4~1e-2 需网格搜索。若 batch 极小或强增广,适当减小 wd。
  • 可以一边 AdamW 一边再手工加 L2 吗? 不建议,除非你精确地想实现“AdamW + 额外泛化惩罚”的特定实验;否则易造成“双重正则”。

结语

“Adam + weight_decay 当 L2 用”“Norm/偏置也衰减”“AdamW 之外再叠 L2”是阻碍收敛的三连坑。把优化器统一为 AdamW,严格区分衰减/不衰减参数组,并监控标准化层与偏置的权重范数,你会看到验证曲线立刻恢复应有的上升势头。上面的复现实验与模板可以直接拷贝进项目,作为优化器配置的通用基线。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 准确率长期卡在 70%?把 weight_decay 用成 L2、给 Norm/偏置也衰减、与 AdamW 混用的三连坑
    • Bug 现象
    • 场景复现
    • Debug 过程
    • 代码修改要点(模板)
    • 监控与护栏
    • 常见问答
    • 结语
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档