Solver (RL Algorithm)¶
Solvers (a.k.a. algorithms) define how learning happens. They consume rollout graphs from the Environment, ask the Problem for model log-probs/rewards, and return a scalar loss (plus optional logs) to the Trainer. In ASTRA-RL a solver subclasses Algorithm[...]
and typically implements three things:
flatten(graph)
→ turn a rolloutGraph
into per-sample Stepscollate_fn(steps)
→ batch those steps into a Batchstep(batch)
→ compute the training loss and alogs
dict
1. What Solvers Do¶
Given rollouts (graphs of attacker–target turns), a solver decides what examples to learn from (via flatten
), how to batch them (collate_fn
), and what objective to optimize (step
). This keeps “how we learn” separate from:
- Environment: how data is collected/structured (single path vs tree, etc.)
- Problem: how models are run (log-probs, rewards, advance logic)
2. Built-in Solvers/Examples¶
ASTRA-RL includes preference-learning solvers commonly used for LM alignment/red-teaming:
- DPO — Direct Preference Optimization (pairwise preferred vs rejected)
- IPO — Implicit Preference Optimization (margin-style objective over log-ratio differences)
- PPO - Proximal Policy Optimization
These serve as concrete references for writing your own solver. Find the code for these solvers here!
3. Ways to Customize¶
3.1 Fast path: adapt a built-in (e.g., DPO → IPO)¶
If your rollout selection and batching are the same, you can reuse flatten
and collate_fn
and only change the loss in step
. IPO in our codebase demonstrates this pattern by inheriting from DPO and overriding step
. Therefore, if you are only making a small change to how the loss is calculated, a great option would be to inheret from the DPO, IPO or PPO and ovverid 'step' to include your custom loss calculation.
3.2 Full control: subclass Algorithm
¶
When your algorithm needs a different sampling strategy, subclass Algorithm[...]
and implement flatten
, collate_fn
, and step
to match your data/learning objective.
4. Required Interface¶
4.1 Step/Batch data contracts¶
Define explicit dataclasses that encode exactly what your algorithm needs.
from dataclasses import dataclass
from typing import Generic, Sequence
from astra_rl.core.common import StateT, ActionT
@dataclass
class MyStep(Generic[StateT, ActionT]):
context: StateT
action: ActionT
reward: float # or advantage/return/log-ratio/etc.
@dataclass
class MyBatch(Generic[StateT, ActionT]):
contexts: Sequence[StateT]
actions: Sequence[ActionT]
rewards: torch.Tensor # tensors for math
Keep these minimal and algorithm-specific. They are the contract between your data selection (
flatten
) and your loss (step
).
4.2 flatten
, collate_fn
, step
contracts¶
flatten(graph: Graph) -> Sequence[Step]
Select and transform nodes/edges from the rollout graph into per-sampleStep
s (BFS/DFS as you like).collate_fn(steps: Sequence[Step]) -> Batch
Convert a list of steps into batched tensors/sequences for efficient training.step(batch: Batch) -> tuple[torch.Tensor, dict]
Compute a scalar loss (used for backprop) and a logs dict of floats (the base trainer may ignore them; custom trainers can log them).
4.3 Interacting with Problem
¶
Your solver calls into the Problem
for model computations:
problem._get_attacker_logprobs_and_validate(contexts, actions)
problem._get_baseline_logprobs_and_validate(contexts, actions)
- optionally:
problem.get_target_logprobs(...)
,problem.reward(...)
, etc.
Tip: Target/baseline log-prob calls usually should be in torch.no_grad()
; the attacker’s log-probs must require grad.
5. Best Practices & Sanity Checks¶
- Pairwise methods need width ≥ 2. For DPO/IPO, set
tree_width >= 2
so each context has at least two candidate actions. - Stable scales. Keep losses well-scaled (e.g., use a
beta
like in DPO/IPO). Normalize or clip rewards if needed. - Efficient batching. Vectorize log-prob calls; avoid per-item model runs.
- Validate shapes. Collated tensors must be aligned and same length.
- Freeze the ref/baseline. Only attacker params should receive gradients.
- KL anchor (when applicable). If training drifts, increase KL pressure (or use adaptive control) where appropriate.
6. Plug into the Trainer¶
Instantiate and pass your solver to the trainer:
solver = DPO(problem, beta=0.1) # or IPO(...)
trainer = Trainer(config=config, environment=env, algorithm=solver)
trainer.train()
Under the hood, the Trainer will:
- collect rollout graphs,
- call your solver’s
flatten
to produceSteps
, - use your solver’s
collate_fn
to form batches, and - call your solver’s
step
to get(loss, logs)
.
The base
Trainer
usesloss
for optimization and may ignorelogs
. UseHFASTTrainer
or a custom trainer to evaluate and checkpoint.
7. Debug Checklist¶
- Shapes match:
len(prefixes) == len(pos) == len(neg)
(or analogous fields). - Gradients only through attacker: wrap baseline/target log-prob calls in
torch.no_grad()
if you surface them directly. - Finite values: check for
nan/inf
in losses and rewards (clip/normalize if necessary). - Tree width OK: preference solvers require
tree_width ≥ 2
. - KL anchor: if the attacker drifts, increase β or add an explicit KL penalty to the loss.
- Determinism: set seeds and/or make selection in
flatten
deterministic to repro bugs.