前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >强化学习FrozenLake求解

强化学习FrozenLake求解

作者头像
万木逢春
发布2019-07-15 15:07:50
2.3K0
发布2019-07-15 15:07:50
举报
文章被收录于专栏:帮你学MatLab帮你学MatLab
FrozenLake

% 让agent越过危险到达目标

大概情景是湖面结冰不结实、黑色圆圈表示踩到就会跌落、白色圆圈表示目标、红色小圆圈代表agent

代码语言:javascript
复制
% 强化学习
ccc
rng('default');  
% 初始化环境
env=two_dimensional_env(4,4,0.01);
% 产生一个强化学习agent
two_dimensional_rl=rl_q_table(env.actions,0.9,0.1,0.9);
% pause(2)
for episode =1:env.max_episodes
    env = env.reset();
    env.render();
    while 1
        A = two_dimensional_rl.choose_action(env.agent);
        env = env.step(A);  % 采取动作获得状态和奖励
        two_dimensional_rl=two_dimensional_rl.learn(env, A);  % 更新
        two_dimensional_rl.dump();% 显示学习进度
        env.agent=env.observation;
        env.render();
        if env.done
            break
        end
    end
end

下面是具体的环境

代码语言:javascript
复制
classdef two_dimensional_env
    %二维宝藏环境类
    properties
        weight
        height
        actions
        fresh_time
        fig
        agent
        observation
        reward
        done
        info
        max_episodes
        step_counter
    end
    % 上面是外部可操作的属性
    % 下面是内部可操作的属性
    properties (SetAccess = private)
      weight_
      height_
      fresh_time_
      fig_
      target=[3,3]
      trap=[2,3;3,2]
   end
    % 在methods中是外部可调用的方法
    methods
        function self = two_dimensional_env(weight,height,fresh_time)
            % 新建类时候初始化
            switch nargin
                case 1
                    self.weight=weight;
                    self.height=weight;
                    self.fresh_time=0.05;
                case 2
                    self.weight=weight;
                    self.height=height;
                    self.fresh_time=0.05;
                case 3
                    self.weight=weight;
                    self.height=height;
                    self.fresh_time=fresh_time;
                otherwise
                    error('缺少参数');
            end
            self.max_episodes=100;
            self.step_counter=0;
            self.actions=[1 2 3 4];
            self.agent=[1 1];
            self.observation=self.agent;
            self.reward=0;
            self.done=0;
            self.info=0;
            h=figure;
            self.fig = h.Number;
            self.weight_=self.weight;
            self.height_=self.height;
            self.fresh_time_=self.fresh_time;
            self.fig_=self.fig;
        end
        function render(self)
            % 绘制图像
            pause(self.fresh_time)
            h=figure(self.fig);
            ax = axes('Parent',h);
            ax.YAxis.Visible = 'off';
            ax.XAxis.Visible = 'off';
            for i=1:self.weight
                for j=1:self.weight
                    draw.rect(i*10,j*10,10);
                    if all([i,j]==self.agent)
                        draw.circle(i*10,j*10,2,'r');
                    end
                    if all([i,j]==self.target)
                        draw.circle(i*10,j*10,4);
                    end
                    if any(all([i,j]==self.trap,2))
                        draw.circle(i*10,j*10,4,'k');
                    end
                end
            end
            axis(ax, 'equal')
        end
        function self = step(self,action)
            % 一步动作
            self.done = 0;
            self.reward=0;
            temp_w=self.agent(1);
            temp_h=self.agent(2);
            if action == 1
                temp_w = self.agent(1)-1;
            end
            if action == 2
                temp_w = self.agent(1)+1;
            end
            if action == 3
                temp_h = self.agent(2)-1;
            end
            if action == 4
                temp_h = self.agent(2)+1;
            end
            if temp_w < 1
                temp_w=1;
                self.reward=-1;
            end
            if temp_h < 1
                temp_h=1;
                self.reward=-1;
            end
            if temp_w > self.weight
                temp_w=self.weight;
                self.reward=-1;
            end
            if temp_h > self.height
                temp_h=self.height;
                self.reward=-1;
            end
            if all([temp_w,temp_h]==self.target)
                self.done = 1;
                self.reward=1;
            end
            if any(all([temp_w,temp_h]==self.trap,2))
                self.done = 1;
                self.reward=-1;
            end
            self.observation=[temp_w,temp_h];
            self.info=0;
        end
        function self=reset(self)
            % 重置环境
            self.weight=self.weight_;
            self.height=self.height_;
            self.fresh_time=self.fresh_time_;
            self.fig=self.fig_;
            self.actions=[1 2];
            self.step_counter=0;
            self.agent=[1 1];
            self.observation=self.agent;
            self.reward=0;
            self.done=0;
            self.info=0;
        end
    end
end

下面是学习方法对象

代码语言:javascript
复制
classdef rl_q_table
    % 强化学习逻辑
    properties
        q_table
        actions
        epsilon
        alpha
        gamma
    end
    methods
        function obj = rl_q_table(actions,epsilon,alpha,gamma)
            % 初始化
            obj.actions=actions;
            obj.epsilon=epsilon;
            obj.alpha=alpha;
            obj.gamma=gamma;
            obj.q_table = containers.Map();
        end
        function dump(obj)
            % 显示学习表
            keySet = keys(obj.q_table);
            len=length(keySet);
            if len<=0
                return
            end
            disp('---------------------------')
            for i=1:len
                disp([keySet{i} ':' StrHelper.arr2str(obj.q_table(keySet{i}))])
            end
        end
        function table_ling=find_line(obj,state)
            % 取出学习表的一行
            agent_str = StrHelper.arr2str(state);
            if ~isKey(obj.q_table,agent_str)
                obj.q_table(agent_str) = zeros(1,length(obj.actions));
            end
            table_ling = obj.q_table(agent_str);
        end
        function obj=learn(obj,env, A)
            % 更新
            q_predict_arr = obj.find_line(env.agent);
            q_predict = q_predict_arr(A);
            if env.done ~= 1
                q_target = env.reward + obj.gamma * max(obj.find_line(env.observation));  % 没有结束
            else
                q_target = env.reward;  % 一局结束了
            end
            % 更新QLearning table
            table_line = obj.find_line(env.agent);
            table_line(A) = table_line(A) + obj.alpha * (q_target - q_predict);
            obj.q_table(StrHelper.arr2str(env.agent)) = table_line;
        end
        function action_name = choose_action(obj,state)
            % 选择一个动作
            state_actions = obj.find_line(state);% 取出这一步的概率
            if (rand() > obj.epsilon) || (all(state_actions == 0))
                % 初始时随机选择
                action_name = obj.actions(randi(length(obj.actions)));
            else  % 贪心选择
                [~,I] = max(state_actions);
                max_index = state_actions==state_actions(I);
                if sum(max_index)>1
                    action_name = obj.actions(max_index);
                    action_name = action_name(randi(length(action_name)));
                else
                action_name = obj.actions(I);
                end
                % 选概率大的
            end
        end
    end
end

相关工具方法在

https://github.com/sickle12138/MatGamer

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-07-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 帮你学MatLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档