前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >matlab强化学习-复杂网格环境搭建与训练

matlab强化学习-复杂网格环境搭建与训练

作者头像
万木逢春
发布2019-10-25 01:55:34
1.5K0
发布2019-10-25 01:55:34
举报
文章被收录于专栏:帮你学MatLab帮你学MatLab

借助通用的网格世界环境类搭建各式各样的二维环境进行强化学习训练

先看这个向量式方法生成随机maze来训练Q—learning的agent、主要特点是地图大

ccc

rng(0)

%% 布置环境硬件

n = 12;% 尺寸

GW = createGridWorld(n,n,'Kings');

s = randi([1,n],(n/3)*n,2); % 生成随即障碍物

ss = num2cell(s,2); % 按行生成cell数组

sss = arrayfun(@(x) StrHelper.arr2str(x{:}),ss);

% 将cell数组每一行生成字符串

GW.ObstacleStates = unique(sss); % 去除重复

GW.CurrentState = StrHelper.arr2str([1,1]);

GW.TerminalStates = StrHelper.arr2str([10,12]);

%% 设置可否行进

updateStateTranstionForObstacles(GW)

%% 设置reward

nS = numel(GW.States);

nA = numel(GW.Actions);

GW.R = -1*ones(nS,nS,nA);

GW.R(:,state2idx(GW,GW.ObstacleStates),:) = -n*2;

GW.R(:,state2idx(GW,GW.TerminalStates),:) = n*2;

%% 生成环境及初始位置

env = rlMDPEnv(GW);

env.ResetFcn = @() 1;

plot(env)

%% 参数初始化

qTable = rlTable(getObservationInfo(env),getActionInfo(env));

tableRep = rlRepresentation(qTable);

tableRep.Options.LearnRate = 1;

agentOpts = rlQAgentOptions;

agentOpts.EpsilonGreedyExploration.Epsilon = .04;

qAgent = rlQAgent(tableRep,agentOpts);

trainOpts = rlTrainingOptions;

trainOpts.MaxStepsPerEpisode = n*10;

trainOpts.MaxEpisodes= 1000;

trainOpts.StopTrainingCriteria = "AverageReward";

trainOpts.StopTrainingValue = n*2;

trainOpts.ScoreAveragingWindowLength = 30;

%% 训练

trainingStats = train(qAgent,env,trainOpts);

%% 结果展示

plot(env)

env.Model.Viewer.ShowTrace = true;

env.Model.Viewer.clearTrace;

sim(qAgent,env)

甚至还可以设置跳跃、类似传送门的效果

先创建障碍物

GW = createGridWorld(5,5);

GW.CurrentState = '[2,1]';

GW.TerminalStates = '[5,5]';

GW.ObstacleStates = ["[3,3]";"[3,4]";"[3,5]";"[4,3]"];

在updateStateTranstionForObstacles(GW)之后 GW.T(state2idx(GW,"[2,4]"),:,:) = 0; GW.T(state2idx(GW,"[2,4]"),state2idx(GW,"[4,4]"),:) = 1;

设定可以从[2,4]跳跃到[4,4]

在设定reward的时候加一句

nS = numel(GW.States); nA = numel(GW.Actions); GW.R = -1*ones(nS,nS,nA); GW.R(state2idx(GW,"[2,4]"),state2idx(GW,"[4,4]"),:) = 5; GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;

如果用了这一步跳跃就得5分

训练就会找到这个跳跃节点

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-10-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 帮你学MatLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档