Library/Interface
This section details the interface and functions provided by Crux.jl.
Contents
Index
Crux.AdversarialOffPolicySolver
Crux.BatchSolver
Crux.OffPolicySolver
Crux.OnPolicySolver
Crux.PolicyParams
Crux.A2C
Crux.ASAF
Crux.AdRIL
Crux.AdVIL
Crux.BC
Crux.BatchSAC
Crux.CQL
Crux.DDPG
Crux.DQN
Crux.ExperienceReplay
Crux.LagrangePPO
Crux.OffPolicyGAIL
Crux.OnPolicyGAIL
Crux.OnlineIQLearn
Crux.PPO
Crux.RARL
Crux.RARL_DQN
Crux.RARL_TD3
Crux.REINFORCE
Crux.SAC
Crux.SQIL
Crux.SoftQ
Crux.TD3
Crux.a2c_loss
Crux.ddpg_actor_loss
Crux.ddpg_target
Crux.dqn_target
Crux.lagrange_ppo_loss
Crux.ppo_loss
Crux.reinforce_loss
Crux.sac_actor_loss
Crux.sac_deterministic_target
Crux.sac_max_q_target
Crux.sac_target
Crux.sac_temp_loss
Crux.smoothed_ddpg_target
Crux.softq_target
Crux.td3_actor_loss
Crux.td3_target
Reinforcement Learning
RL interfaces are categorized by on-policy and off-policy below.
On-Policy
Crux.OnPolicySolver
— TypeOn policy solver type.
Fields
agent::PolicyParams
Policy parameters (PolicyParams
)S::AbstractSpace
State spaceN::Int = 1000
Number of environment interactionsΔN::Int = 200
Number of interactions between updatesmax_steps::Int = 100
Maximum number of steps per episodelog::Union{Nothing, LoggerParams} = nothing
The logging parametersi::Int = 0
The current number of environment interactionsparam_optimizers::Dict{Any, TrainingParams} = Dict()
Training parameters for the parametersa_opt::TrainingParams
Training parameters for the actorc_opt::Union{Nothing, TrainingParams} = nothing
Training parameters for the critic𝒫::NamedTuple = (;)
Parameters of the algorithminteraction_storage = nothing
If this is initialized to an array then it will store all interactionspost_sample_callback = (𝒟; kwargs...) -> nothing
Callback that that happens after sampling experiencepost_batch_callback = (𝒟; kwargs...) -> nothing
Callback that that happens after sampling a batch
On-policy-specific parameters
λ_gae::Float32 = 0.95
Generalized advantage estimation parameterrequired_columns = Symbol[]
Extra data columns to store
Parameters specific to cost constraints (a separate value network)
Vc::Union{ContinuousNetwork, Nothing} = nothing
Cost value approximatorcost_opt::Union{Nothing, TrainingParams} = nothing
Training parameters for the cost value
Crux.a2c_loss
— FunctionA2C loss function.
Crux.A2C
— FunctionAdvantage actor critic (A2C) solver.
A2C(;
π::ActorCritic,
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
log::NamedTuple=(;),
λp::Float32=1f0,
λe::Float32=0.1f0,
required_columns=[],
kwargs...)
Crux.ppo_loss
— FunctionPPO loss function.
Crux.PPO
— FunctionProximal policy optimization (PPO) solver.
PPO(;
π::ActorCritic,
ϵ::Float32 = 0.2f0,
λp::Float32 = 1f0,
λe::Float32 = 0.1f0,
target_kl = 0.012f0,
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
log::NamedTuple=(;),
required_columns=[],
kwargs...)
Crux.lagrange_ppo_loss
— FunctionPPO loss with a penalty.
Crux.LagrangePPO
— FunctionLagrange-Constrained PPO solver.
LagrangePPO(;
π::ActorCritic,
Vc::ContinuousNetwork, # value network for estimating cost
ϵ::Float32 = 0.2f0,
λp::Float32 = 1f0,
λe::Float32 = 0.1f0,
λ_gae = 0.95f0,
target_kl = 0.012f0,
target_cost = 0.025f0,
penalty_scale = 1f0,
penalty_max = Inf32,
Ki_max = 10f0,
Ki = 1f-3,
Kp = 1,
Kd = 0,
ema_α = 0.95,
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
cost_opt::NamedTuple=(;),
log::NamedTuple=(;),
required_columns=[],
kwargs...)
Crux.reinforce_loss
— FunctionREINFORCE loss function.
Crux.REINFORCE
— FunctionREINFORCE solver.
REINFORCE(;
π,
a_opt::NamedTuple=(;),
log::NamedTuple=(;),
required_columns=[],
kwargs...)
Off-Policy
Crux.OffPolicySolver
— TypeOff policy solver type.
Fields
agent::PolicyParams
# Policy parameters (PolicyParams
)S::AbstractSpace
# State spaceN::Int = 1000
# Number of environment interactionsΔN::Int = 4
# Number of interactions between updatesmax_steps::Int = 100
# Maximum number of steps per episodelog::Union{Nothing, LoggerParams} = nothing
# The logging parametersi::Int = 0
# The current number of environment interactionsparam_optimizers::Dict{Any, TrainingParams} = Dict()
# Training parameters for the parametersa_opt::Union{Nothing, TrainingParams} = nothing
# Training parameters for the actorc_opt::TrainingParams
# Training parameters for the critic𝒫::NamedTuple = (;)
# Parameters of the algorithminteraction_storage = nothing
# If this is initialized to an array then it will store all interactionspost_sample_callback = (𝒟; kwargs...) -> nothing
# Callback that that happens after sampling experience
Off-policy-specific parameters
post_batch_callback = (𝒟; kwargs...) -> nothing
Callback that that happens after sampling a batchpre_train_callback = (𝒮; kwargs...) -> nothing
callback that gets called once prior to trainingtarget_update = (π⁻, π; kwargs...) -> polyak_average!(π⁻, π, 0.005f0)
Function for updating the target networktarget_fn
Target for critic regression with input signature(π⁻, 𝒟, γ; i)
buffer_size = 1000
Size of the bufferrequired_columns = Symbol[]
buffer = ExperienceBuffer(S, agent.space, buffer_size, required_columns)
The replay bufferpriority_fn = td_error
function for prioritized replaybuffer_init::Int = max(c_opt.batch_size, 200)
Number of observations to initialize the buffer withextra_buffers = []
extra buffers (i.e. for experience replay in continual learning)buffer_fractions = [1.0]
Fraction of the minibatch devoted to each buffer
Crux.ddpg_target
— FunctionDDPG target function.
Set yᵢ = rᵢ + γQ′(sᵢ₊₁, μ′(sᵢ₊₁ | θᵘ′) | θᶜ′)
Crux.smoothed_ddpg_target
— FunctionSmooth DDPG target.
Crux.ddpg_actor_loss
— FunctionDDPG actor loss function.
∇_θᵘ 𝐽 ≈ 1/𝑁 Σᵢ ∇ₐQ(s, a | θᶜ)|ₛ₌ₛᵢ, ₐ₌ᵤ₍ₛᵢ₎ ∇_θᵘ μ(s | θᵘ)|ₛᵢ
Crux.DDPG
— FunctionDeep deterministic policy gradient (DDPG) solver.
- T. P. Lillicrap, et al., "Continuous control with deep reinforcement learning", ICLR 2016.
DDPG(;
π::ActorCritic,
ΔN=50,
π_explore=GaussianNoiseExplorationPolicy(0.1f0),
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
a_loss=ddpg_actor_loss,
c_loss=td_loss(),
target_fn=ddpg_target,
prefix="",
log::NamedTuple=(;),
π_smooth=GaussianNoiseExplorationPolicy(0.1f0, ϵ_min=-0.5f0, ϵ_max=0.5f0), kwargs...)
Crux.dqn_target
— FunctionDQN target function.
Crux.DQN
— FunctionDeep Q-learning (DQN) solver.
- V. Mnih, et al., "Human-level control through deep reinforcement learning", Nature 2015.
DQN(;
π::DiscreteNetwork,
N::Int,
ΔN=4,
π_explore=ϵGreedyPolicy(LinearDecaySchedule(1., 0.1, floor(Int, N/2)), π.outputs),
c_opt::NamedTuple=(;),
log::NamedTuple=(;),
c_loss=td_loss(),
target_fn=dqn_target,
prefix="",
kwargs...)
Crux.sac_target
— FunctionSAC target function.
Crux.sac_deterministic_target
— FunctionDeterministic SAC target function.
Crux.sac_max_q_target
— FunctionMax-Q SAC target function.
Crux.sac_actor_loss
— FunctionSAC actor loss function.
Crux.sac_temp_loss
— FunctionSAC temp-based loss function.
Crux.SAC
— FunctionSoft Actor Critic (SAC) solver.
SAC(;
π::ActorCritic{T, DoubleNetwork{ContinuousNetwork, ContinuousNetwork}},
ΔN=50,
SAC_α::Float32=1f0,
SAC_H_target::Float32 = Float32(-prod(dim(action_space(π)))),
π_explore=GaussianNoiseExplorationPolicy(0.1f0),
SAC_α_opt::NamedTuple=(;),
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
a_loss=sac_actor_loss,
c_loss=double_Q_loss(),
target_fn=sac_target(π),
prefix="",
log::NamedTuple=(;),
𝒫::NamedTuple=(;),
param_optimizers=Dict(),
kwargs...)
Crux.softq_target
— FunctionSoft Q-learning target function.
Crux.SoftQ
— FunctionSoft Q-learning solver.
SoftQ(;
π::DiscreteNetwork,
N::Int,
ΔN=4,
c_opt::NamedTuple=(;epochs=4),
log::NamedTuple=(;),
c_loss=td_loss(),
α=Float32(1.),
prefix="",
kwargs...)
Crux.td3_target
— FunctionTD3 target function.
Crux.td3_actor_loss
— FunctionTD3 actor loss function.
Crux.TD3
— FunctionTwin Delayed DDPG (TD3) solver.
TD3(;
π,
ΔN=50,
π_smooth::Policy=GaussianNoiseExplorationPolicy(0.1f0, ϵ_min=-0.5f0, ϵ_max=0.5f0),
π_explore=GaussianNoiseExplorationPolicy(0.1f0),
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
a_loss=td3_actor_loss,
c_loss=double_Q_loss(),
target_fn=td3_target,
prefix="",
log::NamedTuple=(;),
𝒫::NamedTuple=(;),
kwargs...)
Imitation Learning
Crux.AdRIL
— FunctionAdversarial Reward-moment Imitation Learning (AdRIL) solver.
AdRIL(;
π,
S,
ΔN=50,
solver=SAC,
𝒟_demo,
normalize_demo::Bool=true,
expert_frac=0.5,
buffer_size = 1000,
buffer_init=0,
log::NamedTuple=(;),
buffer::ExperienceBuffer = ExperienceBuffer(S, action_space(π), buffer_size, [:i]),
kwargs...)
Crux.AdVIL
— FunctionAdversarial Value Moment Imitation Learning (AdVIL) solver.
AdVIL(;
π,
S,
𝒟_demo,
normalize_demo::Bool=true,
λ_GP::Float32=10f0,
λ_orth::Float32=1f-4,
λ_BC::Float32=2f-1,
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
log::NamedTuple=(;),
kwargs...)
Crux.ASAF
— FunctionAdversarial Soft Advantage Fitting (ASAF) solver.
ASAF(;
π,
S,
𝒟_demo,
normalize_demo::Bool=true,
ΔN=50,
λ_orth=1f-4,
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
log::NamedTuple=(;),
kwargs...)
Crux.BC
— FunctionBehavioral cloning solver.
BC(;
π,
S,
𝒟_demo,
normalize_demo::Bool=true,
loss=nothing,
validation_fraction=0.3,
window=100,
λe::Float32=1f-3,
opt::NamedTuple=(;),
log::NamedTuple=(;),
kwargs...)
Crux.OnlineIQLearn
— FunctionOnline Inverse Q-Learning solver.
OnlineIQLearn(;
π,
S,
𝒟_demo,
γ=Float32(0.9),
normalize_demo::Bool=true,
solver=SoftQ, # or SAC for continuous states
log::NamedTuple=(;period=500),
reg::Bool=true,
α_reg=Float32(0.5),
gp::Bool=true,
λ_gp=Float32(10.),
kwargs...)
Crux.OffPolicyGAIL
— FunctionOff-policy generative adversarial imitation learning (GAIL) solver.
OffPolicyGAIL(;
π,
S,
𝒟_demo,
𝒟_ndas::Array{ExperienceBuffer} = ExperienceBuffer[],
normalize_demo::Bool=true,
D::ContinuousNetwork,
solver=SAC,
d_opt::NamedTuple=(epochs=5,),
log::NamedTuple=(;),
kwargs...)
Crux.OnPolicyGAIL
— FunctionOn-policy generative adversarial imitation learning (GAIL) solver.
OnPolicyGAIL(;
π,
S,
γ,
λ_gae::Float32 = 0.95f0,
𝒟_demo,
αr::Float32 = 0.5f0,
normalize_demo::Bool=true,
D::ContinuousNetwork,
solver=PPO,
gan_loss::GANLoss=GAN_BCELoss(),
d_opt::NamedTuple=(;),
log::NamedTuple=(;),
Rscale=1f0,
kwargs...)
Crux.SQIL
— FunctionSoft Q Imitation Learning (SQIL) solver.
SQIL(;
π,
S,
𝒟_demo,
normalize_demo::Bool=true,
solver=SAC,
log::NamedTuple=(;),
kwargs...)
Adversarial RL
Crux.AdversarialOffPolicySolver
— TypeAdversarial off-policy solver.
𝒮_pro::OffPolicySolver
Solver parameters for the protagonist𝒮_ant::OffPolicySolver
Solver parameters for the antagonistpx::PolicyParams
Nominal disturbance policytrain_pro_every::Int = 1
train_ant_every::Int = 1
log::Union{Nothing, LoggerParams} = nothing
The logging parametersi::Int = 0
The current number of environment interactions
Crux.RARL_DQN
— FunctionRobust Adversarial RL (RARL) deep Q-learning solver.
Crux.RARL_TD3
— FunctionRobust Adversarial RL (RARL) TD3 solver.
Crux.RARL
— FunctionRobust Adversarial RL (RARL) solver.
RARL(;
𝒮_pro,
𝒮_ant,
px,
log::NamedTuple=(;),
train_pro_every::Int=1,
train_ant_every::Int=1,
buffer_size=1000, # Size of the buffer
required_columns=Symbol[:x, :fail],
buffer::ExperienceBuffer=ExperienceBuffer(𝒮_pro.S, 𝒮_pro.agent.space, buffer_size, required_columns), # The replay buffer
buffer_init::Int=max(max(𝒮_pro.c_opt.batch_size, 𝒮_ant.c_opt.batch_size), 200) # Number of observations to initialize the buffer with
)
Continual Learning
Crux.ExperienceReplay
— FunctionExperience replay buffer.
Batched RL
Crux.BatchSolver
— TypeBatch solver type.
Fields
agent::PolicyParams
Policy parameters (PolicyParams
)S::AbstractSpace
State spacemax_steps::Int = 100
Maximum number of steps per episode𝒟_train
Training dataparam_optimizers::Dict{Any, TrainingParams} = Dict()
Training parameters for the parametersa_opt::TrainingParams
Training parameters for the actorc_opt::Union{Nothing, TrainingParams} = nothing
Training parameters for the discriminatortarget_fn = nothing
the target function for value-based methodstarget_update = (π⁻, π; kwargs...) -> polyak_average!(π⁻, π, 0.005f0)
Function for updating the target network𝒫::NamedTuple = (;)
Parameters of the algorithmlog::Union{Nothing, LoggerParams} = nothing
The logging parametersrequired_columns = Symbol[]
Extra columns to sampleepoch = 0
Number of epochs of training
Crux.CQL
— FunctionConservative Q-Learning (CQL) solver.
CQL(;
π::ActorCritic{T, DoubleNetwork{ContinuousNetwork, ContinuousNetwork}},
solver_type=BatchSAC,
CQL_α::Float32=1f0,
CQL_is_distribution=DistributionPolicy(product_distribution([Uniform(-1,1) for i=1:dim(action_space(π))[1]])),
CQL_α_thresh::Float32=10f0,
CQL_n_action_samples::Int=10,
CQL_α_opt::NamedTuple=(;),
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
log::NamedTuple=(;),
kwargs...)
Crux.BatchSAC
— FunctionBatched soft actor critic (SAC) solver.
BatchSAC(;
π::ActorCritic{T, DoubleNetwork{ContinuousNetwork, ContinuousNetwork}},
S,
ΔN=50,
SAC_α::Float32=1f0,
SAC_H_target::Float32 = Float32(-prod(dim(action_space(π)))),
𝒟_train,
SAC_α_opt::NamedTuple=(;),
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
log::NamedTuple=(;),
𝒫::NamedTuple=(;),
param_optimizers=Dict(),
normalize_training_data = true,
kwargs...)
Policies
Crux.PolicyParams
— TypeStruct for combining useful policy parameters together
π::Pol
space::T2 = action_space(π)
π_explore = π
π⁻ = nothing
pa = nothing # nominal action distribution