Examples
For a full set of examples, please see the examples/ directory.
- Reinforcement learning examples:
- Imitation learning examples:
- Adversarial RL examples:
- Offline RL examples:
Minimal RL Example (Pure Julia)
As a minimal self-contained example, we'll show how to solve a GridWorld problem using DQN with no external dependencies:
using Crux
mdp = SimpleGridWorld()
S = state_space(mdp)
A() = DiscreteNetwork(Chain(Dense(2, 8, relu), Dense(8, 4)), actions(mdp))
solver = DQN(π=A(), S=S, N=100_000)
policy = solve(solver, mdp)Gym Environments
For OpenAI Gym environments like CartPole, you need to install POMDPGym.jl:
] add https://github.com/ancorso/POMDPGym.jlPOMDPGym requires Python with Gymnasium installed (pip install gymnasium).
Here's an example using REINFORCE to solve CartPole:
using Crux, POMDPGym
# Problem setup
mdp = GymPOMDP(:CartPole)
as = actions(mdp)
S = state_space(mdp)
# Flux network: Map states to actions
A() = DiscreteNetwork(Chain(Dense(dim(S)..., 64, relu), Dense(64, length(as))), as)
# Setup REINFORCE solver
solver_reinforce = REINFORCE(S=S, π=A())
# Solve the `mdp` to get the `policy`
policy_reinforce = solve(solver_reinforce, mdp)You can run other algorithms, such as A2C and PPO, to generate different policies:
# Set up the critic network for actor-critic algorithms
V() = ContinuousNetwork(Chain(Dense(dim(S)..., 64, relu), Dense(64, 1)))
solver_a2c = A2C(S=S, π=ActorCritic(A(), V()))
policy_a2c = solve(solver_a2c, mdp)
solver_ppo = PPO(S=S, π=ActorCritic(A(), V()))
policy_ppo = solve(solver_ppo, mdp)You also may want to adjust the number of environment interactions N or the number of interactions between updates ΔN:
solver_reinforce = REINFORCE(S=S, π=A(), N=10_000, ΔN=500)
policy_reinforce = solve(solver_reinforce, mdp)
solver_a2c = A2C(S=S, π=ActorCritic(A(), V()), N=10_000, ΔN=500)
policy_a2c = solve(solver_a2c, mdp)
solver_ppo = PPO(S=S, π=ActorCritic(A(), V()), N=10_000, ΔN=500)
policy_ppo = solve(solver_ppo, mdp)Plotting and Animations
You can take the above results and plot the learning curves:
p = plot_learning([solver_reinforce, solver_a2c, solver_ppo],
title="CartPole Training Curves",
labels=["REINFORCE", "A2C", "PPO"])
Crux.savefig(p, "cartpole_training.pdf")Here's an example for the half cheetah MuJoCo problem, comparing four RL algorithms from examples/rl/half_cheetah_mujoco.jl.
You can also create an animated gif of the final policy:
gif(mdp, policy_ppo, "cartpole_policy.gif", max_steps=100)Note: You may need to install pygame via pip install "gymnasium[classic-control]"