首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >【Debug日志 | 梯度问题】

【Debug日志 | 梯度问题】

原创
作者头像
九年义务漏网鲨鱼
发布2025-09-29 15:08:10
发布2025-09-29 15:08:10
1740
举报
文章被收录于专栏:tencent cloudtencent cloud

学着学着就炸了?忘记 zero_grad 和重复 backward 造成的隐式梯度累加

在普通的分类模型中,最初几步 loss 正常下降,随后出现两种诡异现象之一:要么 loss/gnorm 突然飙升并发散,要么显存一轮比一轮高直到 OOM。复盘后发现是训练循环里没有在每次优化步前清梯度,或者对同一计算图做了多次 backward(还 retain_graph=True),导致梯度被悄悄“叠罗汉”。

❓ Bug 现象

  • 同样学习率与 batch,曲线更抖更慢,偶发直接发散。
  • 梯度范数逐步增大,权重更新幅度异常(等效学习率被放大)。
  • 显存占用在若干 step 后持续上升,最终 OOM(retain_graph=True 常见)。
  • 打印若干参数的 grad,发现下一步未清零仍保留上一步的数值。

📽️ 场景复现

CPU 可跑;分两类错误切换观察。

代码语言:python
复制
import argparse, torch, torch.nn as nn, torch.nn.functional as F
torch.manual_seed(0)

class TinyMLP(nn.Module):
    def __init__(self, din=32, dh=128, ncls=3):
        super().__init__()
        self.f = nn.Sequential(nn.Linear(din, dh), nn.ReLU(), nn.Linear(dh, ncls))
    def forward(self, x): return self.f(x)

def make_loader(n=8192, bs=128, din=32, ncls=3):
    X = torch.randn(n, din)
    W = torch.randn(din, ncls); b = torch.randn(ncls)
    y = (X @ W + b).argmax(1)
    ds = torch.utils.data.TensorDataset(X, y)
    return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, drop_last=True)

def run(mode="no_zero", steps=200):
    model = TinyMLP()
    opt = torch.optim.AdamW(model.parameters(), lr=3e-3)

    loader = make_loader()
    it = iter(loader)
    if mode != "no_zero":
        opt.zero_grad(set_to_none=True)

    for step in range(1, steps+1):
        try: x, y = next(it)
        except StopIteration:
            it = iter(loader); x, y = next(it)

        logits = model(x)
        loss_a = F.cross_entropy(logits, y, label_smoothing=0.0)
        loss_b = 0.1 * logits.pow(2).mean()

        if mode == "no_zero":
            total = loss_a + loss_b
            total.backward()
            opt.step()
            # 没有 zero_grad
        elif mode == "double_backward":
            opt.zero_grad(set_to_none=True)
            loss_a.backward(retain_graph=True)  
            loss_b.backward()                  
            opt.step()
        elif mode == "fix":
            opt.zero_grad(set_to_none=True)
            total = loss_a + loss_b
            total.backward()
            opt.step()
        else:
            raise ValueError(mode)

        with torch.no_grad():
            gnorm = 0.0
            for p in model.parameters():
                if p.grad is not None:
                    gnorm += p.grad.norm().item()
            acc = (logits.argmax(1) == y).float().mean().item()
            if step % 25 == 0:
                print(f"[{mode}] step={step:03d} loss={float((loss_a+loss_b)):.3f} acc={acc:.3f} gnorm≈{gnorm:.2f}")

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--mode", choices=["no_zero","double_backward","fix"], default="no_zero")
    args = ap.parse_args()
    run(args.mode, steps=200)

你会看到

  • no_zero:gnorm 随步数偏大且不稳定,loss 容易抖。
  • double_backward:曲线更不稳,显存会比 fix 高(retain_graph=True 会保留计算图)。
  • fix:gnorm 与 loss 平滑,准确率稳步上升。

Debug 过程

  1. 打印某些参数的 grad 范数与是否为 None 若在 backward 前就不是 None,说明上一轮没清干净;若在 opt.step 后仍非 None,多半是没 zero_grad。
代码语言:python
复制
def probe_grad(model, tag):
    with torch.no_grad():
        grads = [p.grad.norm().item() for n,p in model.named_parameters() if p.grad is not None]
        print(f"[{tag}] has_grad={len(grads)}>0 grad_mean={sum(grads)/max(1,len(grads)):.2f}")
  1. 排查 retain_graph 的使用场景 只有在需要对同一前向图多次求不同导且图还要继续使用时才需要 retain_graph=True;多数分类/回归训练一次 backward 足够。
  2. 检查 loss 的组合方式 如果有多个项(主损失、正则、蒸馏、对比学习等),应统一到一个 total_loss 一次 backward;不要分别 backward 两次。
  3. 观察显存 显存随 step 逐步抬升通常与重复保留图、未清梯度或额外缓存挂钩。逐项关闭 retain_graph、register_hook、TensorBoard 图跟踪等验证。

代码修改

  1. 每个优化步清梯度;推荐 set_to_none=True
代码语言:python
复制
for step, (x,y) in enumerate(loader):
    opt.zero_grad(set_to_none=True)  # 清到 None 更省显存,也更容易暴露未清零
    logits = model(x)
    loss = main_loss(logits, y) + aux_regularizer(logits, y)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # 若需要
    opt.step()
  1. 合并损失,一次 backward
代码语言:python
复制
total_loss = loss_a + loss_b + loss_c
total_loss.backward()
  1. 避免不必要的 retain_graph 除非你明确需要多次沿相同图求导,否则不要加;很多场景其实可以通过第二次前向来避免 retain_graph。
  2. 有意的梯度累积要显式控制 如果你确实想跨 micro-batch 累积,参考梯度累积模板:在非边界步不清梯度,边界步清零与 step;并把 loss 除以 accum_steps。
代码语言:python
复制
opt.zero_grad(set_to_none=True)
for i, (x,y) in enumerate(loader):
    loss = criterion(model(x), y) / accum_steps
    loss.backward()
    if (i+1) % accum_steps == 0:
        opt.step(); opt.zero_grad(set_to_none=True)

结语

隐式梯度累加是最容易忽视、却能把训练一步步带偏的暗坑。把清梯度的位置固定在每个优化步开头,合并损失一次反传,避免随手 retain_graph,再配上小小的 grad 监控与显存观察,这类问题基本可以一劳永逸。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 学着学着就炸了?忘记 zero_grad 和重复 backward 造成的隐式梯度累加
    • ❓ Bug 现象
    • 📽️ 场景复现
    • Debug 过程
    • 代码修改
    • 结语
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档