首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >ML-SYS 学习宝典:从 RLHF 系统到 SGLang 深入解析

ML-SYS 学习宝典:从 RLHF 系统到 SGLang 深入解析

原创
作者头像
qife122
发布2025-12-19 19:42:54
发布2025-12-19 19:42:54
1550
举报

项目简介

Awesome-ML-SYS-Tutorial 是一个专注于机器学习系统(ML SYS)领域的学习笔记与代码仓库。它旨在为对 ML 与系统交叉领域感兴趣的研究者和工程师提供高质量的学习资源。项目内容涵盖了从基础理论(如强化学习、马尔科夫决策过程)到前沿系统框架(如 RLHF 训练系统、SGLang 推理引擎)的深度解析,并包含了大量实战配置指南和核心代码走读。

该项目的核心价值在于其实践性系统性:不仅提供了如何配置开发环境、使用 Docker、管理依赖等实操指南,还深入剖析了 slime、verl、OpenRLHF 等主流开源框架的架构设计与核心实现,帮助读者理解大规模机器学习训练与推理背后的系统设计思想。

功能特性

  • 全面的 RLHF 系统开发笔记:详细记录了 veRL、slime、OpenRLHF 等框架的设计理念、工作流程、核心模块(如 Rollout、Training、Buffer)及异步训练、多轮对话、工具调用等高级特性的实现。
  • 深入的 SGLang 推理引擎解析:从请求生命周期、KV Cache 管理、分布式并行(TP/DP)、权重更新机制到多模态(如 Qwen2.5-VL)和扩散模型支持,进行了全方位的代码走读和原理阐述。
  • 训推不一致的系统性解决方案:探讨了 RL 训练中因数值精度、算子差异导致的训练与推理不匹配问题,并介绍了 slime 框架中“真正 On-Policy 训练”和“算法缓解(TIS/MIS)”两种解决方案。
  • 实用的环境配置与工具指南:提供了基于 Docker 的可复现环境配置方法、高效的 bash/zsh 配置、uv 包管理工具的使用,以及如何在复杂集群上配置开发环境。
  • 核心算法与代码实现:包含 PPO、GRPO、SPIN、Online DPO 等主流强化学习算法的理论推导、公式解析及其在具体框架(如 verl、trl)中的代码实现。
  • 前沿工作复现与分析:对 Search-R1、LUFFY、Kimi K1.5 等前沿研究工作的算法思想和实现细节进行了学习和复现笔记。

安装指南

本项目主要为学习笔记和代码分析,不依赖于单一的安装脚本。但项目内包含了大量环境配置的实践指导:

  1. 基础环境:推荐使用 Docker 来创建隔离且可复现的开发环境。可以参考项目中的 Docker 配置指南(如使用 lmsysorg/sglang:latestnvcr.io/nvidia/pytorch 等基础镜像)。
  2. Python 环境:建议使用 uv 作为快速的 Python 包管理器,并搭配虚拟环境(venv)。# 创建虚拟环境 python3 -m venv ~/.python/myenv source ~/.python/myenv/bin/activate安装 uvpython3 -m pip install uv
  3. 框架安装:针对不同的学习模块,需要安装对应的框架。 - verl (with SGLang):git clone https://github.com/volcengine/verl.git cd verl python3 -m uv pip install -e ".sglang" --prerelease=allow - slime:git clone https://github.com/THUDM/slime.git cd slime pip install -e . - SGLang:git clone https://github.com/sgl-project/sglang.git cd sglang pip install -e "pythonall"
  4. 依赖管理:注意处理 PyTorch、CUDA、flash-attn、transformers 等依赖的版本兼容性问题,具体版本需参考各框架的官方要求。

使用说明

快速开始:运行一个 RLHF 训练示例

以下以在 verl 框架中使用 SGLang 运行一个 GSM8K 多轮 GRPO 训练为例:

  1. 准备环境和数据:# 拉取并预处理数据集 cd verl python examples/data_preprocess/gsm8k_multiturn_w_tool.py
  2. 启动训练脚本:# 设置 GPU 并运行训练(示例脚本) export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh该脚本内部会配置 Ray 集群、启动 SGLang 推理引擎,并开始 GRPO 训练循环。

