Library/Interface
This section details the interface and functions provided by Crux.jl.
Contents
Index
Crux.AdversarialOffPolicySolverCrux.BatchSolverCrux.OffPolicySolverCrux.OnPolicySolverCrux.PolicyParamsCrux.A2CCrux.ASAFCrux.AdRILCrux.AdVILCrux.BCCrux.BatchSACCrux.CQLCrux.DDPGCrux.DQNCrux.ExperienceReplayCrux.LagrangePPOCrux.OffPolicyGAILCrux.OnPolicyGAILCrux.OnlineIQLearnCrux.PPOCrux.RARLCrux.RARL_DQNCrux.RARL_TD3Crux.REINFORCECrux.SACCrux.SQILCrux.SoftQCrux.TD3Crux.a2c_lossCrux.ddpg_actor_lossCrux.ddpg_targetCrux.dqn_targetCrux.lagrange_ppo_lossCrux.ppo_lossCrux.reinforce_lossCrux.sac_actor_lossCrux.sac_deterministic_targetCrux.sac_max_q_targetCrux.sac_targetCrux.sac_temp_lossCrux.smoothed_ddpg_targetCrux.softq_targetCrux.td3_actor_lossCrux.td3_target
Reinforcement Learning
RL interfaces are categorized by on-policy and off-policy below.
On-Policy
Crux.OnPolicySolver — Type
On policy solver type.
Fields
agent::PolicyParamsPolicy parameters (PolicyParams)S::AbstractSpaceState spaceN::Int = 1000Number of environment interactionsΔN::Int = 200Number of interactions between updatesmax_steps::Int = 100Maximum number of steps per episodelog::Union{Nothing, LoggerParams} = nothingThe logging parametersi::Int = 0The current number of environment interactionsparam_optimizers::Dict{Any, TrainingParams} = Dict()Training parameters for the parametersa_opt::TrainingParamsTraining parameters for the actorc_opt::Union{Nothing, TrainingParams} = nothingTraining parameters for the critic𝒫::NamedTuple = (;)Parameters of the algorithminteraction_storage = nothingIf this is initialized to an array then it will store all interactionspost_sample_callback = (𝒟; kwargs...) -> nothingCallback that that happens after sampling experiencepost_batch_callback = (𝒟; kwargs...) -> nothingCallback that that happens after sampling a batch
On-policy-specific parameters
λ_gae::Float32 = 0.95Generalized advantage estimation parameterrequired_columns = Symbol[]Extra data columns to store
Parameters specific to cost constraints (a separate value network)
Vc::Union{ContinuousNetwork, Nothing} = nothingCost value approximatorcost_opt::Union{Nothing, TrainingParams} = nothingTraining parameters for the cost value
Crux.a2c_loss — Function
A2C loss function.
Crux.ppo_loss — Function
PPO loss function.
Crux.lagrange_ppo_loss — Function
PPO loss with a penalty.
Crux.LagrangePPO — Function
Lagrange-Constrained PPO solver.
LagrangePPO(;
π::ActorCritic,
Vc::ContinuousNetwork, # value network for estimating cost
ϵ::Float32 = 0.2f0,
λp::Float32 = 1f0,
λe::Float32 = 0.1f0,
λ_gae = 0.95f0,
target_kl = 0.012f0,
target_cost = 0.025f0,
penalty_scale = 1f0,
penalty_max = Inf32,
Ki_max = 10f0,
Ki = 1f-3,
Kp = 1,
Kd = 0,
ema_α = 0.95,
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
cost_opt::NamedTuple=(;),
log::NamedTuple=(;),
required_columns=[],
kwargs...)Crux.reinforce_loss — Function
REINFORCE loss function.
Crux.REINFORCE — Function
REINFORCE solver.
REINFORCE(;
π,
a_opt::NamedTuple=(;),
log::NamedTuple=(;),
required_columns=[],
kwargs...)Off-Policy
Crux.OffPolicySolver — Type
Off policy solver type.
Fields
agent::PolicyParams# Policy parameters (PolicyParams)S::AbstractSpace# State spaceN::Int = 1000# Number of environment interactionsΔN::Int = 4# Number of interactions between updatesmax_steps::Int = 100# Maximum number of steps per episodelog::Union{Nothing, LoggerParams} = nothing# The logging parametersi::Int = 0# The current number of environment interactionsparam_optimizers::Dict{Any, TrainingParams} = Dict()# Training parameters for the parametersa_opt::Union{Nothing, TrainingParams} = nothing# Training parameters for the actorc_opt::TrainingParams# Training parameters for the critic𝒫::NamedTuple = (;)# Parameters of the algorithminteraction_storage = nothing# If this is initialized to an array then it will store all interactionspost_sample_callback = (𝒟; kwargs...) -> nothing# Callback that that happens after sampling experience
Off-policy-specific parameters
post_batch_callback = (𝒟; kwargs...) -> nothingCallback that that happens after sampling a batchpre_train_callback = (𝒮; kwargs...) -> nothingcallback that gets called once prior to trainingtarget_update = (π⁻, π; kwargs...) -> polyak_average!(π⁻, π, 0.005f0)Function for updating the target networktarget_fnTarget for critic regression with input signature(π⁻, 𝒟, γ; i)buffer_size = 1000Size of the bufferrequired_columns = Symbol[]buffer = ExperienceBuffer(S, agent.space, buffer_size, required_columns)The replay bufferpriority_fn = td_errorfunction for prioritized replaybuffer_init::Int = max(c_opt.batch_size, 200)Number of observations to initialize the buffer withextra_buffers = []extra buffers (i.e. for experience replay in continual learning)buffer_fractions = [1.0]Fraction of the minibatch devoted to each buffer
Crux.ddpg_target — Function
DDPG target function.
Set yᵢ = rᵢ + γQ′(sᵢ₊₁, μ′(sᵢ₊₁ | θᵘ′) | θᶜ′)
Crux.smoothed_ddpg_target — Function
Smooth DDPG target.
Crux.ddpg_actor_loss — Function
DDPG actor loss function.
∇_θᵘ 𝐽 ≈ 1/𝑁 Σᵢ ∇ₐQ(s, a | θᶜ)|ₛ₌ₛᵢ, ₐ₌ᵤ₍ₛᵢ₎ ∇_θᵘ μ(s | θᵘ)|ₛᵢ
Crux.DDPG — Function
Deep deterministic policy gradient (DDPG) solver.
- T. P. Lillicrap, et al., "Continuous control with deep reinforcement learning", ICLR 2016.
DDPG(;
π::ActorCritic,
ΔN=50,
π_explore=GaussianNoiseExplorationPolicy(0.1f0),
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
a_loss=ddpg_actor_loss,
c_loss=td_loss(),
target_fn=ddpg_target,
prefix="",
log::NamedTuple=(;),
π_smooth=GaussianNoiseExplorationPolicy(0.1f0, ϵ_min=-0.5f0, ϵ_max=0.5f0), kwargs...)Crux.dqn_target — Function
DQN target function.
Crux.DQN — Function
Deep Q-learning (DQN) solver.
- V. Mnih, et al., "Human-level control through deep reinforcement learning", Nature 2015.
DQN(;
π::DiscreteNetwork,
N::Int,
ΔN=4,
π_explore=ϵGreedyPolicy(LinearDecaySchedule(1., 0.1, floor(Int, N/2)), π.outputs),
c_opt::NamedTuple=(;),
log::NamedTuple=(;),
c_loss=td_loss(),
target_fn=dqn_target,
prefix="",
kwargs...)Crux.sac_target — Function
SAC target function.
Crux.sac_deterministic_target — Function
Deterministic SAC target function.
Crux.sac_max_q_target — Function
Max-Q SAC target function.
Crux.sac_actor_loss — Function
SAC actor loss function.
Crux.sac_temp_loss — Function
SAC temp-based loss function.
Crux.SAC — Function
Soft Actor Critic (SAC) solver.
SAC(;
π::ActorCritic{T, DoubleNetwork{ContinuousNetwork, ContinuousNetwork}},
ΔN=50,
SAC_α::Float32=1f0,
SAC_H_target::Float32 = Float32(-prod(dim(action_space(π)))),
π_explore=GaussianNoiseExplorationPolicy(0.1f0),
SAC_α_opt::NamedTuple=(;),
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
a_loss=sac_actor_loss,
c_loss=double_Q_loss(),
target_fn=sac_target(π),
prefix="",
log::NamedTuple=(;),
𝒫::NamedTuple=(;),
param_optimizers=Dict(),
kwargs...)Crux.softq_target — Function
Soft Q-learning target function.
Crux.SoftQ — Function
Soft Q-learning solver.
SoftQ(;
π::DiscreteNetwork,
N::Int,
ΔN=4,
c_opt::NamedTuple=(;epochs=4),
log::NamedTuple=(;),
c_loss=td_loss(),
α=Float32(1.),
prefix="",
kwargs...)Crux.td3_target — Function
TD3 target function.
Crux.td3_actor_loss — Function
TD3 actor loss function.
Crux.TD3 — Function
Twin Delayed DDPG (TD3) solver.
TD3(;
π,
ΔN=50,
π_smooth::Policy=GaussianNoiseExplorationPolicy(0.1f0, ϵ_min=-0.5f0, ϵ_max=0.5f0),
π_explore=GaussianNoiseExplorationPolicy(0.1f0),
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
a_loss=td3_actor_loss,
c_loss=double_Q_loss(),
target_fn=td3_target,
prefix="",
log::NamedTuple=(;),
𝒫::NamedTuple=(;),
kwargs...)Imitation Learning
Crux.AdRIL — Function
Adversarial Reward-moment Imitation Learning (AdRIL) solver.
AdRIL(;
π,
S,
ΔN=50,
solver=SAC,
𝒟_demo,
normalize_demo::Bool=true,
expert_frac=0.5,
buffer_size = 1000,
buffer_init=0,
log::NamedTuple=(;),
buffer::ExperienceBuffer = ExperienceBuffer(S, action_space(π), buffer_size, [:i]),
kwargs...)Crux.AdVIL — Function
Adversarial Value Moment Imitation Learning (AdVIL) solver.
AdVIL(;
π,
S,
𝒟_demo,
normalize_demo::Bool=true,
λ_GP::Float32=10f0,
λ_orth::Float32=1f-4,
λ_BC::Float32=2f-1,
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
log::NamedTuple=(;),
kwargs...)Crux.OnlineIQLearn — Function
Online Inverse Q-Learning solver.
OnlineIQLearn(;
π,
S,
𝒟_demo,
γ=Float32(0.9),
normalize_demo::Bool=true,
solver=SoftQ, # or SAC for continuous states
log::NamedTuple=(;period=500),
reg::Bool=true,
α_reg=Float32(0.5),
gp::Bool=true,
λ_gp=Float32(10.),
kwargs...)Crux.OffPolicyGAIL — Function
Off-policy generative adversarial imitation learning (GAIL) solver.
OffPolicyGAIL(;
π,
S,
𝒟_demo,
𝒟_ndas::Array{ExperienceBuffer} = ExperienceBuffer[],
normalize_demo::Bool=true,
D::ContinuousNetwork,
solver=SAC,
d_opt::NamedTuple=(epochs=5,),
log::NamedTuple=(;),
kwargs...)Crux.OnPolicyGAIL — Function
On-policy generative adversarial imitation learning (GAIL) solver.
OnPolicyGAIL(;
π,
S,
γ,
λ_gae::Float32 = 0.95f0,
𝒟_demo,
αr::Float32 = 0.5f0,
normalize_demo::Bool=true,
D::ContinuousNetwork,
solver=PPO,
gan_loss::GANLoss=GAN_BCELoss(),
d_opt::NamedTuple=(;),
log::NamedTuple=(;),
Rscale=1f0,
kwargs...)Adversarial RL
Crux.AdversarialOffPolicySolver — Type
Adversarial off-policy solver.
𝒮_pro::OffPolicySolverSolver parameters for the protagonist𝒮_ant::OffPolicySolverSolver parameters for the antagonistpx::PolicyParamsNominal disturbance policytrain_pro_every::Int = 1train_ant_every::Int = 1log::Union{Nothing, LoggerParams} = nothingThe logging parametersi::Int = 0The current number of environment interactions
Crux.RARL_DQN — Function
Robust Adversarial RL (RARL) deep Q-learning solver.
Crux.RARL_TD3 — Function
Robust Adversarial RL (RARL) TD3 solver.
Crux.RARL — Function
Robust Adversarial RL (RARL) solver.
RARL(;
𝒮_pro,
𝒮_ant,
px,
log::NamedTuple=(;),
train_pro_every::Int=1,
train_ant_every::Int=1,
buffer_size=1000, # Size of the buffer
required_columns=Symbol[:x, :fail],
buffer::ExperienceBuffer=ExperienceBuffer(𝒮_pro.S, 𝒮_pro.agent.space, buffer_size, required_columns), # The replay buffer
buffer_init::Int=max(max(𝒮_pro.c_opt.batch_size, 𝒮_ant.c_opt.batch_size), 200) # Number of observations to initialize the buffer with
)Continual Learning
Crux.ExperienceReplay — Function
Experience replay buffer.
Batched RL
Crux.BatchSolver — Type
Batch solver type.
Fields
agent::PolicyParamsPolicy parameters (PolicyParams)S::AbstractSpaceState spacemax_steps::Int = 100Maximum number of steps per episode𝒟_trainTraining dataparam_optimizers::Dict{Any, TrainingParams} = Dict()Training parameters for the parametersa_opt::TrainingParamsTraining parameters for the actorc_opt::Union{Nothing, TrainingParams} = nothingTraining parameters for the discriminatortarget_fn = nothingthe target function for value-based methodstarget_update = (π⁻, π; kwargs...) -> polyak_average!(π⁻, π, 0.005f0)Function for updating the target network𝒫::NamedTuple = (;)Parameters of the algorithmlog::Union{Nothing, LoggerParams} = nothingThe logging parametersrequired_columns = Symbol[]Extra columns to sampleepoch = 0Number of epochs of training
Crux.CQL — Function
Conservative Q-Learning (CQL) solver.
CQL(;
π::ActorCritic{T, DoubleNetwork{ContinuousNetwork, ContinuousNetwork}},
solver_type=BatchSAC,
CQL_α::Float32=1f0,
CQL_is_distribution=DistributionPolicy(product_distribution([Uniform(-1,1) for i=1:dim(action_space(π))[1]])),
CQL_α_thresh::Float32=10f0,
CQL_n_action_samples::Int=10,
CQL_α_opt::NamedTuple=(;),
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
log::NamedTuple=(;),
kwargs...)Crux.BatchSAC — Function
Batched soft actor critic (SAC) solver.
BatchSAC(;
π::ActorCritic{T, DoubleNetwork{ContinuousNetwork, ContinuousNetwork}},
S,
ΔN=50,
SAC_α::Float32=1f0,
SAC_H_target::Float32 = Float32(-prod(dim(action_space(π)))),
𝒟_train,
SAC_α_opt::NamedTuple=(;),
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
log::NamedTuple=(;),
𝒫::NamedTuple=(;),
param_optimizers=Dict(),
normalize_training_data = true,
kwargs...)Policies
Crux.PolicyParams — Type
Struct for combining useful policy parameters together
π::Pol
space::T2 = action_space(π)
π_explore = π
π⁻ = nothing
pa = nothing # nominal action distribution