Trainer¶
Trainers run the optimization loop that updates your attacker. They wire together the environment (rollout collection), the algorithm/solver (computes a loss from rollouts), and the optimizer (updates model weights). In ASTRA-RL you can use a minimal, no-frills base trainer or a preconfigured, Hugging Face–friendly trainer that handles evaluation and checkpointing.
This guide explains what a trainer does, what ASTRA-RL ships with, and how to implement or customize your own trainer and training configuration.
1. What Trainers Do¶
The base trainer in astra_rl/training/trainer.py
is responsible for:
- Optimizer setup — creates the optimizer that updates the attacker’s weights.
- Harness orchestration — uses the Harness to collect rollouts and feed batches to your Algorithm (solver).
- Main training loop — calls
train()
to iterate fortraining_steps
.
Important: the base loop is intentionally minimal. It does not perform evaluation, checkpointing, or external logging.
Note
The Harness invokes your Algorithm on each batch and returns (loss, step_logs)
. The base Trainer
uses the loss for backprop and discards step_logs
. Use HFASTTrainer
or subclass Trainer
if you need logging/eval/checkpointing.
2. Built-in Trainers¶
Trainer
(base) — minimal loop: collect → compute loss → optimize. No eval/saving/logging.HFASTTrainer
— Hugging Face–friendly trainer that performs periodic dev evaluation and saves HF checkpoints, driven byHFASTConfiguration
(or your own config).
3. Ways to Customize¶
3.1 Fast path: use the base Trainer
¶
If you just want a lean optimization loop (no eval/checkpointing/logging), use Trainer
with a TrainingConfiguration
. See 6.1.
3.2 Fast path: HF-compatible trainer (HFASTTrainer
)¶
If you want periodic dev evaluation and automatic Hugging Face checkpointing, use HFASTTrainer
with HFASTConfiguration
. See 6.2.
3.3 Full control: subclass Trainer
¶
Need custom evaluation cadence, model-saving policy, learning-rate schedules, gradient accumulation, early stopping, or logging destinations (e.g., Weights & Biases)? Subclass Trainer
and override train()
(and optional helpers). See 6.3 and 6.4.
4. Required Interface¶
4.1 TrainingConfiguration
knobs¶
Instantiate TrainingConfiguration
with the hyperparameters you care about. These values drive how the trainer and optimizer are initialized.
from astra_rl import TrainingConfiguration
config = TrainingConfiguration(
lr=1e-5,
batch_size=4,
optimizer="adamw", # one of ["adam", "adamw", "sgd", "rmsprop", "adagrad"]
gradient_accumulation_steps=1, # call optimizer.step() every N backward passes
training_steps=1000, # number of experience() calls
num_episodes_per_experience=2, # rollouts sampled per experience() call
)
training_steps
= number of Harnessexperience()
iterations.- Approx. total rollouts ≈
training_steps × num_episodes_per_experience
.
4.2 What the Trainer expects from the Harness/Algorithm¶
- The Harness returns an iterable/batches of experiences.
- For each batch, the trainer calls the Algorithm (solver) via the Harness to process the experience and obtain:
python
loss, step_logs = harness.step(batch)
loss
: a scalar tensor used for backprop.step_logs
: optional dict of scalars (e.g., reward stats). The base trainer ignores these; custom trainers can log them.
5. Best Practices & Sanity Checks¶
- Don’t hack the Harness unless you truly need different data-collection semantics; most customization belongs in the trainer/config/environment/algorithm.
- Detach when logging:
logs["loss"] = float(loss.detach().item())
to avoid holding computation graphs. - Checkpoint sensibly: attacker + tokenizer is usually enough; if your algorithm has extra state, save it too.
- Batching vs. accumulation: prefer reasonable batch sizes; use
gradient_accumulation_steps
when memory is tight. - Reproducibility: seed PyTorch/NumPy and pass a
seed
through your environment when possible. - Validation cadence: dev eval can be slow—choose
eval_every
that matches your budget.
6. How-Tos¶
6.1 Minimal usage of the base Trainer
¶
A simple loop that optimizes the algorithm loss with no eval, checkpoints, or logging:
from astra_rl import Trainer
trainer = Trainer(
config=config,
environment=env,
algorithm=solver,
)
trainer.train()
# Optionally save the final HF model/tokenizer:
problem.attacker.save_pretrained("final_ckpt")
problem.tokenizer.save_pretrained("final_ckpt")
6.2 Periodic eval + HF checkpoints via HFASTTrainer
¶
Use the preconfigured, Hugging Face–compatible trainer for periodic dev evaluation and automatic checkpointing.
from astra_rl.ext.transformers.hf_ast_problem import (
HFASTTrainer,
HFASTConfiguration,
)
config = HFASTConfiguration() # or your own TrainingConfiguration
trainer = HFASTTrainer(
config=config,
environment=env,
algorithm=solver,
dev_prompts=DEV_PROMPTS, # iterable of prompts for evaluation
eval_every=100, # run dev eval every N steps
ckpt_dir="checkpoints", # HF-format checkpoints saved here
)
trainer.train()
6.3 Write a custom trainer with eval, saving, and grad accumulation¶
Subclass Trainer
to add evaluation cadence, HF-style saving, gradient accumulation, and logging.
import os
import torch
from astra_rl import Trainer, TrainingConfiguration
from astra_rl.logging import logger
class MyConfig(TrainingConfiguration):
def __init__(self):
super().__init__(
lr=1e-5,
batch_size=4,
optimizer="adamw",
gradient_accumulation_steps=1,
training_steps=1000,
num_episodes_per_experience=2,
)
# Custom fields for your subclass:
self.eval_every = 100
self.ckpt_dir = "checkpoints"
class MyTrainer(Trainer):
"""
Extends the base trainer with:
- periodic dev-set evaluation
- HF-format checkpointing
- optional grad accumulation
"""
def __init__(self, config: MyConfig, environment, algorithm, dev_prompts=None):
super().__init__(config, environment, algorithm)
self.dev_prompts = dev_prompts or []
os.makedirs(self.config.ckpt_dir, exist_ok=True)
# optional but encouraged
def _save_hf(self, step: int) -> None:
"""Save your attacker and its tokenizer"""
# optional method
@torch.no_grad()
def _eval_dev(self, step: int, tag: str = "dev"):
"""Run a lightweight evaluation on dev prompts. Fill in your logic."""
pass
# required method
def train(self):
"""Implement your custom training loop!"""
# see astra_rl.ext.transformers.hf_ast_problem for an implemented custom class example
pass
6.4 Early stopping or custom LR schedules¶
Inside your custom train()
:
- Early stopping: track a validation metric (e.g., dev score from
_eval_dev
) and stop afterx
steps with no improvement. - LR scheduling: instantiate a PyTorch scheduler after the optimizer and call
scheduler.step()
each iteration or epoch; store scheduler state alongside checkpoints if you need exact resumption.
7. Full Examples¶
- Custom AST problem with trainer:
examples/GPT2_v_GPT2/ast_trainer.py
- Source for HF-compatible trainer/config:
astra_rl/ext/transformers/hf_ast_problem.py
Use these as references when writing up your own training loop or extending the provided trainers.