Skip to content

Harness

astra_rl.training.harness

Harness

Bases: Generic[StateT, ActionT, Step, Batch]

Harness for running an algorithm in a given environment.

Example:

Here is an example of how to use the `Harness` class with the DPO algorithm
and an AST problem environment for *one episode only*. You should add your
own optimization things such as weight decay or scheduling and figure out
early stopping, etc.

>>> import torch
>>> from astra_rl.training.harness import (
...     Harness,
... )
>>> from astra_rl.algorithms.dpo import (
...     DPO,
... )
>>> from astra_rl.methods.ast import (
...     ASTProblem,
...     ASTEnvironment,
... )
>>>
>>> problem = (
...     ASTProblem()
... )
>>> environment = (
...     ASTEnvironment(
...         problem, ...
...     )
... )
>>> algorithm = DPO(...)
>>> harness = Harness(
...     environment,
...     algorithm,
... )
>>> optimizer = torch.optim.Adam(
...     problem.parameters(),
...     lr=1e-4,
... )
>>>
>>> for batch in harness.experience():
...     loss = harness.step(
...         batch
...     )
...     loss.backward()
...     optimizer.zero_grad()

Attributes:

Name Type Description
environment Environment[StateT, ActionT]

The environment to run the algorithm in.

algorithm Algorithm[StateT, ActionT, Step, Batch]

The algorithm to run.

num_episodes_per_experience int

Number of episodes per call to .experience().

dataloader_kwargs Dict[str, Any]

Keyword arguments for the PyTorch data loader. Batch size, for instance, should be set.

Generics

StateT (type): The type of the state in the environment. ActionT (type): The type of the action in the environment. Step (type): The type of a single step in the environment. Batch (type): The type of a batch of steps, passed to the .step() function for gradient.

