专栏首页帮你学MatLab强化学习FrozenLake求解

强化学习FrozenLake求解

FrozenLake

% 让agent越过危险到达目标

大概情景是湖面结冰不结实、黑色圆圈表示踩到就会跌落、白色圆圈表示目标、红色小圆圈代表agent

% 强化学习
ccc
rng('default');  
% 初始化环境
env=two_dimensional_env(4,4,0.01);
% 产生一个强化学习agent
two_dimensional_rl=rl_q_table(env.actions,0.9,0.1,0.9);
% pause(2)
for episode =1:env.max_episodes
    env = env.reset();
    env.render();
    while 1
        A = two_dimensional_rl.choose_action(env.agent);
        env = env.step(A);  % 采取动作获得状态和奖励
        two_dimensional_rl=two_dimensional_rl.learn(env, A);  % 更新
        two_dimensional_rl.dump();% 显示学习进度
        env.agent=env.observation;
        env.render();
        if env.done
            break
        end
    end
end

下面是具体的环境

classdef two_dimensional_env
    %二维宝藏环境类
    properties
        weight
        height
        actions
        fresh_time
        fig
        agent
        observation
        reward
        done
        info
        max_episodes
        step_counter
    end
    % 上面是外部可操作的属性
    % 下面是内部可操作的属性
    properties (SetAccess = private)
      weight_
      height_
      fresh_time_
      fig_
      target=[3,3]
      trap=[2,3;3,2]
   end
    % 在methods中是外部可调用的方法
    methods
        function self = two_dimensional_env(weight,height,fresh_time)
            % 新建类时候初始化
            switch nargin
                case 1
                    self.weight=weight;
                    self.height=weight;
                    self.fresh_time=0.05;
                case 2
                    self.weight=weight;
                    self.height=height;
                    self.fresh_time=0.05;
                case 3
                    self.weight=weight;
                    self.height=height;
                    self.fresh_time=fresh_time;
                otherwise
                    error('缺少参数');
            end
            self.max_episodes=100;
            self.step_counter=0;
            self.actions=[1 2 3 4];
            self.agent=[1 1];
            self.observation=self.agent;
            self.reward=0;
            self.done=0;
            self.info=0;
            h=figure;
            self.fig = h.Number;
            self.weight_=self.weight;
            self.height_=self.height;
            self.fresh_time_=self.fresh_time;
            self.fig_=self.fig;
        end
        function render(self)
            % 绘制图像
            pause(self.fresh_time)
            h=figure(self.fig);
            ax = axes('Parent',h);
            ax.YAxis.Visible = 'off';
            ax.XAxis.Visible = 'off';
            for i=1:self.weight
                for j=1:self.weight
                    draw.rect(i*10,j*10,10);
                    if all([i,j]==self.agent)
                        draw.circle(i*10,j*10,2,'r');
                    end
                    if all([i,j]==self.target)
                        draw.circle(i*10,j*10,4);
                    end
                    if any(all([i,j]==self.trap,2))
                        draw.circle(i*10,j*10,4,'k');
                    end
                end
            end
            axis(ax, 'equal')
        end
        function self = step(self,action)
            % 一步动作
            self.done = 0;
            self.reward=0;
            temp_w=self.agent(1);
            temp_h=self.agent(2);
            if action == 1
                temp_w = self.agent(1)-1;
            end
            if action == 2
                temp_w = self.agent(1)+1;
            end
            if action == 3
                temp_h = self.agent(2)-1;
            end
            if action == 4
                temp_h = self.agent(2)+1;
            end
            if temp_w < 1
                temp_w=1;
                self.reward=-1;
            end
            if temp_h < 1
                temp_h=1;
                self.reward=-1;
            end
            if temp_w > self.weight
                temp_w=self.weight;
                self.reward=-1;
            end
            if temp_h > self.height
                temp_h=self.height;
                self.reward=-1;
            end
            if all([temp_w,temp_h]==self.target)
                self.done = 1;
                self.reward=1;
            end
            if any(all([temp_w,temp_h]==self.trap,2))
                self.done = 1;
                self.reward=-1;
            end
            self.observation=[temp_w,temp_h];
            self.info=0;
        end
        function self=reset(self)
            % 重置环境
            self.weight=self.weight_;
            self.height=self.height_;
            self.fresh_time=self.fresh_time_;
            self.fig=self.fig_;
            self.actions=[1 2];
            self.step_counter=0;
            self.agent=[1 1];
            self.observation=self.agent;
            self.reward=0;
            self.done=0;
            self.info=0;
        end
    end
