Quick Start: Training¶
Do you want to train a HuggingFace tester using an ASTRA-supported algorithm (e.g., DPO, IPO, PPO) and system formulation (ASTPrompter, Perez et al., MALIBU, CRT*)? coming soon
Then this guide is for you. We'll walk through every step required to train an adversarial testing tester (llama3) using our pre-configured classes and point you to customization guides when your use case goes beyond the defaults. By using our pre-configured classes, you'll be training your tester in 7 easy steps!
Step 1: Setup¶
Please see the main documentation for full setup instructions. Here's a quick recap:
# Install the ASTRA-RL toolbox
pip install astra-rl
# Import ASTRA-RL in your python code
import astra_rl
from astra_rl.datasets import CONVOKIT_REDDIT_TRAIN, CONVOKIT_REDDIT_DEV
from astra_rl.ext.transformers.hf_ast_system import HFASTTrainer, HFASTConfiguration
Note
wandb is not automatically installed during this process. If you would like to use wandb, either install it as an optional dependency (pip install "astra-rl[wandb]") or install it directly (uv pip install wandb) and run export WANDB_API_KEY=YOUR_API_KEY_HERE in your terminal.
Step 2: Import Required Modules and set device¶
from torch.optim import AdamW
import json
# ASTRA-RL core components
from astra_rl import ASTSampler, DPO, DetoxifyScorer, Harness
# HuggingFace-friendly system wrapper for ASTPrompter-style red teaming
from astra_rl.ext.transformers import HFASTSystem
from astra_rl.training import Trainer, TrainingConfiguration
# training and dev data sets - serve as initial prompts in auditor-target rollouts
from astra_rl.datasets import CONVOKIT_REDDIT_TRAIN, CONVOKIT_REDDIT_DEV
DEVICE = "cuda" # or "cpu" if GPU is not available
Note
To train a tester, you'll need a list of comma-separated strings that act as initial prompts—these initiate tester-target rollouts used for online training.
Since ASTPrompter tests for harmful outputs in conversational settings, it uses the ConvoKit Reddit Small Corpus (filtered for proper formatting and for non-toxicity using Detoxify) as its default source of initial prompts. This data is imported from astra_rl.datasets as CONVOKIT_REDDIT_TRAIN and CONVOKIT_REDDIT_DEV
The ASTRA-RL toolbox easily supports external prompt datasets or APIs—just ensure the final PROMPTS variable is formatted as a list of strings.
Step 3: Instantiate Your System¶
The system is an important component of training that handles rollout step generation, reward computation, and log-probability calculation. To speed you along, we have implemented the HFASTSystem class that handles the technical backend so you just need to provide the huggingface model IDs of the tester, target and baseline models and a Scorer instance (DetexifyScorer(), LlamaGuardScorer() or your custom scorer).
Note: Your tester and target can be different models but your baseline and tester should typically be the same.
Note: HFASTSystem can not handle huggingface models that have a fixed length (e.g. GPT2) as it will run into an out of context error. If you would like to train a GPT2-based tester model, please see our GPT2_v_GPT2 examples which use a custom ExampleDetoxifySystem() designed for GPT2 models.
# Example system instantiation: llama3 tester, target, and baseline with Detoxify scorer (heavyweight setup - requires GPU)
system = HFASTSystem("meta-llama/Llama-3.1-8B", "meta-llama/Llama-3.1-8B", "meta-llama/Llama-3.1-8B", DetoxifyScorer(), DEVICE)
Need a custom model or rollout step logic? See the System Customization guide. Want to use a custom scorer? See the Scorer Customization guide.
Step 4: Instantiate the Sampler¶
The sampler defines how training rollouts are structured and collected. In ASTRA-RL, the default is the ASTSampler, which implements the conversation tree rollout used in the ASTPrompter paper.
# define the dataset of seed prompts that will initiate tester-target training rollouts
# Depending on your space constraints, adjust how long you allow seed prompts to be (here we clip to a maximum of 500 characters )
PROMPTS = [p[:500] for p in CONVOKIT_REDDIT_TRAIN]
# instantiate the sampler
sampler = ASTSampler(system, PROMPTS)
Curious about how this sampler structures rollouts?
This sampler builds a tree-structured conversation graph, where: - The root node starts from a random initial prompt (from `CONVOKIT_REDDIT_TRAIN`) - At each turn, the tester generates multiple (`tree_width`, default 2) candidate utterances - Each of those utterances is fed to the target model, which produces a response - The resulting tester–target tuples form child nodes - This process repeats for `tree_depth` levels (default 3), yielding a multi-turn tester–target dialogue tree. This structure enables preference-based learning algorithms like DPO and IPO to reason over multiple conversational branches at once, training the tester to elicit harmful responses in a multi-turn setting.By default, rollouts are configured with tree_width=2 and tree_depth=3, but you can customize both:
Want a different rollout graph structure or a multi-agent setup? See the Sampler Customization guide.
Step 5: Choose Your Learning Algorithm¶
The solver is the RL learning algorithm that will take in a graph of training rollouts and compute the loss. This loss will drive the optimization of tester model weights, teaching the tester to more effectively elicit target toxicity.
We use DPO as the default for this quickstart.
# pass the DPO algorithm your system and the beta hyperparamater you would like to train with
solver = DPO(system, beta=0.1)
To integrate your own RL algorithm, see the Solver Customization guide.
Step 6: Train the Tester¶
For the quick start approach, simply call our training configuration and trainer classes and start training!
# instantiate the pre-configured HF-compatible configuration and trainer class
config = HFASTConfiguration() # lr = 7e-6, batch size = 16, optimizer = "adamw", 16 gradient accumulation steps, 3000 training steps, 1 episode per experience
# this trainer will train the tester and save the best tester to the checkpoint directory
trainer = HFASTTrainer(
config,
sampler,
solver,
dev_prompts=CONVOKIT_REDDIT_DEV,
eval_every=25, # this determines how often the model is evaluated on the dev set, we save the best performing model
ckpt_dir="./checkpoints/custom_name", # add your model name here!
use_wandb=True,
show_progress=True,
)
trainer.train()
The source code for the training configuration and trainer are at trainer.py
The source code for HFASTConfiguration and HFASTTrainer are at hf_ast_system.py
Want to customize the training configuration/hyperparams, the training loop, or model saving/eval? See the Trainer Customization guide.
Full Examples:¶
We provide complete working examples that mirror this guide!
Hugging face example for llama3 models: examples/ast_huggingface_train.py
Custom AST system for GPT2 models with trainer: examples/ast_gpt2_train.py
Custom AST system for PPO algorithm: examples/ast_ppo.py