前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MATLAB强化学习 DQN 算法

MATLAB强化学习 DQN 算法

作者头像
万木逢春
发布2019-11-18 22:15:56
4.2K1
发布2019-11-18 22:15:56
举报
文章被收录于专栏:帮你学MatLab帮你学MatLab

强化学习 DQN 算法将状态和动作当成神经网络的输入, 然后经过神经网络分析后得到动作的 Q 值, 这样不需要在表格中记录 Q 值, 而是直接使用神经网络生成 Q 值.也可以只输入状态值, 输出所有的动作值, 然后按照 Q learning 的原则, 直接选择拥有最大值的动作当做下一步要做的动作.

%% 读取环境

ccc

env = rlPredefinedEnv("CartPole-Discrete");

rng(0)

%% 初始化agent

statePath = [

imageInputLayer([4 1 1],'Normalization','none','Name','state')

fullyConnectedLayer(24,'Name','CriticStateFC1')

reluLayer('Name','CriticRelu1')

fullyConnectedLayer(24,'Name','CriticStateFC2')];

actionPath = [

imageInputLayer([1 1 1],'Normalization','none','Name','action')

fullyConnectedLayer(24,'Name','CriticActionFC1')];

commonPath = [

additionLayer(2,'Name','add')

reluLayer('Name','CriticCommonRelu')

fullyConnectedLayer(1,'Name','output')];

criticNetwork = layerGraph(statePath);

criticNetwork = addLayers(criticNetwork, actionPath);

criticNetwork = addLayers(criticNetwork, commonPath);

criticNetwork = connectLayers(criticNetwork,'CriticStateFC2','add/in1');

criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');

figure

plot(criticNetwork)

criticOpts = rlRepresentationOptions('LearnRate',0.01,'GradientThreshold',1);

obsInfo = getObservationInfo(env);

actInfo = getActionInfo(env);

critic = rlRepresentation(criticNetwork,obsInfo,actInfo,'Observation',{'state'},'Action',{'action'},criticOpts);

agentOpts = rlDQNAgentOptions(...

'UseDoubleDQN',false, ...

'TargetUpdateMethod',"periodic", ...

'TargetUpdateFrequency',4, ...

'ExperienceBufferLength',100000, ...

'DiscountFactor',0.99, ...

'MiniBatchSize',256);

agent = rlDQNAgent(critic,agentOpts);

%% 设置训练参数

trainOpts = rlTrainingOptions(...

'MaxEpisodes', 1000, ...

'MaxStepsPerEpisode', 500, ...

'Verbose', false, ...

'Plots','training-progress',...

'StopTrainingCriteria','AverageReward',...

'StopTrainingValue',480);

plot(env)

%% 并行学习设置

trainOpts.UseParallel = true;

trainOpts.ParallelizationOptions.Mode = "async";

trainOpts.ParallelizationOptions.DataToSendFromWorkers = "Experiences";%这里和Acagent不一样

trainOpts.ParallelizationOptions.StepsUntilDataIsSent = 32;

%% 训练

trainingStats = train(agent,env,trainOpts);

%% 结果展示

simOptions = rlSimulationOptions('MaxSteps',500);

experience = sim(env,agent,simOptions);

totalReward = sum(experience.Reward);

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-11-11,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 帮你学MatLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档