前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >matlab倒立摆环境建模

matlab倒立摆环境建模

作者头像
万木逢春
发布2019-11-05 16:05:07
1.1K0
发布2019-11-05 16:05:07
举报
文章被收录于专栏:帮你学MatLab帮你学MatLab

matlab强化学习工具箱提供了建立环境的模板对象,我们只要

新建模板rlCreateEnvTemplate("CartPoleEnv")

就可以自行建立需要的学习环境,成功建立之后

env = CartPoleEnv;

就成功得到环境变量、配合强化学习agent就可以进行学习训练

训练之前可以通过

step(env,10);

查看是否正常

下面就是环境对象

classdef CartPoleEnv < rl.env.MATLABEnvironment

%CartPoleEnv: matlab倒立摆环境.

%% 属性设置

properties

% 常量设置

Gravity = 9.8

% 滑块质量

CartMass = 1.0

% 摆杆质量

PoleMass = 0.1

% 摆杆长度的一半(重心)

HalfPoleLength = 0.5

% 推力最值

MaxForce = 10

% 采样时间

Ts = 0.02

% 摆动范围限制

AngleThreshold = 12 * pi/180

% 移动范围限制

DisplacementThreshold = 2.4

% 平衡时的reward

RewardForNotFalling = 1

% 超出平衡范围的reward

PenaltyForFalling = -10

% 保存显示的figure句柄

h

% 是否绘图

show

end

properties

% 初始状态 [x,dx,theta,dtheta]'

State = zeros(4,1)

end

properties(Access = protected)

% 结束标记

IsDone = false

end

%% 必须的方法

methods

% 构造方法

function this = CartPoleEnv()

% 初始设置观察状态

ObservationInfo = rlNumericSpec([4 1]);

ObservationInfo.Name = 'CartPole States';

ObservationInfo.Description = 'x, dx, theta, dtheta';

% 设置动作

ActionInfo = rlFiniteSetSpec([-1 1]);

ActionInfo.Name = 'CartPole Action';

% 继承系统环境

this = this@rl.env.MATLABEnvironment(ObservationInfo,ActionInfo);

% 初始化、设置

updateActionInfo(this);

this.h = figure;

this.show=1;

notifyEnvUpdated(this);

end

% 一次动作的效果

function [Observation,Reward,IsDone,LoggedSignals] = step(this,Action)

LoggedSignals = [];

% 计算推力

Force = getForce(this,Action);

% 取出状态

XDot = this.State(2);

Theta = this.State(3);

ThetaDot = this.State(4);

% 缓存变量

CosTheta = cos(Theta);

SinTheta = sin(Theta);

SystemMass = this.CartMass + this.PoleMass;

temp = (Force + this.PoleMass*this.HalfPoleLength * ThetaDot^2 * SinTheta) / SystemMass;

% 计算运动结果

ThetaDotDot = (this.Gravity * SinTheta - CosTheta* temp) / (this.HalfPoleLength * (4.0/3.0 - this.PoleMass * CosTheta * CosTheta / SystemMass));

XDotDot = temp - this.PoleMass*this.HalfPoleLength * ThetaDotDot * CosTheta / SystemMass;

% 更新状态

Observation = this.State + this.Ts.*[XDot;XDotDot;ThetaDot;ThetaDotDot];

this.State = Observation;

% 检查是否超出范围

X = Observation(1);

Theta = Observation(3);

IsDone = abs(X) > this.DisplacementThreshold || abs(Theta) > this.AngleThreshold;

this.IsDone = IsDone;

% 计算reward

Reward = getReward(this);

% 通知绘图方法进行绘图

notifyEnvUpdated(this);

end

% 环境重置

function InitialObservation = reset(this)

% 初始角度

T0 = 2 * 0.05 * rand - 0.05;

% 初始角速度

Td0 = 0;

% 滑块位置

X0 = 0;

% 滑块速度

Xd0 = 0;

InitialObservation = [T0;Td0;X0;Xd0];

this.State = InitialObservation;

% 通知绘图

notifyEnvUpdated(this);

end

end

%% 可选函数、为了方便自行添加的

methods

% 计算推力

function force = getForce(this,action)

if ~ismember(action,this.ActionInfo.Elements)

error('Action must be %g for going left and %g for going right.',-this.MaxForce,this.MaxForce);

end

force = action;

end

% 设置最大推力

function updateActionInfo(this)

this.ActionInfo.Elements = this.MaxForce*[-1 1];

end

% 计算Reward

function Reward = getReward(this)

if ~this.IsDone

Reward = this.RewardForNotFalling;

else

Reward = this.PenaltyForFalling;

end

end

% 绘制环境

function plot(this)

% 初始化句柄

this.h = figure;

% 通知绘图

envUpdatedCallback(this)

end

% 用于测试环境的几个方法

function set.State(this,state)

validateattributes(state,{'numeric'},{'finite','real','vector','numel',4},'','State');

this.State = double(state(:));

notifyEnvUpdated(this);

end

function set.HalfPoleLength(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','HalfPoleLength');

this.HalfPoleLength = val;

notifyEnvUpdated(this);

end

function set.Gravity(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Gravity');

this.Gravity = val;

end

function set.CartMass(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','CartMass');

this.CartMass = val;

end

function set.PoleMass(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','PoleMass');

this.PoleMass = val;

end

function set.MaxForce(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MaxForce');

this.MaxForce = val;

updateActionInfo(this);

end

function set.Ts(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Ts');

this.Ts = val;

end

function set.AngleThreshold(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','AngleThreshold');

this.AngleThreshold = val;

end

function set.DisplacementThreshold(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','DisplacementThreshold');

this.DisplacementThreshold = val;

end

function set.RewardForNotFalling(this,val)

validateattributes(val,{'numeric'},{'real','finite','scalar'},'','RewardForNotFalling');

this.RewardForNotFalling = val;

end

function set.PenaltyForFalling(this,val)

validateattributes(val,{'numeric'},{'real','finite','scalar'},'','PenaltyForFalling');

this.PenaltyForFalling = val;

end

end

methods (Access = protected)

% 收到绘图通知开始绘图的方法

function envUpdatedCallback(this)

% 判断是否需要绘图

if ~this.show

return

end

figure(this.h)

clf

% 取出变量

X = this.State(2);

theta = this.State(3);

% 绘制滑块

cartpoly = polyshape([-0.25 -0.25 0.25 0.25],[-0.125 0.125 0.125 -0.125]);

cartpoly = translate(cartpoly,[X 0]);

plot(cartpoly,'FaceColor',[0.8500 0.3250 0.0980])

hold on

% 绘制摆杆

L = this.HalfPoleLength*2;

polepoly = polyshape([-0.1 -0.1 0.1 0.1],[0 L L 0]);

polepoly = translate(polepoly,[X,0]);

polepoly = rotate(polepoly,rad2deg(theta),[X,0]);

plot(polepoly,'FaceColor',[0 0.4470 0.7410])

hold off

xlim([-3 3])

ylim([-1 2])

end

end

end

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-10-29,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 帮你学MatLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档