Library/Interface

This section details the interface and functions provided by Crux.jl.

Contents

Index

Reinforcement Learning

RL interfaces are categorized by on-policy and off-policy below.

On-Policy

Crux.OnPolicySolverType

On policy solver type.

Fields

  • agent::PolicyParams Policy parameters (PolicyParams)
  • S::AbstractSpace State space
  • N::Int = 1000 Number of environment interactions
  • ΔN::Int = 200 Number of interactions between updates
  • max_steps::Int = 100 Maximum number of steps per episode
  • log::Union{Nothing, LoggerParams} = nothing The logging parameters
  • i::Int = 0 The current number of environment interactions
  • param_optimizers::Dict{Any, TrainingParams} = Dict() Training parameters for the parameters
  • a_opt::TrainingParams Training parameters for the actor
  • c_opt::Union{Nothing, TrainingParams} = nothing Training parameters for the critic
  • 𝒫::NamedTuple = (;) Parameters of the algorithm
  • interaction_storage = nothing If this is initialized to an array then it will store all interactions
  • post_sample_callback = (𝒟; kwargs...) -> nothing Callback that that happens after sampling experience
  • post_batch_callback = (𝒟; kwargs...) -> nothing Callback that that happens after sampling a batch

On-policy-specific parameters

  • λ_gae::Float32 = 0.95 Generalized advantage estimation parameter
  • required_columns = Symbol[] Extra data columns to store

Parameters specific to cost constraints (a separate value network)

  • Vc::Union{ContinuousNetwork, Nothing} = nothing Cost value approximator
  • cost_opt::Union{Nothing, TrainingParams} = nothing Training parameters for the cost value
source
Crux.A2CFunction

Advantage actor critic (A2C) solver.

A2C(;
    π::ActorCritic, 
    a_opt::NamedTuple=(;), 
    c_opt::NamedTuple=(;), 
    log::NamedTuple=(;), 
    λp::Float32=1f0, 
    λe::Float32=0.1f0, 
    required_columns=[],
    kwargs...)
source
Crux.PPOFunction

Proximal policy optimization (PPO) solver.

PPO(;
    π::ActorCritic,
    ϵ::Float32 = 0.2f0,
    λp::Float32 = 1f0,
    λe::Float32 = 0.1f0,
    target_kl = 0.012f0,
    a_opt::NamedTuple=(;),
    c_opt::NamedTuple=(;),
    log::NamedTuple=(;),
    required_columns=[],
    kwargs...)
source
Crux.LagrangePPOFunction

Lagrange-Constrained PPO solver.

LagrangePPO(;
    π::ActorCritic,
    Vc::ContinuousNetwork, # value network for estimating cost
    ϵ::Float32 = 0.2f0,
    λp::Float32 = 1f0,
    λe::Float32 = 0.1f0,
    λ_gae = 0.95f0,
    target_kl = 0.012f0,
    target_cost = 0.025f0,
    penalty_scale = 1f0,
    penalty_max = Inf32,
    Ki_max = 10f0,
    Ki = 1f-3,
    Kp = 1,
    Kd = 0,
    ema_α = 0.95,
    a_opt::NamedTuple=(;),
    c_opt::NamedTuple=(;),
    cost_opt::NamedTuple=(;),
    log::NamedTuple=(;),
    required_columns=[],
    kwargs...)
source
Crux.REINFORCEFunction

REINFORCE solver.

REINFORCE(;
    π,
    a_opt::NamedTuple=(;), 
    log::NamedTuple=(;),
    required_columns=[],
    kwargs...)
source

Off-Policy

Crux.OffPolicySolverType

Off policy solver type.

Fields

  • agent::PolicyParams # Policy parameters (PolicyParams)
  • S::AbstractSpace # State space
  • N::Int = 1000 # Number of environment interactions
  • ΔN::Int = 4 # Number of interactions between updates
  • max_steps::Int = 100 # Maximum number of steps per episode
  • log::Union{Nothing, LoggerParams} = nothing # The logging parameters
  • i::Int = 0 # The current number of environment interactions
  • param_optimizers::Dict{Any, TrainingParams} = Dict() # Training parameters for the parameters
  • a_opt::Union{Nothing, TrainingParams} = nothing # Training parameters for the actor
  • c_opt::TrainingParams # Training parameters for the critic
  • 𝒫::NamedTuple = (;) # Parameters of the algorithm
  • interaction_storage = nothing # If this is initialized to an array then it will store all interactions
  • post_sample_callback = (𝒟; kwargs...) -> nothing # Callback that that happens after sampling experience