核心概念:理解 SGLang 中的请求处理流程

SGLang 作为高性能推理引擎,其核心是高效处理并发的推理请求。一个请求的生命周期大致如下(简化):

代码语言:python
复制
# 伪代码,示意 SGLang 内部流程
# 1. 请求接收与 Tokenization
# 用户请求通过 FastAPI endpoint 进入
async def v1_chat_completions(request: ChatCompletionRequest):
    req_input = convert_to_generate_req_input(request)
    # TokenizerManager 进行分词和多模态数据处理
    tokenized = tokenizer_manager.tokenize(req_input)
    # 将请求放入 Scheduler 的等待队列
    scheduler.add_request(tokenized)

# 2. 调度与批处理 (Scheduler)
# Scheduler 事件循环
while True:
    # 从等待队列中根据优先级(如最长前缀)选取请求组成批次
    batch = scheduler.get_next_batch()
    if batch.prefill_tokens > 0:
        # Prefill 阶段:处理输入提示,填充 KV Cache
        run_prefill_batch(batch)
    # Decode 阶段:自回归生成 token
    next_token_ids = run_decode_batch(batch)
    # 处理生成结果,更新请求状态
    scheduler.process_batch_results(batch, next_token_ids)
    # 如果请求完成,将结果发送给 DetokenizerManager
    for req in batch.finished_requests:
        detokenizer_manager.send_output(req)

# 3. 结果返回
# DetokenizerManager 将 token ID 解码为文本,通过 HTTP 返回给用户。

API 概览:verl 训练配置关键参数

在 verl 等框架中,训练通过配置文件(如 YAML)或命令行参数控制。以下是一些关键参数:

参数组

参数名

说明

Data

train_batch_size

每次训练迭代使用的提示(Prompt)数量。

Data

max_prompt_length

提示的最大 token 长度。

Actor/Rollout

ppo_mini_batch_size

PPO 训练中,将经验数据分割成的 mini-batch 大小。

Actor/Rollout

rollout.n

每个提示采样多少条回复(Responses)。

Rollout Engine

rollout.name

指定推理引擎,如 sglangvllm

Multi-turn

rollout.multi_turn.enable

是否启用多轮对话训练。

Tool Calling

rollout.trace.backend

启用轨迹追踪(如 weave),用于分析工具调用。

核心代码

1. slime 中 GAE 的 Chunk-Scan 并行计算

该优化解决了长序列下 GAE 计算串行导致的性能瓶颈。核心思想是将时间序列分块,并行计算局部 GAE,再通过前缀扫描合并。

代码语言:python
复制
# 代码片段位于 slime 相关 PR (#850)
# 核心思想:将 GAE 计算转化为可并行的前缀扫描问题
def chunk_scan_gae(full_rewards, full_values, gamma, lambd, chunk_size):
    """
    full_rewards: [B, T]
    full_values: [B, T+1]
    """
    B, T = full_rewards.shape
    num_chunks = (T + chunk_size - 1) // chunk_size

    # 1. 将数据分块
    reward_chunks = full_rewards.split(chunk_size, dim=1)
    value_chunks = full_values.split(chunk_size, dim=1)

    # 2. 并行计算每个 chunk 的局部 delta 和 GAE(伪代码)
    # 每个 chunk 内部是串行的,但 chunk 之间可以并行处理
    chunk_results = []
    for i in range(num_chunks):
        chunk_rewards = reward_chunks[i]
        chunk_values = value_chunks[i]
        next_values = value_chunks[i+1] if i+1 < len(value_chunks) else 0.0
        # 计算该 chunk 的 delta 和局部 GAE
        local_delta = chunk_rewards + gamma * next_values - chunk_values
        local_gae = compute_sequential_gae(local_delta, gamma*lambd) # 内部串行
        chunk_results.append((local_delta, local_gae))

    # 3. 前缀扫描合并 chunk (简化示意)
    full_advantages = torch.zeros_like(full_rewards)
    carry = 0.0  # 跨 chunk 的累积因子
    for i in range(num_chunks):
        local_delta, local_gae = chunk_results[i]
        # 将上一个 chunk 的尾部影响加到当前 chunk 的 GAE 上
        adjusted_gae = local_gae + carry
        # 更新 carry 用于下一个 chunk
        carry = (gamma * lambd) ** chunk_size * local_gae[:, -1:]
        # 存储结果
        start_idx = i * chunk_size
        end_idx = start_idx + chunk_size
        full_advantages[:, start_idx:end_idx] = adjusted_gae
    return full_advantages