Source code in src/astra_rl/training/harness.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class Harness(Generic[StateT, ActionT, Step, Batch]):
    """Harness for running an algorithm in a given environment.

    Example:

        Here is an example of how to use the `Harness` class with the DPO algorithm
        and an AST problem environment for *one episode only*. You should add your
        own optimization things such as weight decay or scheduling and figure out
        early stopping, etc.

        >>> import torch
        >>> from astra_rl.training.harness import (
        ...     Harness,
        ... )
        >>> from astra_rl.algorithms.dpo import (
        ...     DPO,
        ... )
        >>> from astra_rl.methods.ast import (
        ...     ASTProblem,
        ...     ASTEnvironment,
        ... )
        >>>
        >>> problem = (
        ...     ASTProblem()
        ... )
        >>> environment = (
        ...     ASTEnvironment(
        ...         problem, ...
        ...     )
        ... )
        >>> algorithm = DPO(...)
        >>> harness = Harness(
        ...     environment,
        ...     algorithm,
        ... )
        >>> optimizer = torch.optim.Adam(
        ...     problem.parameters(),
        ...     lr=1e-4,
        ... )
        >>>
        >>> for batch in harness.experience():
        ...     loss = harness.step(
        ...         batch
        ...     )
        ...     loss.backward()
        ...     optimizer.zero_grad()


    Attributes:
        environment (Environment[StateT, ActionT]): The environment to run the algorithm in.
        algorithm (Algorithm[StateT, ActionT, Step, Batch]): The algorithm to run.
        num_episodes_per_experience (int): Number of episodes per call to `.experience()`.
        dataloader_kwargs (Dict[str, Any]): Keyword arguments for the PyTorch data loader. Batch size, for instance, should be set.

    Generics:
        StateT (type): The type of the state in the environment.
        ActionT (type): The type of the action in the environment.
        Step (type): The type of a single step in the environment.
        Batch (type): The type of a batch of steps, passed to the `.step()` function for gradient.
    """

    def __init__(
        self,
        environment: Environment[StateT, ActionT],
        algorithm: Algorithm[StateT, ActionT, Step, Batch],
        num_episodes_per_experience: int = 32,
        use_wandb: bool = True,
        wandb_kwargs: Optional[Dict[str, Any]] = None,
        dataloader_kwargs: Optional[Dict[str, Any]] = None,
    ) -> None:
        """
        Args:
            environment (Environment): The environment to run the algorithm in.
            algorithm (Algorithm): The algorithm to run.
            num_episodes_per_experience (int, optional): Number of episodes per call to `.experience()`. Defaults to 32.
            wandb_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments for configuring Weights & Biases. Defaults to None.
            dataloader_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments for the PyTorch DataLoader, such as batch size and shuffle. Defaults to None.
        """

        self.environment = environment
        self.algorithm = algorithm
        self.num_episodes_per_experience = num_episodes_per_experience
        self.use_wandb = use_wandb
        self.wandb_kwargs = wandb_kwargs or {}
        self.dataloader_kwargs: Dict[str, Any] = dataloader_kwargs or {}

        if self.use_wandb:
            self.wandb = ASTRAWandbLogger(self.wandb_kwargs)

    def step(self, batch: Batch) -> tuple[torch.Tensor, Dict[Any, Any]]:
        """Run a step of the algorithm on the dataset.

        Args:
            batch (Batch): The dataset batch to run the algorithm on.

        Returns:
            tuple[torch.Tensor, Dict[Any, Any]]: A tuple containing:
                - torch.Tensor: The loss computed by the algorithm (for current batch).
                - Dict[Any, Any]: Additional information for logging.
        """

        result: torch.Tensor
        logging_dict: Dict[Any, Any]
        result, logging_dict = self.algorithm.step(batch)
        step_logs: Dict[Any, Any] = {}

        # TODO: Add other values here to logs besides algorithm specifics? Alternatively, can just return logging_dict
        step_logs = {
            **logging_dict,
        }

        return result, step_logs

    def experience(self, seed: Optional[int] = None) -> Iterator[Batch]:
        """Collect some experiences!

        Args:
            seed (Optional[int], optional): Seed for reproducibility. Defaults to None.

        Returns:
            Sequence[Step]: A sequence of steps collected from the algorithm's rollouts.
        """

        logger.debug(
            f"Collecting {self.num_episodes_per_experience} episodes of experience..."
        )

        graphs = []
        for _ in range(self.num_episodes_per_experience):
            graph = self.environment.rollout(seed=seed)
            graphs.append(graph)

        steps = sum([list(self.algorithm.flatten(i)) for i in graphs], [])

        logger.debug(
            f"Done collecting {self.num_episodes_per_experience} episodes of experience"
            f", got {len(steps)} training steps."
        )

        return iter(
            DataLoader(
                ListDataset(steps),
                collate_fn=self.algorithm.collate_fn,
                **self.dataloader_kwargs,
            )
        )

    def log_current_step(self, current_logs: Dict[Any, Any]) -> None:
        """Log the current step metrics to Weights & Biases (if enabled) and logger.

        Args:
            current_logs (Dict[Any, Any]): The logs to be recorded.
        """
        if self.use_wandb:
            self.wandb.log(current_logs)

        # Always log to the logger
        # TODO: Do we want to log to the logger? Should be fine as used for debugging?
        logger.info(f"Current logs: {current_logs}")

__init__(environment, algorithm, num_episodes_per_experience=32, use_wandb=True, wandb_kwargs=None, dataloader_kwargs=None)

Parameters:

Name Type Description Default
environment Environment

The environment to run the algorithm in.

required
algorithm Algorithm

The algorithm to run.

required
num_episodes_per_experience int

Number of episodes per call to .experience(). Defaults to 32.

32
wandb_kwargs Optional[Dict[str, Any]]

Keyword arguments for configuring Weights & Biases. Defaults to None.

None
dataloader_kwargs Optional[Dict[str, Any]]

Keyword arguments for the PyTorch DataLoader, such as batch size and shuffle. Defaults to None.

