rlDQNAgent
Deep Q-network (DQN) reinforcement learning agent
Description
The deep Q-network (DQN) algorithm is a model-free, online, off-policy reinforcement learning method. A DQN agent is a value-based reinforcement learning agent that trains a critic to estimate the return or future rewards. DQN is a variant of Q-learning, and it operates only within discrete action spaces.
For more information,Deep Q-Network (DQN) Agents. For more information on the different types of reinforcement learning agents, seeReinforcement Learning Agents.
Creation
Syntax
Description
Create Agent from Observation and Action Specifications
creates a DQN agent for an environment with the given observation and action specifications, using default initialization options. The critic in the agent uses a default vector (that is, multi-output) Q-value deep neural network built from the observation specificationagent
= rlDQNAgent(observationInfo
,actionInfo
)observationInfo
and the action specificationactionInfo
. TheObservationInfo
andActionInfo
properties ofagent
are set to theobservationInfo
andactionInfo
input arguments, respectively.
creates a DQN agent for an environment with the given observation and action specifications. The agent uses a default network configured using options specified in theagent
= rlDQNAgent(observationInfo
,actionInfo
,initOpts
)initOpts
object. For more information on the initialization options, seerlAgentInitializationOptions
.
Create Agent from Critic
creates a DQN agent with the specified critic network using a default option set for a DQN agent.agent
= rlDQNAgent(critic
)
Specify Agent Options
creates a DQN agent with the specified critic network and sets theagent
= rlDQNAgent(critic
,agentOptions
)AgentOptions
property to theagentOptions
input argument. Use this syntax after any of the input arguments in the previous syntaxes..
Input Arguments
initOpts
—Agent initialization options
rlAgentInitializationOptions
object
Agent initialization options, specified as anrlAgentInitializationOptions
object.
critic
—Critic
rlQValueFunction
object|rlVectorQValueFunction
object
Critic, specified as anrlQValueFunction
or as the generally more efficientrlVectorQValueFunction
object. For more information on creating critics, seeCreate Policies and Value Functions.
Your critic can use a recurrent neural network as its function approximator. However, onlyrlVectorQValueFunction
supports recurrent neural networks. For an example, seeCreate DQN Agent with Recurrent Neural Network.
Properties
ObservationInfo
—Observation specifications
specification object|array of specification objects
Observation specifications, specified as a reinforcement learning specification object or an array of specification objects defining properties such as dimensions, data type, and names of the observation signals.
If you create the agent by specifying a critic object, the value ofObservationInfo
matches the value specified incritic
.
You can extractobservationInfo
从现有的环境或代理singgetObservationInfo
. You can also construct the specifications manually usingrlFiniteSetSpec
orrlNumericSpec
.
ActionInfo
—Action specification
specification object
Action specifications, specified as a reinforcement learning specification object defining properties such as dimensions, data type, and names of the action signals.
Since a DDPG agent operates in a discrete action space, you must specifyactionInfo
as anrlFiniteSetSpec
object.
If you create the agent by specifying a critic object, the value ofActionInfo
matches the value specified incritic
.
You can extractactionInfo
从现有的环境或代理singgetActionInfo
. You can also construct the specification manually usingrlFiniteSetSpec
.
AgentOptions
—Agent options
rlDQNAgentOptions
object
Agent options, specified as anrlDQNAgentOptions
object.
If you create a DQN agent with a default critic that uses a recurrent neural network, the default value ofAgentOptions.SequenceLength
is32
.
ExperienceBuffer
—经验的缓冲
rlReplayMemory
object
经验的缓冲, specified as anrlReplayMemory
object. During training the agent stores each of its experiences (S,A,R,S',D) in a buffer. Here:
Sis the current observation of the environment.
Ais the action taken by the agent.
Ris the reward for taking actionA.
S'is the next observation after taking actionA.
Dis the is-done signal after taking actionA.
UseExplorationPolicy
—Option to use exploration policy
false
(default) |true
Option to use exploration policy when selecting actions, specified as a one of the following logical values.
true
— Use the base agent exploration policy when selecting actions.false
— Use the base agent greedy policy when selecting actions.
SampleTime
—Sample time of agent
positive scalar|-1
样本的代理,specified as a positive scalar or as-1
. Setting this parameter to-1
allows for event-based simulations. The value ofSampleTime
matches the value specified inAgentOptions
.
Within a Simulink®environment, theRL Agentblock in which the agent is specified to execute everySampleTime
seconds of simulation time. IfSampleTime
is-1
, the block inherits the sample time from its parent subsystem.
Within a MATLAB®environment, the agent is executed every time the environment advances. In this case,SampleTime
is the time interval between consecutive elements in the output experience returned bysim
ortrain
. IfSampleTime
is-1
, the time interval between consecutive elements in the returned output experience reflects the timing of the event that triggers the agent execution.
Object Functions
train |
Train reinforcement learning agents within a specified environment |
sim |
Simulate trained reinforcement learning agents within specified environment |
getAction |
Obtain action from agent, actor, or policy object given environment observations |
getActor |
Get actor from reinforcement learning agent |
setActor |
Set actor of reinforcement learning agent |
getCritic |
Get critic from reinforcement learning agent |
setCritic |
Set critic of reinforcement learning agent |
generatePolicyFunction |
Generate function that evaluates policy of an agent or policy object |
Examples
Create DQN Agent from Observation and Action Specifications
Create an environment with a discrete action space, and obtain its observation and action specifications. For this example, load the environment used in the exampleCreate Agent Using Deep Network Designer and Train Using Image Observations. This environment has two observations: a 50-by-50 grayscale image and a scalar (the angular velocity of the pendulum). The action is a scalar with five possible elements (a torque of either -2
, -1
,0
,1
, or2
Nm applied to a swinging pole).
% load predefined environmentenv = rlPredefinedEnv("SimplePendulumWithImage-Discrete");% obtain observation and action specificationsobsInfo = getObservationInfo(env); actInfo = getActionInfo(env);
代理创建函数初始化的演员and critic networks randomly. You can ensure reproducibility by fixing the seed of the random generator.
rng(0)
Create a deep Q-network agent from the environment observation and action specifications.
agent = rlDQNAgent(obsInfo,actInfo);
To check your agent, usegetAction
to return the action from a random observation.
getAction(agent,{rand(obsInfo(1).Dimension),rand(obsInfo(2).Dimension)})
ans =1x1 cell array{[1]}
You can now test and train the agent within the environment.
Create DQN Agent Using Initialization Options
Create an environment with a discrete action space, and obtain its observation and action specifications. For this example, load the environment used in the exampleCreate Agent Using Deep Network Designer and Train Using Image Observations. This environment has two observations: a 50-by-50 grayscale image and a scalar (the angular velocity of the pendulum). The action is a scalar with five possible elements (a torque of either -2, -1, 0, 1, or 2 Nm applied to a swinging pole).
% load predefined environmentenv = rlPredefinedEnv("SimplePendulumWithImage-Discrete");% obtain observation and action specificationsobsInfo = getObservationInfo(env); actInfo = getActionInfo(env);
Create an agent initialization option object, specifying that each hidden fully connected layer in the network must have128
neurons (instead of the default number,256
).
initOpts = rlAgentInitializationOptions(NumHiddenUnit=128);
代理创建函数初始化的演员and critic networks randomly. Ensure reproducibility by fixing the seed of the random generator.
rng(0)
Create a policy gradient agent from the environment observation and action specifications.
agent = rlDQNAgent(obsInfo,actInfo,initOpts);
Extract the deep neural network from both the critic.
criticNet = getModel(getCritic(agent));
The default DQN agent uses a multi-output Q-value critic approximator. A multi-output approximator has observations as inputs and state-action values as outputs. Each output element represents the expected cumulative long-term reward for taking the corresponding discrete action from the state indicated by the observation inputs.
Display the layers of the critic network, and verify that each hidden fully connected layer has 128 neurons
criticNet.Layers
ans = 11x1 Layer array with layers: 1 'concat' Concatenation Concatenation of 2 inputs along dimension 1 2 'relu_body' ReLU ReLU 3 'fc_body' Fully Connected 128 fully connected layer 4 'body_output' ReLU ReLU 5 'input_1' Image Input 50x50x1 images 6 'conv_1' 2-D Convolution 64 3x3x1 convolutions with stride [1 1] and padding [0 0 0 0] 7 'relu_input_1' ReLU ReLU 8 'fc_1' Fully Connected 128 fully connected layer 9 'input_2' Feature Input 1 features 10 'fc_2' Fully Connected 128 fully connected layer 11 'output' Fully Connected 5 fully connected layer
Plot the critic network
plot(layerGraph(criticNet))
To check your agent, usegetAction
to return the action from random observations.
getAction(agent,{rand(obsInfo(1).Dimension),rand(obsInfo(2).Dimension)})
ans =1x1 cell array{[0]}
You can now test and train the agent within the environment.
Create a DQN Agent Using a Multi-Output Critic
Create an environment interface and obtain its observation and action specifications. For this example load the predefined environment used for theTrain DQN Agent to Balance Cart-Pole Systemexample. This environment has a continuous four-dimensional observation space (the positions and velocities of both cart and pole) and a discrete one-dimensional action space consisting on the application of two possible forces, -10N or 10N.
Create the predefined environment.
env = rlPredefinedEnv("CartPole-Discrete");
Get the observation and action specification objects.
obsInfo = getObservationInfo(env); actInfo = getActionInfo(env);
To approximate the Q-value function within the critic, use a deep neural network. For DQN agents with a discrete action space, you have the option to create a multi-output Q-value function critic, which is generally more efficient than a comparable single-output critic.
A network for this critic must take only the observation as input and return a vector of values for each action. Therefore, it must have an input layer with as many elements as the dimension of the observation space and an output layer having as many elements as the number of possible discrete actions. Each output element represents the expected cumulative long-term reward following from the observation given as input, when the corresponding action is taken.
Define the network as an array of layer objects, and get the dimensions of the observation space (that is,prod(obsInfo.Dimension)
) and the number of possible actions (that is,numel(actInfo.Elements)
) directly from the environment specification objects.
dnn = [ featureInputLayer(prod(obsInfo.Dimension)) fullyConnectedLayer(24) reluLayer fullyConnectedLayer(24) reluLayer fullyConnectedLayer(numel(actInfo.Elements))];
Convert the network to adlnetwork
object ad display the number of weights.
dnn = dlnetwork(dnn); summary(dnn)
Initialized: true Number of learnables: 770 Inputs: 1 'input' 4 features
Create the critic usingrlVectorQValueFunction
, the networkdnn
as well as the observation and action specifications.
critic = rlVectorQValueFunction(dnn,obsInfo,actInfo);
Check that the critic works with a random observation input.
getValue(critic,{rand(obsInfo.Dimension)})
ans =2x1 single column vector-0.0361 0.0913
Create the DQN agent using the critic.
agent = rlDQNAgent(critic)
agent = rlDQNAgent with properties: ExperienceBuffer: [1x1 rl.replay.rlReplayMemory] AgentOptions: [1x1 rl.option.rlDQNAgentOptions] UseExplorationPolicy: 0 ObservationInfo: [1x1 rl.util.rlNumericSpec] ActionInfo: [1x1 rl.util.rlFiniteSetSpec] SampleTime: 1
Specify agent options, including training options for the critic.
agent.AgentOptions.UseDoubleDQN=false; agent.AgentOptions.TargetUpdateMethod="periodic"; agent.AgentOptions.TargetUpdateFrequency=4; agent.AgentOptions.ExperienceBufferLength=100000; agent.AgentOptions.DiscountFactor=0.99; agent.AgentOptions.MiniBatchSize=256; agent.AgentOptions.CriticOptimizerOptions.LearnRate=1e-2; agent.AgentOptions.CriticOptimizerOptions.GradientThreshold=1;
To check your agent, usegetAction
to return the action from a random observation.
getAction(agent,{rand(obsInfo.Dimension)})
ans =1x1 cell array{[10]}
You can now test and train the agent within the environment.
Create a DQN Agent Using a Single-Output Critic
Create an environment interface and obtain its observation and action specifications. For this example load the predefined environment used for theTrain DQN Agent to Balance Cart-Pole Systemexample. This environment has a continuous four-dimensional observation space (the positions and velocities of both cart and pole) and a discrete one-dimensional action space consisting on the application of two possible forces, -10N or 10N.
Create the predefined environment.
env = rlPredefinedEnv("CartPole-Discrete");
Get the observation and action specification objects.
obsInfo = getObservationInfo(env); actInfo = getActionInfo(env);
Create a deep neural network to be used as approximation model within the critic. For DQN agents, you have the option to create a multi-output Q-value function critic, which is generally more efficient than a comparable single-output critic. However, for this example, create a single-output Q-value function critic instead.
The network for this critic must have two input layers, one for the observation and the other for the action, and return a scalar value representing the expected cumulative long-term reward following from the given observation and action.
Define each network path as an array of layer objects. Get the dimensions of the observation and action spaces from the environment specification objects and specify a name for the input layers, so you can later explicitly associate them with the appropriate environment channel.
% Observation pathobsPath = [ featureInputLayer(prod(obsInfo.Dimension),Name="netOin") fullyConnectedLayer(24) reluLayer fullyConnectedLayer(24,Name="fcObsPath")];% Action pathactPath = [ featureInputLayer(prod(actInfo.Dimension),Name="netAin") fullyConnectedLayer(24,Name="fcActPath")];% Common path (concatenate inputs along dim #1)commonPath = [ concatenationLayer(1,2,Name="cat") reluLayer fullyConnectedLayer(1,Name="out")];% Add paths to networknet = layerGraph; net = addLayers(net,obsPath); net = addLayers(net,actPath); net = addLayers(net,commonPath);% Connect layersnet = connectLayers(net,'fcObsPath','cat/in1'); net = connectLayers(net,'fcActPath','cat/in2');% Plot networkplot(net)
% Convert to dlnetwork objectnet = dlnetwork(net);% Display the number of weightssummary(net)
Initialized: true Number of learnables: 817 Inputs: 1 'netOin' 4 features 2 'netAin' 1 features
Create the critic usingrlQValueFunction
. Specify the names of the layers to be associated with the observation and action channels.
critic = rlQValueFunction(net,...obsInfo,...actInfo,...ObservationInputNames="netOin",...ActionInputNames="netAin");
Check the critic with a random observation and action input.
getValue(评论家,{兰德(obsInfo.Dimension)},{兰特(actInfo.Dimension)})
ans =single-0.0232
Create the DQN agent using the critic.
agent = rlDQNAgent(critic)
agent = rlDQNAgent with properties: ExperienceBuffer: [1x1 rl.replay.rlReplayMemory] AgentOptions: [1x1 rl.option.rlDQNAgentOptions] UseExplorationPolicy: 0 ObservationInfo: [1x1 rl.util.rlNumericSpec] ActionInfo: [1x1 rl.util.rlFiniteSetSpec] SampleTime: 1
Specify agent options, including training options for the critic.
agent.AgentOptions.UseDoubleDQN=false; agent.AgentOptions.TargetUpdateMethod="periodic"; agent.AgentOptions.TargetUpdateFrequency=4; agent.AgentOptions.ExperienceBufferLength=100000; agent.AgentOptions.DiscountFactor=0.99; agent.AgentOptions.MiniBatchSize=256; agent.AgentOptions.CriticOptimizerOptions.LearnRate=1e-2; agent.AgentOptions.CriticOptimizerOptions.GradientThreshold=1;
To check your agent, usegetAction
to return the action from a random observation.
getAction(agent,{rand(obsInfo.Dimension)})
ans =1x1 cell array{[10]}
You can now test and train the agent within the environment.
Create DQN Agent with Recurrent Neural Network
For this example load the predefined environment used for theTrain DQN Agent to Balance Cart-Pole Systemexample. This environment has a continuous four-dimensional observation space (the positions and velocities of both cart and pole) and a discrete one-dimensional action space consisting on the application of two possible forces, -10N or 10N.
env = rlPredefinedEnv('CartPole-Discrete');
Get the observation and action specification objects.
obsInfo = getObservationInfo(env); actInfo = getActionInfo(env);
To approximate the Q-value function within the critic, use a recurrent deep neural network. For DQN agents, only the vector function approximator,rlVectorQValueFunction
, supports recurrent neural networks models. For vector Q-value function critics, the number of elements of the output layer has to be equal to the number of possible actions:numel(actInfo.Elements)
.
Define the network as an array of layer objects. Get the dimensions of the observation space from the environment specification object (prod(obsInfo.Dimension)
). To create a recurrent neural network, use asequenceInputLayer
as the input layer and include anlstmLayer
as one of the other network layers.
net = [ sequenceInputLayer(prod(obsInfo.Dimension)) fullyConnectedLayer(50) reluLayer lstmLayer(20,OutputMode="sequence"); fullyConnectedLayer(20) reluLayer fullyConnectedLayer(numel(actInfo.Elements))];
Convert to adlnetwork
object and display the number of weights.
net = dlnetwork(net); summary(net);
Initialized: true Number of learnables: 6.3k Inputs: 1 'sequenceinput' Sequence input with 4 dimensions
Create a critic using the recurrent neural network.
critic = rlVectorQValueFunction(net,obsInfo,actInfo);
Check your critic with a random input observation.
getValue(critic,{rand(obsInfo.Dimension)})
ans =2x1 single column vector0.0136 0.0067
Define some training options for the critic.
criticOptions = rlOptimizerOptions(...LearnRate=1e-3,...GradientThreshold=1);
Specify options for creating the DQN agent. To use a recurrent neural network, you must specify aSequenceLength
greater than 1.
agentOptions = rlDQNAgentOptions(...UseDoubleDQN=false,...TargetSmoothFactor=5e-3,...ExperienceBufferLength=1e6,...SequenceLength=32,...CriticOptimizerOptions=criticOptions); agentOptions.EpsilonGreedyExploration.EpsilonDecay = 1e-4;
Create the agent. The actor and critic networks are initialized randomly.
agent = rlDQNAgent(critic,agentOptions);
Check your agent usinggetAction
to return the action from a random observation.
getAction(agent,rand(obsInfo.Dimension))
ans =1x1 cell array{[-10]}
You can now test and train the agent against the environment.
Version History
Introduced in R2019a
Open Example
You have a modified version of this example. Do you want to open this example with your edits?
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select:.
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina(Español)
- Canada(English)
- United States(English)
Europe
- Belgium(English)
- Denmark(English)
- Deutschland(Deutsch)
- España(Español)
- Finland(English)
- France(Français)
- Ireland(English)
- Italia(Italiano)
- Luxembourg(English)
- Netherlands(English)
- Norway(English)
- Österreich(Deutsch)
- Portugal(English)
- Sweden(English)
- Switzerland
- United Kingdom(English)