代码注释

  • 传统 GAE 计算需要对时间步 tT-10 进行串行循环,无法利用 GPU 并行能力。
  • chunk_scan_gae 函数首先将长度为 T 的序列划分为多个 chunk
  • 每个 chunk 内部的 GAE 计算仍是串行的,但不同的 chunk 可以并行计算,这显著提高了计算吞吐。
  • 之后,通过一个轻量的“前缀扫描”步骤,将前一个 chunk 的末端 GAE 值(carry)传播到后一个 chunk,从而合并得到整个序列正确的 GAE。
  • 该优化在超长序列的 Agentic RL 场景下,可带来 100-300 倍 的加速。

2. SGLang 中多模态请求的 Token 扩展与特征注入

以 Qwen2.5-VL 为例,展示了 SGLang 如何处理包含图像的请求。

代码语言:python
复制
# 代码思想基于 sglang/runtime/multimodal_extensions/qwen_vl.py
# 关键步骤:Tokenizer 扩展与 M-RoPE 位置编码
def process_qwen_vl_request(generate_req_input):
    """
    generate_req_input: 包含 text 和 image_data 的请求输入
    """
    text = generate_req_input.text
    image_data_list = generate_req_input.image_data

    # 1. 并发加载和预处理图像
    pixel_values_list = []
    for img_data in image_data_list:
        # 加载图像,并应用模型特定的 resize (如 smart_resize)
        pixel_values = load_and_preprocess_image(img_data)
        pixel_values_list.append(pixel_values)

    # 2. Tokenization 与即时 Token 扩展
    # 原始 prompt 可能包含类似 `<|vision_start|><image><|vision_end|>` 的占位符
    # Tokenizer 会直接将 `<image>` 替换为一连串特定的 image placeholder tokens (如 <|image_pad|>)
    input_ids = tokenizer.encode(text)
    # 此时 input_ids 中已经包含了代表图像区域的特殊 token 序列

    # 3. 计算 M-RoPE (Multimodal Rotary Position Embedding) 位置
    # 这为图像 token 和文本 token 提供了统一的、精确的位置信息
    mrope_positions = compute_mrope_positions(input_ids, pixel_values_list)

    # 4. 构建多模态数据项
    mm_items = []
    for pv in pixel_values_list:
        mm_items.append(MultimodalDataItem(pixel_values=pv))

    # 5. 返回给调度器
    return {
        "input_ids": input_ids,
        "mm_items": mm_items,
        "mrope_positions": mrope_positions
    }

# 在模型前向传播时 (如 qwen2_5_vl.py)
def forward_in_model(input_ids, mrope_positions, mm_items):
    # 获取常规的词嵌入(包含特殊 token)
    embeddings = word_embedding(input_ids)
    # 应用 RoPE,位置信息由 mrope_positions 提供
    embeddings = apply_rotary_pos_emb(embeddings, mrope_positions)

    # 识别出 input_ids 中图像占位符 token 的位置
    image_token_indices = find_image_token_indices(input_ids)
    # 获取视觉特征
    image_features = vision_transformer(mm_items.pixel_values)
    # 将视觉特征投影到语言模型嵌入空间
    projected_image_features = project_to_lm_dim(image_features)
    # 将视觉特征注入到对应位置的 embeddings 中
    embeddings[image_token_indices] = projected_image_features

    # 后续送入 LLM backbone 进行计算
    logits = llm_backbone(embeddings)
    return logits

