前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >第七篇:强化学习策略迭代 代码实现

第七篇:强化学习策略迭代 代码实现

作者头像
double
发布2023-08-08 14:39:45
2160
发布2023-08-08 14:39:45
举报
文章被收录于专栏:算法channel算法channel

你好,我是郭震(zhenguo)

今天介绍强化学习第7篇:强化学习策略迭代代码实现

首先,我们导入包:

代码语言:javascript
复制
import numpy as np

其次,定义迷宫状态空间大小,在这里是9个空格,所以状态数9。

动作空间大小,最多为4个,上、下、左、右

代码语言:javascript
复制
num_states = 9
num_actions = 4

定义迷宫的奖励矩阵,它是二维数组,维度含义:[num_states][num_actions]

代码语言:javascript
复制
rewards = np.array([
    [-1, -1, -1, -1],
    [-1, 0, -1, 1],
    [-1, -1, -1, -1],
    [1, -1, -1, -1],
    [-1, -1, -1, -1],
    [-1, 0, -1, 1],
    [-1, -1, -1, -1],
    [-1, 0, -1, 1],
    [-1, 1, -1, 1],
])

接下来,初始化策略和值函数:

代码语言:javascript
复制
policy = np.ones((num_states, num_actions)) / num_actions
values = np.zeros(num_states)

策略迭代两步:

  • 1)策略评估函数
  • 2)策略改进

迭代第一步:策略评估

主要就是更新values,核心一行代码下面:

代码语言:javascript
复制
v += policy[s][a] * (rewards[s][a] + values[next_state])

这是贝尔曼方程!贝尔曼方程是动态规划和强化学习中的基本方程,由Richard Bellman提出。

贝尔曼方程表达了状态或状态-动作对的值与按照特定策略获得的预期回报之间的关系。

贝尔曼方程的一般形式如下:

其中,

V(s)

表示状态

s

的值函数,即按照某个策略获得的预期回报。

\max_a

表示选择能够使得值最大化的动作

a

\sum_{s', r}

表示对所有可能的下一个状态

s'

和奖励

r

进行求和。

p(s', r | s, a)

表示在状态

s

下执行动作

a

后转移到状态

s'

且获得奖励

r

的概率。

\gamma

是折扣因子,用于平衡当前和未来的奖励。

通过求解贝尔曼方程,我们可以计算出每个状态的值函数

V(s)

,从而确定最优策略。

代码语言:javascript
复制
# 
def policy_evaluation():
    global values
    delta = 1e-6  # 停止迭代的阈值
    max_iterations = 1000  # 最大迭代次数
    for _ in range(max_iterations):
        new_values = np.zeros(num_states)
        for s in range(num_states):
            v = 0
            for a in range(num_actions):
                next_state = get_next_state(s, a)  # 获取下一个状态
                v += policy[s][a] * (rewards[s][a] + values[next_state])
            new_values[s] = v
        if np.max(np.abs(new_values - values)) < delta:
            break
        values = new_values

定义获取下一个状态的函数:

代码语言:javascript
复制
def get_next_state(state, action):
    if state == 1 and action == 3:
        return 3
    elif state == 3 and action == 0:
        return 1
    else:
        return state

迭代第二步:策略改进

主要更新策略 policy 数组

代码语言:javascript
复制
# 定义策略改进函数
def policy_improvement():
    global policy
    for s in range(num_states):
        q_values = np.zeros(num_actions)
        for a in range(num_actions):
            next_state = get_next_state(s, a)  # 获取下一个状态
            q_values[a] = rewards[s][a] + values[next_state]
        best_action = np.argmax(q_values)
        new_policy = np.zeros(num_actions)
        new_policy[best_action] = 1
        policy[s] = new_policy

联合以上两步就得到策略迭代算法。

代码语言:javascript
复制
# 策略迭代算法
def policy_iteration():
    max_iterations = 30  # 最大迭代次数
    for _ in range(max_iterations):
        policy_evaluation()  # 策略评估
        policy_improvement()  # 策略改进
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2023-06-04,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 程序员郭震zhenguo 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档