Skip to content

Trainer

astra_rl.training.trainer

trainer.py The trainer is an opinionated interface designed for making training new models easy. To gain full customization over the model training pipeline, we recommend using the lower-level Harness interface in harness.py.

Trainer

Bases: Generic[StateT, ActionT, Step, Batch]

A high-level trainer that pushbutton trains your policy

Example

Here is an example of how to use the Trainer class with the DPO algorithm and an AST problem environment

import torch from astra_rl import ( ... Trainer, ... TrainingConfiguration, ... ) from astra_rl.algorithms.dpo import ( ... DPO, ... ) from astra_rl.methods.ast import ( ... ASTProblem, ... ASTEnvironment, ... )

problem = ( ... ASTProblem() ... ) environment = ( ... ASTEnvironment( ... problem, ... ... ) ... ) algorithm = DPO(...) config = TrainingConfiguration( ... lr=1e-3, ... batch_size=16, ... optimizer="adamw", ... gradient_accumulation_steps=1, ... training_steps=1024, ... num_episodes_per_experience=8, ... ) trainer = Trainer( ... config, ... environment, ... algorithm, ... ) trainer.train()

Attributes:

Name Type Description
config TrainingConfiguration

The configuration for the training process.

harness Harness

The harness that manages the training loop and interactions with the environment. See astra_rl.training.harness for what it does.

optimizer Optimizer

The optimizer used for updating the model parameters.

_global_step_counter int

A counter for global steps, used for gradient accumulation.

Source code in src/astra_rl/training/trainer.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class Trainer(Generic[StateT, ActionT, Step, Batch]):
    """A high-level trainer that pushbutton trains your policy

    Example:
        Here is an example of how to use the `Trainer` class with the DPO algorithm
        and an AST problem environment

        >>> import torch
        >>> from astra_rl import (
        ...     Trainer,
        ...     TrainingConfiguration,
        ... )
        >>> from astra_rl.algorithms.dpo import (
        ...     DPO,
        ... )
        >>> from astra_rl.methods.ast import (
        ...     ASTProblem,
        ...     ASTEnvironment,
        ... )
        >>>
        >>> problem = (
        ...     ASTProblem()
        ... )
        >>> environment = (
        ...     ASTEnvironment(
        ...         problem, ...
        ...     )
        ... )
        >>> algorithm = DPO(...)
        >>> config = TrainingConfiguration(
        ...     lr=1e-3,
        ...     batch_size=16,
        ...     optimizer="adamw",
        ...     gradient_accumulation_steps=1,
        ...     training_steps=1024,
        ...     num_episodes_per_experience=8,
        ... )
        >>> trainer = Trainer(
        ...     config,
        ...     environment,
        ...     algorithm,
        ... )
        >>> trainer.train()

    Attributes:
        config (TrainingConfiguration): The configuration for the training process.
        harness (Harness): The harness that manages the training loop and interactions with the environment. See `astra_rl.training.harness` for what it does.
        optimizer (Optimizer): The optimizer used for updating the model parameters.
        _global_step_counter (int): A counter for global steps, used for gradient accumulation.
    """

    optimizer: Optimizer

    def __init__(
        self,
        config: TrainingConfiguration,
        environment: Environment[StateT, ActionT],
        algorithm: Algorithm[StateT, ActionT, Step, Batch],
    ):
        """
        Args:
            config (TrainingConfiguration): The configuration for the training process.
            environment (Environment): The environment to run our algorithm in.
            algorithm (Algorithm): The algorithm used for training the attacker agent.
        """

        self.config = config
        self.harness = Harness(
            environment, algorithm, config.num_episodes_per_experience
        )

        # TODO initialize LR scheduler?
        # ?????????????????????????????

        # initialize optimizer
        if config.optimizer == "adam":
            from torch.optim import Adam

            self.optimizer = Adam(environment.problem.parameters(), config.lr)
        elif config.optimizer == "adamw":
            from torch.optim import AdamW

            self.optimizer = AdamW(environment.problem.parameters(), config.lr)
        elif config.optimizer == "sgd":
            from torch.optim import SGD

            self.optimizer = SGD(environment.problem.parameters(), config.lr)
        elif config.optimizer == "rmsprop":
            from torch.optim import RMSprop

            self.optimizer = RMSprop(environment.problem.parameters(), config.lr)
        elif config.optimizer == "adagrad":
            from torch.optim import Adagrad

            self.optimizer = Adagrad(environment.problem.parameters(), config.lr)
        else:
            raise ValueError(f"Unknown optimizer configured: {config.optimizer}")

        # step counter, for acccmulutaion, etc.
        self._global_step_counter = 0

    def train(self) -> None:
        """Run training by the specified config!

        Note:
            This method takes no arguments and returns nothing, and its
            only used for side effects. We don't really need it other than
            it's helpful for allowing the user to contro when training
            actually starts (instead of immediately after Trainer construction).
        """
        for _ in range(self.config.training_steps):
            buf = self.harness.experience()
            for batch in buf:
                # increment counter first for occumulation
                self._global_step_counter += 1
                loss: torch.Tensor = (
                    self.harness.step(batch)[0]
                    / self.config.gradient_accumulation_steps
                )
                # typing disabled here b/c mypy can't statically verify
                # that the loss has gradients
                loss.backward()  # type: ignore[no-untyped-call]

                # if gradient accumulation happens, step!
                if (
                    self._global_step_counter % self.config.gradient_accumulation_steps
                    == 0
                ):
                    self.optimizer.step()
                    self.optimizer.zero_grad()

