
场景:在 CartPole / LunarLander 这类经典离散环境上用 PPO 复现论文曲线,结果训练很慢或直接“横着走”。复盘常见三件事:
本文给出可复现实验与一键修复模板,帮助你把 PPO 从“划水”拉回正轨。
terminated 与 truncated 合并当作 done,优势估计偏小且震荡。保存为 ppo_three_pitfalls.py,两行命令对比:
python ppo_three_pitfalls.py --bug on --env CartPole-v1
# 正确设置(修复三连坑)
python ppo_three_pitfalls.py --bug off --env CartPole-v1
# ppo_three_pitfalls.py
import argparse, math, time, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
import gymnasium as gym
torch.manual_seed(0); np.random.seed(0)
def mlp(s, a, hidden=64):
policy = nn.Sequential(
nn.Linear(s, hidden), nn.Tanh(),
nn.Linear(hidden, hidden), nn.Tanh(),
nn.Linear(hidden, a)
)
value = nn.Sequential(
nn.Linear(s, hidden), nn.Tanh(),
nn.Linear(hidden, hidden), nn.Tanh(),
nn.Linear(hidden, 1)
)
return policy, value
def select_action(pi, obs):
logits = pi(obs)
dist = torch.distributions.Categorical(logits=logits)
a = dist.sample()
logp = dist.log_prob(a)
return a, logp, dist.entropy()
def compute_gae(rews, vals, dones_bootstrap_mask, gamma=0.99, lam=0.95):
# rews/vals: [T]; dones_bootstrap_mask: [T],1 表示“非真正终止,可以 bootstrap”
T = len(rews)
adv = torch.zeros(T)
lastgaelam = 0.0
for t in reversed(range(T)):
# 当 mask=0(真正终止),下一状态价值系数为 0
next_nonterminal = dones_bootstrap_mask[t]
next_value = vals[t+1] if t+1 < len(vals) else 0.0
delta = rews[t] + gamma * next_value * next_nonterminal - vals[t]
lastgaelam = delta + gamma * lam * next_nonterminal * lastgaelam
adv[t] = lastgaelam
ret = adv + vals[:-1]
return adv, ret
def explained_variance(y_pred, y_true):
var_y = torch.var(y_true)
return torch.tensor(1.0) - torch.var(y_true - y_pred) / (var_y + 1e-8)
def rollout(env, pi, vf, steps, bug):
obs_buf, act_buf, logp_old_buf, val_buf = [], [], [], []
rew_buf, term_mask_buf = [], []
o, _ = env.reset(seed=np.random.randint(10_000))
for _ in range(steps):
ot = torch.as_tensor(o, dtype=torch.float32)
with torch.no_grad():
v = vf(ot).squeeze(-1)
a, logp, _ = select_action(pi, ot)
o2, r, terminated, truncated, _ = env.step(a.item())
# time-limit 处理:
if bug:
# 错误:把 time-limit 截断也当作终止(不 bootstrap)
bootstrap_mask = 0.0 if (terminated or truncated) else 1.0
else:
# 正确:仅在“真正终止”时不 bootstrap;time-limit 仍可 bootstrap
bootstrap_mask = 0.0 if terminated else 1.0
obs_buf.append(ot)
act_buf.append(a)
logp_old_buf.append(logp.detach())
val_buf.append(v.detach())
rew_buf.append(torch.as_tensor(r, dtype=torch.float32))
term_mask_buf.append(torch.as_tensor(bootstrap_mask, dtype=torch.float32))
o = o2 if (not terminated and not truncated) else env.reset()[0]
# 末值估计:用于 bootstrapping
with torch.no_grad():
last_v = vf(torch.as_tensor(o, dtype=torch.float32)).squeeze(-1)
vals = torch.stack(val_buf + [last_v])
adv, ret = compute_gae(torch.stack(rew_buf), vals, torch.stack(term_mask_buf))
if not bug:
# 修复1:优势标准化
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
return (
torch.stack(obs_buf), torch.stack(act_buf),
torch.stack(logp_old_buf), torch.stack(val_buf),
adv, ret
)
def ppo_update(pi, vf, data, clip_coef=0.2, vf_coef=0.5, ent_coef=0.0, lr=3e-4, epochs=4, batch=256, bug=True):
opt = torch.optim.Adam(list(pi.parameters())+list(vf.parameters()), lr=lr)
obs, act, logp_old, val_old, adv, ret = data
N = obs.shape[0]
for ep in range(epochs):
idx = torch.randperm(N)
for start in range(0, N, batch):
mb = idx[start:start+batch]
ob, ac = obs[mb], act[mb]
lp_old, val_targ, adv_t = logp_old[mb], ret[mb], adv[mb]
# 前向
logits = pi(ob)
dist = torch.distributions.Categorical(logits=logits)
logp = dist.log_prob(ac)
entropy = dist.entropy().mean()
if bug:
# 错误2:没有缓存“采样时”的旧 log_prob,直接用当前的 logp 当 old
ratio = torch.exp(logp - logp.detach()) # ≈ 1
else:
# 正确:使用 rollout 时缓存的 old_logp
ratio = torch.exp(logp - lp_old)
# 策略损失(clip)
adv_now = adv_t
pg_loss1 = -adv_now * ratio
pg_loss2 = -adv_now * torch.clamp(ratio, 1.0 - clip_coef, 1.0 + clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
# 价值函数损失
v_pred = vf(ob).squeeze(-1)
v_loss = F.mse_loss(v_pred, val_targ)
loss = pg_loss + vf_coef * v_loss - ent_coef * entropy
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(list(pi.parameters())+list(vf.parameters()), 0.5)
opt.step()
# 监控
with torch.no_grad():
approx_kl = (logp_old - dist.log_prob(act)).mean().abs() if 'dist' in locals() else torch.tensor(0.)
ev = explained_variance(vf(obs).squeeze(-1), ret)
ratio_std = torch.std(torch.exp(dist.log_prob(act) - logp_old))
return {
"loss": float(loss),
"pg_loss": float(pg_loss),
"v_loss": float(v_loss),
"entropy": float(entropy),
"approx_kl": float(approx_kl),
"ratio_std": float(ratio_std),
"adv_mean": float(adv.mean()),
"adv_std": float(adv.std()),
"ev": float(ev),
}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--env", default="CartPole-v1")
ap.add_argument("--steps", type=int, default=4096)
ap.add_argument("--iters", type=int, default=10)
ap.add_argument("--bug", choices=["on","off"], default="on")
args = ap.parse_args()
env = gym.make(args.env)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n
pi, vf = mlp(obs_dim, act_dim)
returns = []
for it in range(1, args.iters+1):
data = rollout(env, pi, vf, steps=args.steps, bug=(args.bug=="on"))
# 评估当前策略
with torch.no_grad():
ep_rets = []
for _ in range(5):
o, _ = env.reset()
done = False; ep_ret = 0
while not done:
a = torch.distributions.Categorical(logits=pi(torch.as_tensor(o, dtype=torch.float32))).probs.argmax().item()
o, r, terminated, truncated, _ = env.step(a)
ep_ret += r; done = terminated or truncated
ep_rets.append(ep_ret)
mean_ret = float(np.mean(ep_rets)); returns.append(mean_ret)
stats = ppo_update(pi, vf, data, bug=(args.bug=="on"))
tag = "BUG" if args.bug=="on" else "FIX"
print(f"[{tag}] iter={it:02d} return={mean_ret:6.1f} "
f"adv(mean/std)={stats['adv_mean']:.3f}/{stats['adv_std']:.3f} "
f"ratio_std={stats['ratio_std']:.4f} kl≈{stats['approx_kl']:.4f} ev={stats['ev']:.3f}")
if __name__ == "__main__":
main()你会观察到的典型差异
ratio_std≈0、approx_kl≈0,mean return 缓慢或停滞;adv_std 偏大。adv.mean() 应接近 0;adv.std() 过大说明方差爆炸,优先做标准化。ratio = exp(new_logp - old_logp) 的标准差不应长期接近 0。若≈0,说明你没用旧 log_prob 或把 old_logp 也用当前策略算了。(terminated, truncated):terminated=True 时才是真的终止(例如失败/成功);truncated=True 多是时间上限,不应把 bootstrap 砍掉。
把 bootstrap_mask = 1 - terminated 喂给 GAE/回报计算。explained_variance(v_pred, return);接近 1 说明价值模型拟合良好,负值说明严重偏离。adv = (adv - adv.mean()) / (adv.std() + 1e-8)with torch.no_grad():
a, logp, _ = select_action(pi, ot)
logp_old_buf.append(logp.detach())
...
ratio = torch.exp(new_logp - old_logp) # old_logp 从 buffer 读取# 错误:mask = 1 - (terminated or truncated)
# 正确:
bootstrap_mask = 0.0 if terminated else 1.0def assert_ratio_variation(r):
s = float(torch.std(r))
assert s > 1e-6, f"ratio std≈0,疑似未使用旧 log_prob"
def assert_adv_norm(adv):
m, s = float(adv.mean()), float(adv.std())
assert abs(m) < 0.1 and s > 0.1, f"adv 分布异常:mean={m:.3f}, std={s:.3f}"ratio 偏离 1 后被剪切产生的梯度信号;ratio 一直≈1 意味着策略实际没变或你没用到旧行为概率。γV(s_{t+1}),系统性低估优势与回报,学习信号偏弱且抖动。PPO 的三个“隐形坑”——优势未归一、旧 log_prob 未缓存、time-limit 误处理——足以让曲线“横在半腰”。把上面的复现实验跑一遍,再把修复模板固化到你的 PPO 框架里,同时持
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。