Trainer
astra_rl.training.trainer
¶
trainer.py
The trainer is an opinionated interface designed for making training new models easy. To gain full customization over the model training pipeline, we recommend using the lower-level Harness
interface in harness.py
.
Trainer
¶
Bases: Generic[StateT, ActionT, Step, Batch]
A high-level trainer that pushbutton trains your policy
Example
Here is an example of how to use the Trainer
class with the DPO algorithm
and an AST problem environment
import torch from astra_rl import ( ... Trainer, ... TrainingConfiguration, ... ) from astra_rl.algorithms.dpo import ( ... DPO, ... ) from astra_rl.methods.ast import ( ... ASTProblem, ... ASTEnvironment, ... )
problem = ( ... ASTProblem() ... ) environment = ( ... ASTEnvironment( ... problem, ... ... ) ... ) algorithm = DPO(...) config = TrainingConfiguration( ... lr=1e-3, ... batch_size=16, ... optimizer="adamw", ... gradient_accumulation_steps=1, ... training_steps=1024, ... num_episodes_per_experience=8, ... ) trainer = Trainer( ... config, ... environment, ... algorithm, ... ) trainer.train()
Attributes:
Name | Type | Description |
---|---|---|
config |
TrainingConfiguration
|
The configuration for the training process. |
harness |
Harness
|
The harness that manages the training loop and interactions with the environment. See |
optimizer |
Optimizer
|
The optimizer used for updating the model parameters. |
_global_step_counter |
int
|
A counter for global steps, used for gradient accumulation. |
Source code in src/astra_rl/training/trainer.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
|
__init__(config, environment, algorithm)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config
|
TrainingConfiguration
|
The configuration for the training process. |
required |
environment
|
Environment
|
The environment to run our algorithm in. |
required |
algorithm
|
Algorithm
|
The algorithm used for training the attacker agent. |
required |
Source code in src/astra_rl/training/trainer.py
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
|
train()
¶
Run training by the specified config!
Note
This method takes no arguments and returns nothing, and its only used for side effects. We don't really need it other than it's helpful for allowing the user to contro when training actually starts (instead of immediately after Trainer construction).
Source code in src/astra_rl/training/trainer.py
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
|
TrainingConfiguration
¶
Bases: BaseModel
A typechecked dataclass which configures the training procedure.
Attributes:
Name | Type | Description |
---|---|---|
lr |
float
|
Learning rate for the optimizer. |
batch_size |
int
|
Size of each batch (after flattening from experience) for training. |
optimizer |
str
|
Type of optimizer to use [choices: "adam", "adamw", "sgd", "rmsprop", "adagrad"]. |
gradient_accumulation_steps |
int
|
Number of steps to accumulate gradients before updating the model weights. |
training_steps |
int
|
Total number of rollouts to run and train for. |
num_episodes_per_experience |
int
|
Number of rollouts to run before making a gradient update. |
Source code in src/astra_rl/training/trainer.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
|