前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MATLAB强化学习训练simulink模型优化

MATLAB强化学习训练simulink模型优化

作者头像
万木逢春
发布2020-04-14 15:57:12
2.4K7
发布2020-04-14 15:57:12
举报
文章被收录于专栏:帮你学MatLab帮你学MatLab

这是训练得到的效果

simulink物理结构模型用的是前面文章matlab倒立摆环境建模当在solidworks绘制后导入的,去除子模块封装,添加驱动以及传感器接口,就像这个样子

强化学习模型参考前面的训练文章MATLAB强化学习训练simulink模型,大体形式不变

但是在判断是否结束的环节添加杆的旋转角度限制,如果杆在持续转圈就停止

实现的效果是如果杆在某一方向旋转弧度超过6rad就停止、同时给reward加上-100,绝对值模块就是为了把顺时针和逆时针的旋转角度统一

添加方法和滑块的位置限制类似

再用OR逻辑模块把滑块超出固定横杆的判定和持续旋转的判定合并处理

模型建立后还需要将env的输入输出告知agent

代码语言:javascript
复制
%%
mdl = 'CartPole_RL';
open_system(mdl)
%%
actionInfo = rlNumericSpec([1 1],...
    'LowerLimit',-0.5,...
    'UpperLimit',0.5);
actionInfo.Name = 'force';
%%
observationInfo = rlNumericSpec([5 1],...
    'LowerLimit',-Inf,...
    'UpperLimit',Inf);
observationInfo.Name = 'observations';
%%
env = rlSimulinkEnv(mdl,[mdl '/RL Agent'],observationInfo,actionInfo);
这样就会在simulink模型文件中绑定agent模块了,接下来就是设置agent参数
%%
numObservations = observationInfo.Dimension(1);
%% 
Ts = 0.02;
Tf = 25;
rng(0)
%% 初始化agent
statePath = [
    imageInputLayer([numObservations 1 1],'Normalization','none','Name','observation')
    fullyConnectedLayer(128,'Name','CriticStateFC1')
    reluLayer('Name','CriticRelu1')
    fullyConnectedLayer(200,'Name','CriticStateFC2')];
actionPath = [
    imageInputLayer([1 1 1],'Normalization','none','Name','action')
    fullyConnectedLayer(200,'Name','CriticActionFC1','BiasLearnRateFactor',0)];
commonPath = [
    additionLayer(2,'Name','add')
    reluLayer('Name','CriticCommonRelu')
    fullyConnectedLayer(1,'Name','CriticOutput')];
criticNetwork = layerGraph(statePath);
criticNetwork = addLayers(criticNetwork,actionPath);
criticNetwork = addLayers(criticNetwork,commonPath);
criticNetwork = connectLayers(criticNetwork,'CriticStateFC2','add/in1');
criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');
figure
plot(criticNetwork)
criticOptions = rlRepresentationOptions('LearnRate',1e-03,'GradientThreshold',1);
critic = rlRepresentation(criticNetwork,observationInfo,actionInfo,...
    'Observation',{'observation'},'Action',{'action'},criticOptions);
actorNetwork = [
    imageInputLayer([numObservations 1 1],'Normalization','none','Name','observation')
    fullyConnectedLayer(128,'Name','ActorFC1')
    reluLayer('Name','ActorRelu1')
    fullyConnectedLayer(200,'Name','ActorFC2')
    reluLayer('Name','ActorRelu2')
    fullyConnectedLayer(1,'Name','ActorFC3')
    tanhLayer('Name','ActorTanh1')
    scalingLayer('Name','ActorScaling','Scale',max(actionInfo.UpperLimit))];
actorOptions = rlRepresentationOptions('LearnRate',5e-04,'GradientThreshold',1);
actor = rlRepresentation(actorNetwork,observationInfo,actionInfo,...
    'Observation',{'observation'},'Action',{'ActorScaling'},actorOptions);
agentOptions = rlDDPGAgentOptions(...
    'SampleTime',Ts,...
    'TargetSmoothFactor',1e-3,...
    'ExperienceBufferLength',1e6,...
    'MiniBatchSize',128);
agentOptions.NoiseOptions.Variance = 0.4;
agentOptions.NoiseOptions.VarianceDecayRate = 1e-5;
agent = rlDDPGAgent(actor,critic,agentOptions);
%% 设置训练参数
maxepisodes = 2000;
maxsteps = ceil(Tf/Ts);
trainingOptions = rlTrainingOptions(...
    'MaxEpisodes',maxepisodes,...
    'MaxStepsPerEpisode',maxsteps,...
    'ScoreAveragingWindowLength',5,...
    'Verbose',false,...
    'Plots','training-progress',...
    'StopTrainingCriteria','AverageReward',...
    'StopTrainingValue',400,...
    'SaveAgentCriteria','EpisodeReward',...
    'SaveAgentValue',400);
%% 并行学习设置
trainingOptions.UseParallel = true;
trainingOptions.ParallelizationOptions.Mode = "async";
trainingOptions.ParallelizationOptions.DataToSendFromWorkers = "Experiences";
trainingOptions.ParallelizationOptions.StepsUntilDataIsSent = -1;
然后就是漫长的训练时间
%% 训练
trainingStats = train(agent,env,trainingOptions);
%% 结果展示
simOptions = rlSimulationOptions('MaxSteps',500);
experience = sim(env,agent,simOptions);
totalReward = sum(experience.Reward);
% bdclose(mdl)

这次训练最终得到的就是这样

相关文件下载链接:https://pan.baidu.com/s/1wrc_HUpRH--0ug3iIYP5zg

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-04-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 帮你学MatLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档