end

下面是学习方法对象

classdef rl_q_table
    % 强化学习逻辑
    properties
        q_table
        actions
        epsilon
        alpha
        gamma
    end
    methods
        function obj = rl_q_table(actions,epsilon,alpha,gamma)
            % 初始化
            obj.actions=actions;
            obj.epsilon=epsilon;
            obj.alpha=alpha;
            obj.gamma=gamma;
            obj.q_table = containers.Map();
        end
        function dump(obj)
            % 显示学习表
            keySet = keys(obj.q_table);
            len=length(keySet);
            if len<=0
                return
            end
            disp('---------------------------')
            for i=1:len
                disp([keySet{i} ':' StrHelper.arr2str(obj.q_table(keySet{i}))])
            end
        end
        function table_ling=find_line(obj,state)
            % 取出学习表的一行
            agent_str = StrHelper.arr2str(state);
            if ~isKey(obj.q_table,agent_str)
                obj.q_table(agent_str) = zeros(1,length(obj.actions));
            end
            table_ling = obj.q_table(agent_str);
        end
        function obj=learn(obj,env, A)
            % 更新
            q_predict_arr = obj.find_line(env.agent);
            q_predict = q_predict_arr(A);
            if env.done ~= 1
                q_target = env.reward + obj.gamma * max(obj.find_line(env.observation));  % 没有结束
            else
                q_target = env.reward;  % 一局结束了
            end
            % 更新QLearning table
            table_line = obj.find_line(env.agent);
            table_line(A) = table_line(A) + obj.alpha * (q_target - q_predict);
            obj.q_table(StrHelper.arr2str(env.agent)) = table_line;
        end
        function action_name = choose_action(obj,state)
            % 选择一个动作
            state_actions = obj.find_line(state);% 取出这一步的概率
            if (rand() > obj.epsilon) || (all(state_actions == 0))
                % 初始时随机选择
                action_name = obj.actions(randi(length(obj.actions)));
            else  % 贪心选择
                [~,I] = max(state_actions);
                max_index = state_actions==state_actions(I);
                if sum(max_index)>1
                    action_name = obj.actions(max_index);
                    action_name = action_name(randi(length(action_name)));
                else
                action_name = obj.actions(I);
                end
                % 选概率大的
            end
        end
    end
end

相关工具方法在

https://github.com/sickle12138/MatGamer

本文分享自微信公众号 - 帮你学MatLab(MatLab_helper)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-07-09

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 强化学习第-1步

    function self = one_dimensional_env(len,fresh_time)

    万木逢春
  • 简单游戏环境搭建

    function self = single_pong_env(weight,height,fresh_time)

    万木逢春
  • matlab增强学习之小球不落地

    在利用之前FrozenLake环境训练当中那种面向对象方式管理程序时发现训练后期运行速度变慢、而且是肉眼可观察到的变慢,所以采用了matlab另一种程序文件管理...

    万木逢春
  • 从PEP-8学习Python编码风格

    Python3中应当总是使用UTF-8。(Python2使用ASCII。)在使用了规定编码后不需要再声明文件编码。

    py3study
  • 强化学习第-1步

    function self = one_dimensional_env(len,fresh_time)

    万木逢春
  • DBSCAN算法的Python实现

    当我傻傻的用python写DBSCAN,我才突然想起来在scikit-learn中有DBSCAN,可以直接调用啊,我本来想要放弃快完成的代码,但是我想我可以发博...

    张凝可
  • python第四十二课——__str__(self)函数

    4.__str__(self): 作用: 创建完对象,直接打印对象名/引用名我们得到的是对象的内存信息(十六进制的地址信息), 这串数据我们程序员并不关心...

    hankleo
  • PaddlePaddle版Flappy-Bird—使用DQN算法实现游戏智能

    刚刚举行的 WAVE SUMMIT 2019 深度学习开发者峰会上,PaddlePaddle 发布了 PARL 1.1 版本,这一版新增了 IMPALA、A3C...

    用户1386409
  • Objective-C AVPlayer播放视频的使用与封装大致效果界面搭建Demo地址

    看下成员变量就知道我怎么搭建的了,这里我将video播放层的size作为参照量,对所有控件的size按照其video的size宽高进行比例缩放

    gwk_iOS
  • 用python解析xml的几种方法

    4.开始部分:<?xml version="1.0" encoding="utf-8"?>

    py3study

扫码关注云+社区

领取腾讯云代金券