Off-policy-specific parameters

  • post_batch_callback = (𝒟; kwargs...) -> nothing Callback that that happens after sampling a batch
  • pre_train_callback = (𝒮; kwargs...) -> nothing callback that gets called once prior to training
  • target_update = (π⁻, π; kwargs...) -> polyak_average!(π⁻, π, 0.005f0) Function for updating the target network
  • target_fn Target for critic regression with input signature (π⁻, 𝒟, γ; i)
  • buffer_size = 1000 Size of the buffer
  • required_columns = Symbol[]
  • buffer = ExperienceBuffer(S, agent.space, buffer_size, required_columns) The replay buffer
  • priority_fn = td_error function for prioritized replay
  • buffer_init::Int = max(c_opt.batch_size, 200) Number of observations to initialize the buffer with
  • extra_buffers = [] extra buffers (i.e. for experience replay in continual learning)
  • buffer_fractions = [1.0] Fraction of the minibatch devoted to each buffer
source
Crux.ddpg_targetFunction

DDPG target function.

Set yᵢ = rᵢ + γQ′(sᵢ₊₁, μ′(sᵢ₊₁ | θᵘ′) | θᶜ′)

source
Crux.ddpg_actor_lossFunction

DDPG actor loss function.

∇_θᵘ 𝐽 ≈ 1/𝑁 Σᵢ ∇ₐQ(s, a | θᶜ)|ₛ₌ₛᵢ, ₐ₌ᵤ₍ₛᵢ₎ ∇_θᵘ μ(s | θᵘ)|ₛᵢ

source
Crux.DDPGFunction

Deep deterministic policy gradient (DDPG) solver.

  • T. P. Lillicrap, et al., "Continuous control with deep reinforcement learning", ICLR 2016.
DDPG(;
    π::ActorCritic, 
    ΔN=50, 
    π_explore=GaussianNoiseExplorationPolicy(0.1f0),  
    a_opt::NamedTuple=(;), 
    c_opt::NamedTuple=(;),
    a_loss=ddpg_actor_loss,
    c_loss=td_loss(),
    target_fn=ddpg_target,
    prefix="",
    log::NamedTuple=(;), 
    π_smooth=GaussianNoiseExplorationPolicy(0.1f0, ϵ_min=-0.5f0, ϵ_max=0.5f0), kwargs...)
source
Crux.DQNFunction

Deep Q-learning (DQN) solver.

  • V. Mnih, et al., "Human-level control through deep reinforcement learning", Nature 2015.
DQN(;
      π::DiscreteNetwork, 
      N::Int, 
      ΔN=4, 
      π_explore=ϵGreedyPolicy(LinearDecaySchedule(1., 0.1, floor(Int, N/2)), π.outputs), 
      c_opt::NamedTuple=(;), 
      log::NamedTuple=(;),
      c_loss=td_loss(),
      target_fn=dqn_target,
      prefix="",
      kwargs...)
source
Crux.SACFunction

Soft Actor Critic (SAC) solver.

SAC(;
    π::ActorCritic{T, DoubleNetwork{ContinuousNetwork, ContinuousNetwork}},
    ΔN=50,
    SAC_α::Float32=1f0,
    SAC_H_target::Float32 = Float32(-prod(dim(action_space(π)))),
    π_explore=GaussianNoiseExplorationPolicy(0.1f0),
    SAC_α_opt::NamedTuple=(;),
    a_opt::NamedTuple=(;),
    c_opt::NamedTuple=(;),
    a_loss=sac_actor_loss,
    c_loss=double_Q_loss(),
    target_fn=sac_target(π),
    prefix="",
    log::NamedTuple=(;),
    𝒫::NamedTuple=(;),
    param_optimizers=Dict(),
    kwargs...)
source
Crux.SoftQFunction

Soft Q-learning solver.

