首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >PPO 学不动?advantage 未归一、旧 log_prob 没缓存

PPO 学不动?advantage 未归一、旧 log_prob 没缓存

原创
作者头像
九年义务漏网鲨鱼
发布2025-12-17 17:14:30
发布2025-12-17 17:14:30
2980
举报

PPO 学不动?advantage 未归一、旧 log_prob 没缓存

场景:在 CartPole / LunarLander 这类经典离散环境上用 PPO 复现论文曲线,结果训练很慢或直接“横着走”。复盘常见三件事:

  1. 优势函数(advantage)方差巨大但未做标准化;
  2. 更新阶段没有使用“采样时的旧 log_prob”(ratio≈1 学不动);
  3. 将 time-limit 截断当成真正终止,导致 GAE/回报 bootstrap 错误。

本文给出可复现实验与一键修复模板,帮助你把 PPO 从“划水”拉回正轨。


Bug 现象

  • 同样超参下,平均回报卡在 100~200;甚至训练 1e6 步也不升。
  • 近似 KL 始终接近 0,说明策略几乎没更新;或者 KL 偶发爆高后熔断。
  • 用 Gymnasium 时,把 terminatedtruncated 合并当作 done,优势估计偏小且震荡。
  • 打印 advantage 的均值/方差,发现均值远离 0 或方差巨大;ratio 的均值≈1 且方差极小。

场景复现

保存为 ppo_three_pitfalls.py,两行命令对比:

代码语言:python
复制
python ppo_three_pitfalls.py --bug on --env CartPole-v1

# 正确设置(修复三连坑)
python ppo_three_pitfalls.py --bug off --env CartPole-v1
# ppo_three_pitfalls.py
import argparse, math, time, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
import gymnasium as gym
torch.manual_seed(0); np.random.seed(0)

def mlp(s, a, hidden=64):
    policy = nn.Sequential(
        nn.Linear(s, hidden), nn.Tanh(),
        nn.Linear(hidden, hidden), nn.Tanh(),
        nn.Linear(hidden, a)
    )
    value = nn.Sequential(
        nn.Linear(s, hidden), nn.Tanh(),
        nn.Linear(hidden, hidden), nn.Tanh(),
        nn.Linear(hidden, 1)
    )
    return policy, value

def select_action(pi, obs):
    logits = pi(obs)
    dist = torch.distributions.Categorical(logits=logits)
    a = dist.sample()
    logp = dist.log_prob(a)
    return a, logp, dist.entropy()

def compute_gae(rews, vals, dones_bootstrap_mask, gamma=0.99, lam=0.95):
    # rews/vals: [T]; dones_bootstrap_mask: [T],1 表示“非真正终止,可以 bootstrap”
    T = len(rews)
    adv = torch.zeros(T)
    lastgaelam = 0.0
    for t in reversed(range(T)):
        # 当 mask=0(真正终止),下一状态价值系数为 0
        next_nonterminal = dones_bootstrap_mask[t]
        next_value = vals[t+1] if t+1 < len(vals) else 0.0
        delta = rews[t] + gamma * next_value * next_nonterminal - vals[t]
        lastgaelam = delta + gamma * lam * next_nonterminal * lastgaelam
        adv[t] = lastgaelam
    ret = adv + vals[:-1]
    return adv, ret

def explained_variance(y_pred, y_true):
    var_y = torch.var(y_true)
    return torch.tensor(1.0) - torch.var(y_true - y_pred) / (var_y + 1e-8)

