Examples
For a full set of examples, please see the examples/
directory.
- Reinforcement learning examples:
- Imitation learning examples:
- Adversarial RL examples:
- Offline RL examples:
Minimal RL Example
As a minimal example, we'll show how to set up a cart-pole problem and solve it with a simple Flux network using the REINFORCE algorithm.
using Crux, POMDPGym
# Problem setup
mdp = GymPOMDP(:CartPole)
as = actions(mdp)
S = state_space(mdp)
# Flux network: Map states to actions
A() = DiscreteNetwork(Chain(Dense(dim(S)..., 64, relu), Dense(64, length(as))), as)
# Setup REINFORCE solver
solver_reinforce = REINFORCE(S=S, π=A())
# Solve the `mdp` to get the `policy`
policy_reinforce = solve(solver_reinforce, mdp)
You can run other algorithms, such as A2C and PPO, to generate different policies:
# Set up the critic network for actor-critic algorithms
V() = ContinuousNetwork(Chain(Dense(dim(S)..., 64, relu), Dense(64, 1)))
solver_a2c = A2C(S=S, π=ActorCritic(A(), V()))
policy_a2c = solve(solver_a2c, mdp)
solver_ppo = PPO(S=S, π=ActorCritic(A(), V()))
policy_ppo = solve(solver_ppo, mdp)
You also may want to adjust the number of environment interactions N
or the number of interactions between updates ΔN
:
solver_reinforce = REINFORCE(S=S, π=A(), N=10_000, ΔN=500)
policy_reinforce = solve(solver_reinforce, mdp)
solver_a2c = A2C(S=S, π=ActorCritic(A(), V()), N=10_000, ΔN=500)
policy_a2c = solve(solver_a2c, mdp)
solver_ppo = PPO(S=S, π=ActorCritic(A(), V()), N=10_000, ΔN=500)
policy_ppo = solve(solver_ppo, mdp)
Plotting and Animations
You can take the above results and plot the learning curves:
p = plot_learning([solver_reinforce, solver_a2c, solver_ppo],
title="CartPole Training Curves",
labels=["REINFORCE", "A2C", "PPO"])
Crux.savefig(p, "cartpole_training.pdf")
Here's an example for the half cheetah MuJoCo problem, comparing four RL algorithms from examples/rl/half_cheetah_mujoco.jl
.
You can also create an animated gif of the final policy:
gif(mdp, policy_ppo, "cartpole_policy.gif", max_steps=100)
Note: You may need to install pygame
via pip install "gymnasium[classic-control]"