推荐文章:《【Python爬虫五十个小案例】微博热点爬取小案例~》
文章从环境搭建、代码实现到数据展示与分析,完整实现了一个微博热搜爬取项目。项目不仅可以作为学习爬虫的入门案例,还可扩展为更复杂的热点分析系统。
蒙特卡洛树搜索(Monte Carlo Tree Search,简称MCTS)是一种用于决策制定的算法,尤其在复杂决策问题和游戏理论中表现出强大的实力。AlphaGo的核心算法之一就是MCTS。
最近蒙特卡洛树搜索(MCTS)算法在AI大模型领域再次受到关注,因为openai的o1模型通过结合MCTS和强化学习(RL)的方法,特别是在数学问题解决方面,显示出了显著的效果。很多复现框架也都使用了MCTS,比如Marco-o1,它通过链式思考微调和MCTS技术提升了问题解决的精确度和,尤其在数学、物理和编程等领域表现出色;ReST-MCTS这样的强化自训练方法,通过树搜索MCTS指导过程奖励,自动获取可靠的推理路径,并有效地利用奖励信号进行验证和LLM自训练。
本文就介绍下MCTS的原理和在LLM中的如何使用MCTS。
首先了解什么是MCTS?蒙特卡洛树搜索(Monte Carlo Tree Search,MCTS)是一种寻找最优决策制定的算法,通常应用于组合博弈中的行动规划。通过模拟来估计每个可选动作的价值,帮助选择最佳的下一步的动作,结合了随机模拟的一般性和树搜索的准确性。MCTS通过迭代地选择、扩展、模拟和更新节点来优化搜索树,最终选择最优的动作策略。
MCTS的基本过程可以分为以下几个步骤:
选择(Selection):从根节点开始,递归选择最优的子节点直到达到叶子节点。
扩展(Expansion):如果叶子节点不是一个终止节点,那么就创建一个或多个子节点,选择其中一个进行扩展。
模拟(Simulation):从扩展后的节点开始进行随机模拟,直到游戏结束或达到某个终止条件。
反向传播(Backpropagation):将模拟结果反向传播到搜索树中,更新节点的统计信息。
rollout算法是一种基于MC控制的决策时规划算法。
Rollout算法的基本步骤包括:
模拟轨迹生成:对于每个可能的动作,从当前状态开始,遵循预演策略进行模拟,直到达到终止状态。
价值估计:使用这些模拟轨迹的回报的平均值来估计每个动作的价值。
动作选择:选择具有最高估计价值的动作进行执行
所以,rollout算法对于每个当前状态,通过采样不同动作的仿真轨迹,估计不同动作的值函数,然后选择最大估计值的动作。它的本质并不是找打最优路径。
我们的一次rollout就是一次episode,多次探索组成一颗MCT。
UTC在MCTS算法中是指Upper Confidence Bound applied to Trees,即上限置信区间算法应用于树搜索。这是一种启发式搜索策略,用于在树结构中平衡“利用”(Exploitation)和“探索”(Exploration)。UTC公式如下:
其中:
- Q(v_i) 是节点 v_i 的累计质量值,通常表示为该节点的胜率。
- N(v_i)是节点 v_i 被访问的次数。
- N(v)是节点 的父节点被访问的次数。
-c 是探索参数,用于控制探索和利用之间的平衡。
UTC值最大的节点,就是MCTS遍历过程中选择的节点。这个值由两部分组成:第一部分 ( Q(v_i) \frac{N(v_i)}{N(v)} ) 表示对已有知识的利用,即胜率高的节点;第二部分 ( c \sqrt{\frac{\log(N(v))}{N(v_i)}} ) 表示对未充分模拟节点的探索,即访问次数较少的节点。通过调整参数 c,可以在搜索过程中控制对已知好节点的利用和对未知节点的探索。
上面介绍的概念在后续实践代码中都会使用,可以先有个大概了解。
这篇文章对MCTS的原理介绍的非常详细,本文就不做过多重复。
文章中两张图片可以比较清晰的理解MCTS实现过程:
下面我们就重点介绍在LLM中如何执行select,expansion,simulation,Backpropagation这四个步骤。
在了解了MCTS的流程后,我们将思路放在LLM中,第一个要考虑的就是针对自然语言问题去构建树,这里的Action和Reward是什么呢?这里我们借助微软的开源rStar源码理解。
通过代码当前节点不同的状态,会执行不同的动作,动作空间是A1-A5:
#! create children
if self.node_type is Node_Type.USER_QUESTION:
# A1: Propose an one-step thought.
if not self.disable_a1:
do_action_generate_ost_step()
# A2: Propose the remaining thought steps
do_action_generate_direct_answers()
# A3: Propose next sub-question along with its answer.
do_action_generate_subquestions()
# A5: Rephrase the question/sub-question.
if not self.disable_a5:
do_action_generate_rephrased_user_question()
elif self.node_type is Node_Type.REPHRASED_USER_QUESTION:
# A1: Propose an one-step thought.
if not self.disable_a1:
do_action_generate_ost_step()
# A2: Propose the remaining thought steps
do_action_generate_direct_answers()
# A3: Propose next sub-question along with its answer.
do_action_generate_subquestions()
elif self.node_type is Node_Type.DIRECT_ANSWER:
raise ValueError("DIRECT_ANSWER node cannot create children!!")
elif self.node_type is Node_Type.SUBQUESTION:
# A1: Propose an one-step thought.
if not self.disable_a1:
do_action_generate_ost_step()
# A2: Propose the remaining thought steps
do_action_generate_direct_answers()
# A3: Propose next sub-question along with its answer.
do_action_generate_subquestions()
# A4: Answer the sub-question again.
do_action_generate_re_subanswers()
elif self.node_type is Node_Type.RE_SUBANSWER:
# A1: Propose an one-step thought.
if not self.disable_a1:
do_action_generate_ost_step()
# A2: Propose the remaining thought steps
do_action_generate_direct_answers()
# A3: Propose next sub-question along with its answer.
do_action_generate_subquestions()
elif self.node_type is Node_Type.OST_STEP:
# A1: Propose an one-step thought.
if not self.disable_a1:
do_action_generate_ost_step()
# A2: Propose the remaining thought steps
do_action_generate_direct_answers()
我们结合prompt来理解动作空间
对应代码do_action_generate_ost_step(),我们直接看核心prompt(fewshot_ost_prompt.txt):
### Instruction:
There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
### Response:
Let's think step by step.
Step 1: Identify the initial number of trees. The problem states there are 15 trees in the grove.
Step 2: Identify the final number of trees. The problem states there will be 21 trees after the workers are done planting.
Step 3: Subtract the initial number of trees from the final number of trees to find out how many trees were planted.
Step 4: Therefore, the grove workers planted 21 (final number of trees) - 15 (initial number of trees) = 6 trees today.
Step 5: The answer is 6.
Action1就是让模型单步思考Let's think step by step,一步一步思考后得到最终的答案The answer is。
对应代码do_action_generate_direct_answers(),我们直接看核心prompt(fewshot_cot_prompt.txt):
### Instruction:
Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
### Response:
Let's think step by step. Olivia had 23 dollars.
5 bagels for 3 dollars each will be 5 x 3 = 15 dollars.
So she has 23 - 15 dollars left. 23 - 15 is 8.
The answer is: 8.
Action2就是直接思考给出答案。
对应代码do_action_generate_subquestions(),我们直接看核心prompt(decompose_promp.txt/decompose_prompt_rephrased.txt):
Given a question, please decompose it into sub-questions. For each sub-question, please answer it in a complete sentence, ending with "The answer is <a numeric answer>". When the original question is answerable, please start the subquestion with "Now we can answer the question: <original question>".
Question 1: Given a list of conditions, please answer the question. Condition 1: Four years ago, Kody's age was half of Mohamed's age. Condition 2: Mohamed is currently twice as old as 30 years. Question: What is Kody's current age?
Question 1.1: How old is Mohamed currently?
Answer 1.1: Mohamed is twice as old as 30 years, which means he is 30 * 2 = 60 years old.
Question 1.2: What was Kody's age four years ago, given that it was half of Mohamed's age at that time?
Answer 1.2: Four years ago, Mohamed was 60 - 4 = 56 years old, so Kody was half of that, which is 56 / 2 = 28 years old.
Question 1.3: Now we can answer the question: How old is Kody?
Answer 1.3: Kody is currently 28 + 4 = 32 years old. The answer is 32.
Action3就是将问题拆分问多个子问题,回答每个子问题然后得到最终答案。
对应代码do_action_generate_direct_answers(),这里是对sub_quesion的再次验证,所以是和A3配合使用,针对A3子问题的再次思考校正。具体代码实现在generate_re_subanswers(),大家可以自行查看。
对应代码do_action_generate_rephrased_user_question(),我们直接看核心prompt(rephrasing_prompt_template.txt):
Original Question: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
Rephrased Question: Given a list of conditions, please answer the question.
Condition 1: Initially, there are 15 trees in the grove.
Condition 2: Grove workers will add more trees to the grove today.
Condition 3: After planting, the total number of trees in the grove will increase to 21.
Question: How many trees did the grove workers plant today?
将问题改写为给出一堆条件,然后基于给出的条件回答问题。
大家可以根据我给出的txt去查看相应的源码,对Action空间有更好理解。
最后我们看他是如何构建搜索树的,跟踪到search_for_answers接口中的核心代码,可以看到关键词rollout
for i in (pbar := trange(args.num_rollouts, disable=True, position=0)):
# mcts蒙特卡洛树:do num_rollouts次rollout,构建搜索树
rollout_node = mcts_searcher.do_rollout(root_node, i)
model_rollout_nodes.append(rollout_node)
# search出最优路径
_, best_solution, _, chosen_node, all_solution_nodes, all_solutions = stochastic_find_best_solution(
root_node, generator.evaluator, enable_potential_score=args.enable_potential_score
)
model_solutions.append(best_solution)
model_all_solutions.append(all_solutions)
我们看看rollou里执行了什么逻辑:
def do_rollout(self, root_node: MCTS_Node, rollout_id: int):
"Make the tree one layer better. (Train for one iteration.)"
# select: 选择树节点
path_1 = self._select(root_node, rollout_id)
leaf = path_1[-1]
# expand: 扩展节点
self._expand(leaf, rollout_id)
# simulate: 从扩展后的节点开始进行随机模拟,只到得到答案
path_2 = self._simulate(leaf, rollout_id)
# backpropagate:选择的路径反向传播,更新节点信息
self._backpropagate(path_1 + path_2)
try:
return path_2[-1]
except:
return path_1[-1]
和文章最开头的蒙特卡洛树实现流程一致,这4个步骤具体操作就不详细介绍了~
最后讲一下在通过几轮rollout,我们获取到了树结构,以及节点信息也就是前面讲到的利用UTC方法得到的节点的UTC值。然后怎么去得到最佳路径的,也就是stochastic_find_best_solution()方法的思路。
具体实现,思路都添加到注释中:
def stochastic_find_best_solution(
root_node,
evaluator,
enable_potential_score,
):
# todo: what strategy do we use to select best node?
"""The function finds the best solution from the solution nodes in the MCTS tree.
Return: top answer, top solution, confidence of the top answer, the corresponding node of the answer, all solution nodes
"""
# 有效的solution_node: in SUBQUESTION type or DIRECT_ANSWER type or OST_STEP type,也就是叶子节点或者A1
solution_nodes = find_valid_solution_nodes(root_node)
if len(solution_nodes) == 0:
return None, None
# 提取node的答案
def extract_solution_from_node(node):
if node.node_type is Node_Type.SUBQUESTION:
return node.subanswer
elif node.node_type is Node_Type.DIRECT_ANSWER:
return node.direct_answer
else:
return None
# 提取solution_nodes的答案,也就是每条路径的最终答案
solutions = [extract_solution_from_node(node) for node in solution_nodes]
# 统计每个solutoin_node的分数
def calculate_potential_score_for_solution_node(node):
# 根据规则把最终答案提取出来
model_answer = evaluator.extract_answer_from_model_completion(extract_solution_from_node(node))
potential_answers_history = node.potential_answers_history # {depth -> [potential answers]}
assert potential_answers_history[node.depth] is None
# 这段代码是计算solution_node在树中每一层的得分depth_score
#(potential_answers_history记录的是node的所有祖先节点产生的potential_answers。
#如果path中节点高比例得到的answer和solution_node的answer相等,分数高),
# 最后将每层的depth_score相乘作为最终potential_score 。
potential_score = 1
for depth, depth_potential_answers in potential_answers_history.items():
if depth < node.depth:
depth_score = sum(
evaluator.check_answers_equiv(dpa, model_answer) for dpa in depth_potential_answers
) / len(depth_potential_answers)
potential_score *= depth_score
node.set_potential_score(potential_score)
return potential_score
prior_weights = (
[calculate_potential_score_for_solution_node(node) for node in solution_nodes]
if enable_potential_score
else None
)
# 统计分数最高的答案和path
top_answer, top_completion, top_completion_id, top_confidence = evaluator.stochastic_find_most_confident_answer(
completions=solutions, prior_weights=prior_weights
)
return top_answer, top_completion, top_confidence, solution_nodes[top_completion_id], solution_nodes, solutions
以上就是对MCTS的介绍以及它在LLM推理中的应用了~如果思路上理解有错误欢迎大家指出。
题外话:基于这份代码,我将prompt调整为中文,然后使用4o模型跑了一道简单数学题,采用rollout_num=3, 等待了23分钟得到了答案,值得欣慰的是答案是对的。。。
参考:
https://zhuanlan.zhihu.com/p/61062275
https://blog.csdn.net/qq_41033011/article/details/109034887
https://zhuanlan.zhihu.com/p/864190605
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。