Examples

For a full set of examples, please see the examples/ directory.

Minimal RL Example

As a minimal example, we'll show how to set up a cart-pole problem and solve it with a simple Flux network using the REINFORCE algorithm.

using Crux, POMDPGym

# Problem setup
mdp = GymPOMDP(:CartPole)
as = actions(mdp)
S = state_space(mdp)

# Flux network: Map states to actions
A() = DiscreteNetwork(Chain(Dense(dim(S)..., 64, relu), Dense(64, length(as))), as)

# Setup REINFORCE solver
solver_reinforce = REINFORCE(S=S, π=A())

# Solve the `mdp` to get the `policy`
policy_reinforce = solve(solver_reinforce, mdp)

You can run other algorithms, such as A2C and PPO, to generate different policies:

# Set up the critic network for actor-critic algorithms
V() = ContinuousNetwork(Chain(Dense(dim(S)..., 64, relu), Dense(64, 1)))

solver_a2c = A2C(S=S, π=ActorCritic(A(), V()))
policy_a2c = solve(solver_a2c, mdp)

solver_ppo = PPO(S=S, π=ActorCritic(A(), V()))
policy_ppo = solve(solver_ppo, mdp)

You also may want to adjust the number of environment interactions N or the number of interactions between updates ΔN:

solver_reinforce = REINFORCE(S=S, π=A(), N=10_000, ΔN=500)
policy_reinforce = solve(solver_reinforce, mdp)

solver_a2c = A2C(S=S, π=ActorCritic(A(), V()), N=10_000, ΔN=500)
policy_a2c = solve(solver_a2c, mdp)

solver_ppo = PPO(S=S, π=ActorCritic(A(), V()), N=10_000, ΔN=500)
policy_ppo = solve(solver_ppo, mdp)

Plotting and Animations

You can take the above results and plot the learning curves:

p = plot_learning([solver_reinforce, solver_a2c, solver_ppo],
                  title="CartPole Training Curves",
                  labels=["REINFORCE", "A2C", "PPO"])
Crux.savefig(p, "cartpole_training.pdf")

Here's an example for the half cheetah MuJoCo problem, comparing four RL algorithms from examples/rl/half_cheetah_mujoco.jl.

mujoco

You can also create an animated gif of the final policy:

gif(mdp, policy_ppo, "cartpole_policy.gif", max_steps=100)

Note: You may need to install pygame via pip install "gymnasium[classic-control]"