SoftQ(;
    π::DiscreteNetwork, 
    N::Int, 
    ΔN=4, 
    c_opt::NamedTuple=(;epochs=4), 
    log::NamedTuple=(;),
    c_loss=td_loss(),
    α=Float32(1.),
    prefix="",
    kwargs...)
source
Crux.TD3Function

Twin Delayed DDPG (TD3) solver.

TD3(;
    π,
    ΔN=50,
    π_smooth::Policy=GaussianNoiseExplorationPolicy(0.1f0, ϵ_min=-0.5f0, ϵ_max=0.5f0),
    π_explore=GaussianNoiseExplorationPolicy(0.1f0),
    a_opt::NamedTuple=(;),
    c_opt::NamedTuple=(;),
    a_loss=td3_actor_loss,
    c_loss=double_Q_loss(),
    target_fn=td3_target,
    prefix="",
    log::NamedTuple=(;),
    𝒫::NamedTuple=(;),
    kwargs...)
source

Imitation Learning

Crux.AdRILFunction

Adversarial Reward-moment Imitation Learning (AdRIL) solver.

AdRIL(;
    π, 
    S,
    ΔN=50,
    solver=SAC, 
    𝒟_demo, 
    normalize_demo::Bool=true,
    expert_frac=0.5, 
    buffer_size = 1000, 
    buffer_init=0,
    log::NamedTuple=(;),
    buffer::ExperienceBuffer = ExperienceBuffer(S, action_space(π), buffer_size, [:i]), 
    kwargs...)
source
Crux.AdVILFunction

Adversarial Value Moment Imitation Learning (AdVIL) solver.

AdVIL(;
    π, 
    S,
    𝒟_demo, 
    normalize_demo::Bool=true, 
    λ_GP::Float32=10f0, 
    λ_orth::Float32=1f-4, 
    λ_BC::Float32=2f-1, 
    a_opt::NamedTuple=(;), 
    c_opt::NamedTuple=(;), 
    log::NamedTuple=(;), 
    kwargs...)
source
Crux.ASAFFunction

Adversarial Soft Advantage Fitting (ASAF) solver.

ASAF(;
    π,
    S,
    𝒟_demo,
    normalize_demo::Bool=true,
    ΔN=50,
    λ_orth=1f-4,
    a_opt::NamedTuple=(;),
    c_opt::NamedTuple=(;),
    log::NamedTuple=(;),
    kwargs...)
source
Crux.BCFunction

Behavioral cloning solver.

BC(;
    π,
    S,
    𝒟_demo,
    normalize_demo::Bool=true,
    loss=nothing,
    validation_fraction=0.3,
    window=100,
    λe::Float32=1f-3,
    opt::NamedTuple=(;),
    log::NamedTuple=(;),
    kwargs...)
source
Crux.OnlineIQLearnFunction

Online Inverse Q-Learning solver.

OnlineIQLearn(;
    π, 
    S, 
    𝒟_demo, 
    γ=Float32(0.9),
    normalize_demo::Bool=true, 
    solver=SoftQ, # or SAC for continuous states 
    log::NamedTuple=(;period=500), 
    reg::Bool=true,
    α_reg=Float32(0.5),
    gp::Bool=true,
    λ_gp=Float32(10.),
    kwargs...)
source
Crux.OffPolicyGAILFunction

Off-policy generative adversarial imitation learning (GAIL) solver.

OffPolicyGAIL(;
    π,
    S, 
    𝒟_demo, 
    𝒟_ndas::Array{ExperienceBuffer} = ExperienceBuffer[], 
    normalize_demo::Bool=true, 
    D::ContinuousNetwork, 
    solver=SAC, 
    d_opt::NamedTuple=(epochs=5,), 
    log::NamedTuple=(;), 
    kwargs...)
source
Crux.OnPolicyGAILFunction

On-policy generative adversarial imitation learning (GAIL) solver.

OnPolicyGAIL(;
    π,
    S,
    γ,
    λ_gae::Float32 = 0.95f0,
    𝒟_demo,
    αr::Float32 = 0.5f0,
    normalize_demo::Bool=true,
    D::ContinuousNetwork,
    solver=PPO,
    gan_loss::GANLoss=GAN_BCELoss(),
    d_opt::NamedTuple=(;),
    log::NamedTuple=(;),
    Rscale=1f0,
    kwargs...)