__init__(config, environment, algorithm)

Parameters:

Name Type Description Default
config TrainingConfiguration

The configuration for the training process.

required
environment Environment

The environment to run our algorithm in.

required
algorithm Algorithm

The algorithm used for training the attacker agent.

required
Source code in src/astra_rl/training/trainer.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def __init__(
    self,
    config: TrainingConfiguration,
    environment: Environment[StateT, ActionT],
    algorithm: Algorithm[StateT, ActionT, Step, Batch],
):
    """
    Args:
        config (TrainingConfiguration): The configuration for the training process.
        environment (Environment): The environment to run our algorithm in.
        algorithm (Algorithm): The algorithm used for training the attacker agent.
    """

    self.config = config
    self.harness = Harness(
        environment, algorithm, config.num_episodes_per_experience
    )

    # TODO initialize LR scheduler?
    # ?????????????????????????????

    # initialize optimizer
    if config.optimizer == "adam":
        from torch.optim import Adam

        self.optimizer = Adam(environment.problem.parameters(), config.lr)
    elif config.optimizer == "adamw":
        from torch.optim import AdamW

        self.optimizer = AdamW(environment.problem.parameters(), config.lr)
    elif config.optimizer == "sgd":
        from torch.optim import SGD

        self.optimizer = SGD(environment.problem.parameters(), config.lr)
    elif config.optimizer == "rmsprop":
        from torch.optim import RMSprop

        self.optimizer = RMSprop(environment.problem.parameters(), config.lr)
    elif config.optimizer == "adagrad":
        from torch.optim import Adagrad

        self.optimizer = Adagrad(environment.problem.parameters(), config.lr)
    else:
        raise ValueError(f"Unknown optimizer configured: {config.optimizer}")

    # step counter, for acccmulutaion, etc.
    self._global_step_counter = 0

train()

Run training by the specified config!

Note

This method takes no arguments and returns nothing, and its only used for side effects. We don't really need it other than it's helpful for allowing the user to contro when training actually starts (instead of immediately after Trainer construction).

Source code in src/astra_rl/training/trainer.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def train(self) -> None:
    """Run training by the specified config!

    Note:
        This method takes no arguments and returns nothing, and its
        only used for side effects. We don't really need it other than
        it's helpful for allowing the user to contro when training
        actually starts (instead of immediately after Trainer construction).
    """
    for _ in range(self.config.training_steps):
        buf = self.harness.experience()
        for batch in buf:
            # increment counter first for occumulation
            self._global_step_counter += 1
            loss: torch.Tensor = (
                self.harness.step(batch)[0]
                / self.config.gradient_accumulation_steps
            )
            # typing disabled here b/c mypy can't statically verify
            # that the loss has gradients
            loss.backward()  # type: ignore[no-untyped-call]

            # if gradient accumulation happens, step!
            if (
                self._global_step_counter % self.config.gradient_accumulation_steps
                == 0
            ):
                self.optimizer.step()
                self.optimizer.zero_grad()

TrainingConfiguration

Bases: BaseModel

A typechecked dataclass which configures the training procedure.

Attributes:

Name Type Description
lr float

Learning rate for the optimizer.

batch_size int

Size of each batch (after flattening from experience) for training.

optimizer str

Type of optimizer to use [choices: "adam", "adamw", "sgd", "rmsprop", "adagrad"].

gradient_accumulation_steps int

Number of steps to accumulate gradients before updating the model weights.

training_steps int

Total number of rollouts to run and train for.

num_episodes_per_experience int

Number of rollouts to run before making a gradient update.

Source code in src/astra_rl/training/trainer.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class TrainingConfiguration(BaseModel):
    """A typechecked dataclass which configures the training procedure.

    Attributes:
        lr (float): Learning rate for the optimizer.
        batch_size (int): Size of each batch (after flattening from experience) for training.
        optimizer (str): Type of optimizer to use [choices: "adam", "adamw", "sgd", "rmsprop", "adagrad"].
        gradient_accumulation_steps (int): Number of steps to accumulate gradients before updating the model weights.
        training_steps (int): Total number of rollouts to run and train for.
        num_episodes_per_experience (int): Number of rollouts to run before making a gradient update.
    """

    # optimization configuration
    lr: float = 3e-3
    batch_size: int = 16
    optimizer: str = "adamw"
    gradient_accumulation_steps: int = 1  # how many

    # training configuration
    training_steps: int = 1024  # how many rollouts to train for

    # rollout configuration
    num_episodes_per_experience: int = 8  # how many rollouts per gradient update