强化学习 DQN 算法将状态和动作当成神经网络的输入, 然后经过神经网络分析后得到动作的 Q 值, 这样不需要在表格中记录 Q 值, 而是直接使用神经网络生成 Q 值.也可以只输入状态值, 输出所有的动作值, 然后按照 Q learning 的原则, 直接选择拥有最大值的动作当做下一步要做的动作.
%% 读取环境
ccc
env = rlPredefinedEnv("CartPole-Discrete");
rng(0)
%% 初始化agent
statePath = [
imageInputLayer([4 1 1],'Normalization','none','Name','state')
fullyConnectedLayer(24,'Name','CriticStateFC1')
reluLayer('Name','CriticRelu1')
fullyConnectedLayer(24,'Name','CriticStateFC2')];
actionPath = [
imageInputLayer([1 1 1],'Normalization','none','Name','action')
fullyConnectedLayer(24,'Name','CriticActionFC1')];
commonPath = [
additionLayer(2,'Name','add')
reluLayer('Name','CriticCommonRelu')
fullyConnectedLayer(1,'Name','output')];
criticNetwork = layerGraph(statePath);
criticNetwork = addLayers(criticNetwork, actionPath);
criticNetwork = addLayers(criticNetwork, commonPath);
criticNetwork = connectLayers(criticNetwork,'CriticStateFC2','add/in1');
criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');
figure
plot(criticNetwork)
criticOpts = rlRepresentationOptions('LearnRate',0.01,'GradientThreshold',1);
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
critic = rlRepresentation(criticNetwork,obsInfo,actInfo,'Observation',{'state'},'Action',{'action'},criticOpts);
agentOpts = rlDQNAgentOptions(...
'UseDoubleDQN',false, ...
'TargetUpdateMethod',"periodic", ...
'TargetUpdateFrequency',4, ...
'ExperienceBufferLength',100000, ...
'DiscountFactor',0.99, ...
'MiniBatchSize',256);
agent = rlDQNAgent(critic,agentOpts);
%% 设置训练参数
trainOpts = rlTrainingOptions(...
'MaxEpisodes', 1000, ...
'MaxStepsPerEpisode', 500, ...
'Verbose', false, ...
'Plots','training-progress',...
'StopTrainingCriteria','AverageReward',...
'StopTrainingValue',480);
plot(env)
%% 并行学习设置
trainOpts.UseParallel = true;
trainOpts.ParallelizationOptions.Mode = "async";
trainOpts.ParallelizationOptions.DataToSendFromWorkers = "Experiences";%这里和Acagent不一样
trainOpts.ParallelizationOptions.StepsUntilDataIsSent = 32;
%% 训练
trainingStats = train(agent,env,trainOpts);
%% 结果展示
simOptions = rlSimulationOptions('MaxSteps',500);
experience = sim(env,agent,simOptions);
totalReward = sum(experience.Reward);