source
Crux.SQILFunction

Soft Q Imitation Learning (SQIL) solver.

SQIL(;
    π, 
    S, 
    𝒟_demo, 
    normalize_demo::Bool=true, 
    solver=SAC, 
    log::NamedTuple=(;), 
    kwargs...)
source

Adversarial RL

Crux.AdversarialOffPolicySolverType

Adversarial off-policy solver.

  • 𝒮_pro::OffPolicySolver Solver parameters for the protagonist
  • 𝒮_ant::OffPolicySolver Solver parameters for the antagonist
  • px::PolicyParams Nominal disturbance policy
  • train_pro_every::Int = 1
  • train_ant_every::Int = 1
  • log::Union{Nothing, LoggerParams} = nothing The logging parameters
  • i::Int = 0 The current number of environment interactions
source
Crux.RARLFunction

Robust Adversarial RL (RARL) solver.

RARL(;
    𝒮_pro,
    𝒮_ant,
    px,
    log::NamedTuple=(;), 
    train_pro_every::Int=1,
    train_ant_every::Int=1,
    buffer_size=1000, # Size of the buffer
    required_columns=Symbol[:x, :fail],
    buffer::ExperienceBuffer=ExperienceBuffer(𝒮_pro.S, 𝒮_pro.agent.space, buffer_size, required_columns), # The replay buffer
    buffer_init::Int=max(max(𝒮_pro.c_opt.batch_size, 𝒮_ant.c_opt.batch_size), 200) # Number of observations to initialize the buffer with
)
source

Continual Learning

Batched RL

Crux.BatchSolverType

Batch solver type.

Fields

  • agent::PolicyParams Policy parameters (PolicyParams)
  • S::AbstractSpace State space
  • max_steps::Int = 100 Maximum number of steps per episode
  • 𝒟_train Training data
  • param_optimizers::Dict{Any, TrainingParams} = Dict() Training parameters for the parameters
  • a_opt::TrainingParams Training parameters for the actor
  • c_opt::Union{Nothing, TrainingParams} = nothing Training parameters for the discriminator
  • target_fn = nothing the target function for value-based methods
  • target_update = (π⁻, π; kwargs...) -> polyak_average!(π⁻, π, 0.005f0) Function for updating the target network
  • 𝒫::NamedTuple = (;) Parameters of the algorithm
  • log::Union{Nothing, LoggerParams} = nothing The logging parameters
  • required_columns = Symbol[] Extra columns to sample
  • epoch = 0 Number of epochs of training
source
Crux.CQLFunction

Conservative Q-Learning (CQL) solver.

CQL(;
    π::ActorCritic{T, DoubleNetwork{ContinuousNetwork, ContinuousNetwork}},
    solver_type=BatchSAC,
    CQL_α::Float32=1f0,
    CQL_is_distribution=DistributionPolicy(product_distribution([Uniform(-1,1) for i=1:dim(action_space(π))[1]])),
    CQL_α_thresh::Float32=10f0,
    CQL_n_action_samples::Int=10,
    CQL_α_opt::NamedTuple=(;),
    a_opt::NamedTuple=(;), 
    c_opt::NamedTuple=(;), 
    log::NamedTuple=(;),
    kwargs...)
source
Crux.BatchSACFunction

Batched soft actor critic (SAC) solver.

BatchSAC(;
    π::ActorCritic{T, DoubleNetwork{ContinuousNetwork, ContinuousNetwork}}, 
    S,
    ΔN=50, 
    SAC_α::Float32=1f0, 
    SAC_H_target::Float32 = Float32(-prod(dim(action_space(π)))), 
    𝒟_train, 
    SAC_α_opt::NamedTuple=(;), 
    a_opt::NamedTuple=(;), 
    c_opt::NamedTuple=(;), 
    log::NamedTuple=(;), 
    𝒫::NamedTuple=(;), 
    param_optimizers=Dict(), 
    normalize_training_data = true, 
    kwargs...)
source

Policies

Crux.PolicyParamsType

Struct for combining useful policy parameters together

    π::Pol
    space::T2 = action_space(π)
    π_explore = π
    π⁻ = nothing
    pa = nothing # nominal action distribution
source