def rollout(env, pi, vf, steps, bug):
    obs_buf, act_buf, logp_old_buf, val_buf = [], [], [], []
    rew_buf, term_mask_buf = [], []
    o, _ = env.reset(seed=np.random.randint(10_000))
    for _ in range(steps):
        ot = torch.as_tensor(o, dtype=torch.float32)
        with torch.no_grad():
            v = vf(ot).squeeze(-1)
            a, logp, _ = select_action(pi, ot)
        o2, r, terminated, truncated, _ = env.step(a.item())
        # time-limit 处理:
        if bug:
            # 错误:把 time-limit 截断也当作终止(不 bootstrap)
            bootstrap_mask = 0.0 if (terminated or truncated) else 1.0
        else:
            # 正确:仅在“真正终止”时不 bootstrap;time-limit 仍可 bootstrap
            bootstrap_mask = 0.0 if terminated else 1.0

        obs_buf.append(ot)
        act_buf.append(a)
        logp_old_buf.append(logp.detach())
        val_buf.append(v.detach())
        rew_buf.append(torch.as_tensor(r, dtype=torch.float32))
        term_mask_buf.append(torch.as_tensor(bootstrap_mask, dtype=torch.float32))

        o = o2 if (not terminated and not truncated) else env.reset()[0]
    # 末值估计:用于 bootstrapping
    with torch.no_grad():
        last_v = vf(torch.as_tensor(o, dtype=torch.float32)).squeeze(-1)
    vals = torch.stack(val_buf + [last_v])
    adv, ret = compute_gae(torch.stack(rew_buf), vals, torch.stack(term_mask_buf))
    if not bug:
        # 修复1:优势标准化
        adv = (adv - adv.mean()) / (adv.std() + 1e-8)
    return (
        torch.stack(obs_buf), torch.stack(act_buf),
        torch.stack(logp_old_buf), torch.stack(val_buf),
        adv, ret
    )

def ppo_update(pi, vf, data, clip_coef=0.2, vf_coef=0.5, ent_coef=0.0, lr=3e-4, epochs=4, batch=256, bug=True):
    opt = torch.optim.Adam(list(pi.parameters())+list(vf.parameters()), lr=lr)
    obs, act, logp_old, val_old, adv, ret = data
    N = obs.shape[0]
    for ep in range(epochs):
        idx = torch.randperm(N)
        for start in range(0, N, batch):
            mb = idx[start:start+batch]
            ob, ac = obs[mb], act[mb]
            lp_old, val_targ, adv_t = logp_old[mb], ret[mb], adv[mb]

            # 前向
            logits = pi(ob)
            dist = torch.distributions.Categorical(logits=logits)
            logp = dist.log_prob(ac)
            entropy = dist.entropy().mean()

            if bug:
                # 错误2:没有缓存“采样时”的旧 log_prob,直接用当前的 logp 当 old
                ratio = torch.exp(logp - logp.detach())   # ≈ 1
            else:
                # 正确:使用 rollout 时缓存的 old_logp
                ratio = torch.exp(logp - lp_old)

            # 策略损失(clip)
            adv_now = adv_t
            pg_loss1 = -adv_now * ratio
            pg_loss2 = -adv_now * torch.clamp(ratio, 1.0 - clip_coef, 1.0 + clip_coef)
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()

            # 价值函数损失
            v_pred = vf(ob).squeeze(-1)
            v_loss = F.mse_loss(v_pred, val_targ)

            loss = pg_loss + vf_coef * v_loss - ent_coef * entropy

            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(list(pi.parameters())+list(vf.parameters()), 0.5)
            opt.step()

    # 监控
    with torch.no_grad():
        approx_kl = (logp_old - dist.log_prob(act)).mean().abs() if 'dist' in locals() else torch.tensor(0.)
        ev = explained_variance(vf(obs).squeeze(-1), ret)
        ratio_std = torch.std(torch.exp(dist.log_prob(act) - logp_old))
        return {
            "loss": float(loss),
            "pg_loss": float(pg_loss),
            "v_loss": float(v_loss),
            "entropy": float(entropy),
            "approx_kl": float(approx_kl),
            "ratio_std": float(ratio_std),
            "adv_mean": float(adv.mean()),
            "adv_std": float(adv.std()),
            "ev": float(ev),
        }

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--env", default="CartPole-v1")
    ap.add_argument("--steps", type=int, default=4096)
    ap.add_argument("--iters", type=int, default=10)
    ap.add_argument("--bug", choices=["on","off"], default="on")
    args = ap.parse_args()

    env = gym.make(args.env)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n
    pi, vf = mlp(obs_dim, act_dim)

    returns = []
    for it in range(1, args.iters+1):
        data = rollout(env, pi, vf, steps=args.steps, bug=(args.bug=="on"))
        # 评估当前策略
        with torch.no_grad():
            ep_rets = []
            for _ in range(5):
                o, _ = env.reset()
                done = False; ep_ret = 0
                while not done:
                    a = torch.distributions.Categorical(logits=pi(torch.as_tensor(o, dtype=torch.float32))).probs.argmax().item()
                    o, r, terminated, truncated, _ = env.step(a)
                    ep_ret += r; done = terminated or truncated
                ep_rets.append(ep_ret)
            mean_ret = float(np.mean(ep_rets)); returns.append(mean_ret)

        stats = ppo_update(pi, vf, data, bug=(args.bug=="on"))
        tag = "BUG" if args.bug=="on" else "FIX"
        print(f"[{tag}] iter={it:02d} return={mean_ret:6.1f} "
              f"adv(mean/std)={stats['adv_mean']:.3f}/{stats['adv_std']:.3f} "
              f"ratio_std={stats['ratio_std']:.4f} kl≈{stats['approx_kl']:.4f} ev={stats['ev']:.3f}")