None
Source code in src/astra_rl/training/harness.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def __init__(
    self,
    environment: Environment[StateT, ActionT],
    algorithm: Algorithm[StateT, ActionT, Step, Batch],
    num_episodes_per_experience: int = 32,
    use_wandb: bool = True,
    wandb_kwargs: Optional[Dict[str, Any]] = None,
    dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
    """
    Args:
        environment (Environment): The environment to run the algorithm in.
        algorithm (Algorithm): The algorithm to run.
        num_episodes_per_experience (int, optional): Number of episodes per call to `.experience()`. Defaults to 32.
        wandb_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments for configuring Weights & Biases. Defaults to None.
        dataloader_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments for the PyTorch DataLoader, such as batch size and shuffle. Defaults to None.
    """

    self.environment = environment
    self.algorithm = algorithm
    self.num_episodes_per_experience = num_episodes_per_experience
    self.use_wandb = use_wandb
    self.wandb_kwargs = wandb_kwargs or {}
    self.dataloader_kwargs: Dict[str, Any] = dataloader_kwargs or {}

    if self.use_wandb:
        self.wandb = ASTRAWandbLogger(self.wandb_kwargs)

experience(seed=None)

Collect some experiences!

Parameters:

Name Type Description Default
seed Optional[int]

Seed for reproducibility. Defaults to None.

None

Returns:

Type Description
Iterator[Batch]

Sequence[Step]: A sequence of steps collected from the algorithm's rollouts.

Source code in src/astra_rl/training/harness.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def experience(self, seed: Optional[int] = None) -> Iterator[Batch]:
    """Collect some experiences!

    Args:
        seed (Optional[int], optional): Seed for reproducibility. Defaults to None.

    Returns:
        Sequence[Step]: A sequence of steps collected from the algorithm's rollouts.
    """

    logger.debug(
        f"Collecting {self.num_episodes_per_experience} episodes of experience..."
    )

    graphs = []
    for _ in range(self.num_episodes_per_experience):
        graph = self.environment.rollout(seed=seed)
        graphs.append(graph)

    steps = sum([list(self.algorithm.flatten(i)) for i in graphs], [])

    logger.debug(
        f"Done collecting {self.num_episodes_per_experience} episodes of experience"
        f", got {len(steps)} training steps."
    )

    return iter(
        DataLoader(
            ListDataset(steps),
            collate_fn=self.algorithm.collate_fn,
            **self.dataloader_kwargs,
        )
    )

log_current_step(current_logs)

Log the current step metrics to Weights & Biases (if enabled) and logger.

Parameters:

Name Type Description Default
current_logs Dict[Any, Any]

The logs to be recorded.

required
Source code in src/astra_rl/training/harness.py
171
172
173
174
175
176
177
178
179
180
181
182
def log_current_step(self, current_logs: Dict[Any, Any]) -> None:
    """Log the current step metrics to Weights & Biases (if enabled) and logger.

    Args:
        current_logs (Dict[Any, Any]): The logs to be recorded.
    """
    if self.use_wandb:
        self.wandb.log(current_logs)

    # Always log to the logger
    # TODO: Do we want to log to the logger? Should be fine as used for debugging?
    logger.info(f"Current logs: {current_logs}")

step(batch)

Run a step of the algorithm on the dataset.

Parameters:

Name Type Description Default
batch Batch

The dataset batch to run the algorithm on.

required

Returns:

Type Description
tuple[Tensor, Dict[Any, Any]]

tuple[torch.Tensor, Dict[Any, Any]]: A tuple containing: - torch.Tensor: The loss computed by the algorithm (for current batch). - Dict[Any, Any]: Additional information for logging.

Source code in src/astra_rl/training/harness.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def step(self, batch: Batch) -> tuple[torch.Tensor, Dict[Any, Any]]:
    """Run a step of the algorithm on the dataset.

    Args:
        batch (Batch): The dataset batch to run the algorithm on.

    Returns:
        tuple[torch.Tensor, Dict[Any, Any]]: A tuple containing:
            - torch.Tensor: The loss computed by the algorithm (for current batch).
            - Dict[Any, Any]: Additional information for logging.
    """

    result: torch.Tensor
    logging_dict: Dict[Any, Any]
    result, logging_dict = self.algorithm.step(batch)
    step_logs: Dict[Any, Any] = {}

    # TODO: Add other values here to logs besides algorithm specifics? Alternatively, can just return logging_dict
    step_logs = {
        **logging_dict,
    }

    return result, step_logs