% 让agent越过危险到达目标
大概情景是湖面结冰不结实、黑色圆圈表示踩到就会跌落、白色圆圈表示目标、红色小圆圈代表agent
% 强化学习
ccc
rng('default');
% 初始化环境
env=two_dimensional_env(4,4,0.01);
% 产生一个强化学习agent
two_dimensional_rl=rl_q_table(env.actions,0.9,0.1,0.9);
% pause(2)
for episode =1:env.max_episodes
env = env.reset();
env.render();
while 1
A = two_dimensional_rl.choose_action(env.agent);
env = env.step(A); % 采取动作获得状态和奖励
two_dimensional_rl=two_dimensional_rl.learn(env, A); % 更新
two_dimensional_rl.dump();% 显示学习进度
env.agent=env.observation;
env.render();
if env.done
break
end
end
end
下面是具体的环境
classdef two_dimensional_env
%二维宝藏环境类
properties
weight
height
actions
fresh_time
fig
agent
observation
reward
done
info
max_episodes
step_counter
end
% 上面是外部可操作的属性
% 下面是内部可操作的属性
properties (SetAccess = private)
weight_
height_
fresh_time_
fig_
target=[3,3]
trap=[2,3;3,2]
end
% 在methods中是外部可调用的方法
methods
function self = two_dimensional_env(weight,height,fresh_time)
% 新建类时候初始化
switch nargin
case 1
self.weight=weight;
self.height=weight;
self.fresh_time=0.05;
case 2
self.weight=weight;
self.height=height;
self.fresh_time=0.05;
case 3
self.weight=weight;
self.height=height;
self.fresh_time=fresh_time;
otherwise
error('缺少参数');
end
self.max_episodes=100;
self.step_counter=0;
self.actions=[1 2 3 4];
self.agent=[1 1];
self.observation=self.agent;
self.reward=0;
self.done=0;
self.info=0;
h=figure;
self.fig = h.Number;
self.weight_=self.weight;
self.height_=self.height;
self.fresh_time_=self.fresh_time;
self.fig_=self.fig;
end
function render(self)
% 绘制图像
pause(self.fresh_time)
h=figure(self.fig);
ax = axes('Parent',h);
ax.YAxis.Visible = 'off';
ax.XAxis.Visible = 'off';
for i=1:self.weight
for j=1:self.weight
draw.rect(i*10,j*10,10);
if all([i,j]==self.agent)
draw.circle(i*10,j*10,2,'r');
end
if all([i,j]==self.target)
draw.circle(i*10,j*10,4);
end
if any(all([i,j]==self.trap,2))
draw.circle(i*10,j*10,4,'k');
end
end
end
axis(ax, 'equal')
end
function self = step(self,action)
% 一步动作
self.done = 0;
self.reward=0;
temp_w=self.agent(1);
temp_h=self.agent(2);
if action == 1
temp_w = self.agent(1)-1;
end
if action == 2
temp_w = self.agent(1)+1;
end
if action == 3
temp_h = self.agent(2)-1;
end
if action == 4
temp_h = self.agent(2)+1;
end
if temp_w < 1
temp_w=1;
self.reward=-1;
end
if temp_h < 1
temp_h=1;
self.reward=-1;
end
if temp_w > self.weight
temp_w=self.weight;
self.reward=-1;
end
if temp_h > self.height
temp_h=self.height;
self.reward=-1;
end
if all([temp_w,temp_h]==self.target)
self.done = 1;
self.reward=1;
end
if any(all([temp_w,temp_h]==self.trap,2))
self.done = 1;
self.reward=-1;
end
self.observation=[temp_w,temp_h];
self.info=0;
end
function self=reset(self)
% 重置环境
self.weight=self.weight_;
self.height=self.height_;
self.fresh_time=self.fresh_time_;
self.fig=self.fig_;
self.actions=[1 2];
self.step_counter=0;
self.agent=[1 1];
self.observation=self.agent;
self.reward=0;
self.done=0;
self.info=0;
end
end
end
下面是学习方法对象
classdef rl_q_table
% 强化学习逻辑
properties
q_table
actions
epsilon
alpha
gamma
end
methods
function obj = rl_q_table(actions,epsilon,alpha,gamma)
% 初始化
obj.actions=actions;
obj.epsilon=epsilon;
obj.alpha=alpha;
obj.gamma=gamma;
obj.q_table = containers.Map();
end
function dump(obj)
% 显示学习表
keySet = keys(obj.q_table);
len=length(keySet);
if len<=0
return
end
disp('---------------------------')
for i=1:len
disp([keySet{i} ':' StrHelper.arr2str(obj.q_table(keySet{i}))])
end
end
function table_ling=find_line(obj,state)
% 取出学习表的一行
agent_str = StrHelper.arr2str(state);
if ~isKey(obj.q_table,agent_str)
obj.q_table(agent_str) = zeros(1,length(obj.actions));
end
table_ling = obj.q_table(agent_str);
end
function obj=learn(obj,env, A)
% 更新
q_predict_arr = obj.find_line(env.agent);
q_predict = q_predict_arr(A);
if env.done ~= 1
q_target = env.reward + obj.gamma * max(obj.find_line(env.observation)); % 没有结束
else
q_target = env.reward; % 一局结束了
end
% 更新QLearning table
table_line = obj.find_line(env.agent);
table_line(A) = table_line(A) + obj.alpha * (q_target - q_predict);
obj.q_table(StrHelper.arr2str(env.agent)) = table_line;
end
function action_name = choose_action(obj,state)
% 选择一个动作
state_actions = obj.find_line(state);% 取出这一步的概率
if (rand() > obj.epsilon) || (all(state_actions == 0))
% 初始时随机选择
action_name = obj.actions(randi(length(obj.actions)));
else % 贪心选择
[~,I] = max(state_actions);
max_index = state_actions==state_actions(I);
if sum(max_index)>1
action_name = obj.actions(max_index);
action_name = action_name(randi(length(action_name)));
else
action_name = obj.actions(I);
end
% 选概率大的
end
end
end
end
相关工具方法在
https://github.com/sickle12138/MatGamer