if __name__ == "__main__":
    main()

你会观察到的典型差异

  • 错误设置:ratio_std≈0approx_kl≈0,mean return 缓慢或停滞;adv_std 偏大。
  • 正确设置:ratio 有波动、KL 非零且受 clip 约束;平均回报明显上升。

Debug 过程

  1. 打印优势统计 adv.mean() 应接近 0;adv.std() 过大说明方差爆炸,优先做标准化。
  2. 检查 ratio 的分布 ratio = exp(new_logp - old_logp) 的标准差不应长期接近 0。若≈0,说明你没用旧 log_prob 或把 old_logp 也用当前策略算了。
  3. time-limit 与真正终止分离 Gymnasium 返回 (terminated, truncated)
  • 只有 terminated=True 时才是真的终止(例如失败/成功);
  • truncated=True 多是时间上限,不应把 bootstrap 砍掉。 把 bootstrap_mask = 1 - terminated 喂给 GAE/回报计算。
  1. 价值学习健康度 打印 explained_variance(v_pred, return);接近 1 说明价值模型拟合良好,负值说明严重偏离。

代码修改要点(模板)

  1. 优势标准化
代码语言:python
复制
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
  1. 缓存旧 log_prob(在 rollout 时计算并 detach 存起来)
代码语言:python
复制
with torch.no_grad():
    a, logp, _ = select_action(pi, ot)
logp_old_buf.append(logp.detach())
...
ratio = torch.exp(new_logp - old_logp)  # old_logp 从 buffer 读取
  1. 区分 terminated 与 truncated
代码语言:python
复制
# 错误:mask = 1 - (terminated or truncated)
# 正确:
bootstrap_mask = 0.0 if terminated else 1.0
  1. 监控与护栏
代码语言:python
复制
def assert_ratio_variation(r):
    s = float(torch.std(r))
    assert s > 1e-6, f"ratio std≈0,疑似未使用旧 log_prob"

def assert_adv_norm(adv):
    m, s = float(adv.mean()), float(adv.std())
    assert abs(m) < 0.1 and s > 0.1, f"adv 分布异常:mean={m:.3f}, std={s:.3f}"

常见问答

  • advantage 标准化会不会改变目标 不会改变无偏性,等价于自适应学习率缩放,能显著稳定训练。
  • 为什么 ratio≈1 就学不动 PPO 的核心驱动是 ratio 偏离 1 后被剪切产生的梯度信号;ratio 一直≈1 意味着策略实际没变或你没用到旧行为概率。
  • truncated 一律当终止会怎样 GAE 在 episode 尾部少加了一段 γV(s_{t+1}),系统性低估优势与回报,学习信号偏弱且抖动。
  • 还需要哪些健康指标 近似 KL、ratio 标准差、adv 统计、explained variance、价值损失与熵,至少这几个放到日志里。

结语

PPO 的三个“隐形坑”——优势未归一、旧 log_prob 未缓存、time-limit 误处理——足以让曲线“横在半腰”。把上面的复现实验跑一遍,再把修复模板固化到你的 PPO 框架里,同时持

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • PPO 学不动?advantage 未归一、旧 log_prob 没缓存
    • Bug 现象
    • 场景复现
    • Debug 过程
    • 代码修改要点(模板)
    • 常见问答
    • 结语
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档