代码注释

  • 多模态请求处理的关键在于Token 扩展特征注入
  • Token 扩展发生在 Tokenizer 阶段,图像占位符被直接替换为一系列预定义的特殊 Token,这使得输入序列在结构上包含了图像信息。
  • M-RoPE 为融合后的序列(文本 Token + 图像占位 Token)计算统一的位置编码,确保注意力机制能理解各部分的位置关系。
  • 在模型前向传播时,视觉特征由视觉编码器(如 Vision Transformer)提取,并投影到语言模型的嵌入维度,然后精确地替换嵌入层中对应图像占位符位置的向量。
  • 这种设计使得图像信息能够“无缝”地融入语言模型的推理流程,同时保持了 SGLang 调度和缓存机制的有效性(如图片缓存基于 pixel values 的哈希)。

3. 训推不一致的算法缓解:Masked Importance Sampling (MIS)

在 slime 框架中,除了追求比特级对齐的“真正 On-Policy”方案,还提供了基于重要性采样的算法缓解方案。

代码语言:python
复制
# 代码思想基于 slime 中训推不一致缓解的实现
# Masked Importance Sampling (MIS) 核心公式应用
def compute_mis_corrected_loss(actor_log_probs,              # 训练引擎计算的 log prob
                               rollout_log_probs,            # 推理引擎 rollout 时计算的 log prob
                               advantages,                   # 优势函数估计
                               clip_epsilon=0.2,
                               mismatch_threshold=0.1):      # 失配阈值
    """
    通过重要性采样权重来修正策略梯度,抑制训推差异过大的样本。
    """
    # 1. 计算重要性权重 (importance weight)
    #   即 rollout 策略与当前训练策略的概率比
    log_ratio = rollout_log_probs - actor_log_probs
    ratio = torch.exp(log_ratio)

    # 2. 计算原始的 PPO 裁剪损失
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
    original_pg_loss = -torch.min(surr1, surr2).mean()

    # 3. 计算训推差异度量(例如,每个 token 的 KL 散度)
    kl_per_token = actor_log_probs - rollout_log_probs  # 近似 KL
    avg_kl = kl_per_token.mean(dim=-1)  # 序列平均 KL

    # 4. 构造 Mask (基于阈值)
    #    差异过大的样本,其梯度会被抑制
    mask = (avg_kl < mismatch_threshold).float().unsqueeze(-1)  # 扩张维度以匹配 token 级

    # 5. 应用 MIS:将 mask 作为权重乘到原始损失上
    #    或者更精细地,调整重要性权重
    mis_ratio = ratio * mask
    surr1_mis = mis_ratio * advantages
    surr2_mis = torch.clamp(mis_ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
    mis_corrected_loss = -torch.min(surr1_mis, surr2_mis).mean()

    # 也可以选择将 mask 直接作用于损失
    # mis_corrected_loss = (original_pg_loss * mask.mean())

    return mis_corrected_loss, avg_kl, mask

代码注释

  • 训推不一致 指即使模型权重相同,训练引擎和推理引擎计算出的 token 对数概率也存在微小差异,本质上是浮点运算顺序等系统原因造成的异策略(off-policy)效应。
  • MIS 核心思想:识别出训推差异(可用 KL 散度度量)过大的样本,并在策略梯度更新时降低这些样本的权重(mask)。
  • ratio 是重要性采样权重,衡量 rollout 时策略与当前训练策略的偏离程度。
  • mask 是一个 0/1 掩码,当序列的平均 KL 低于阈值时,mask=1,否则为 0。这相当于过滤掉了差异过大的“不可靠”样本。
  • mask 应用到 ratio 或直接应用到损失上,可以抑制因训推不一致而产生的有害梯度,从而提高训练稳定性,尤其是在 MoE 模型或长序列任务中。
  • 这种方法是一种效率与正确性之间的折中,相比于实现比特级对齐,它的开销更小,但能有效缓解不匹配带来的负面影响。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 项目简介
  • 功能特性
  • 安装指南
  • 使用说明
    • 快速开始:运行一个 RLHF 训练示例
    • 核心概念:理解 SGLang 中的请求处理流程
    • API 概览:verl 训练配置关键参数
  • 核心代码
    • 1. slime 中 GAE 的 Chunk-Scan 并行计算
    • 2. SGLang 中多模态请求的 Token 扩展与特征注入
    • 3. 训推不一致的算法缓解:Masked Importance Sampling (MIS)
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档