Algorithm
astra_rl.core.algorithm
¶
algorithm.py
Algorithm
¶
Bases: ABC, Generic[StateT, ActionT, Step, Batch]
An Algorithm used for performing training.
Specifically, the Algorithm object is responsible for encoding how a particular rollout graph becomes processed into a loss which updates the weights of the model. To implement its children, you basically call self.system's various methods to push values through the network.
Attributes:
| Name | Type | Description |
|---|---|---|
system |
System
|
The system instance that defines the sampler and actions. |
Generics
StateT (type): The type of the state in the sampler. ActionT (type): The type of the action in the sampler. Step (type): The type of a single step in the environment. Batch (type): The type of a batch of steps, passed to the .step() function for gradient.
Source code in src/astra_rl/core/algorithm.py
collate_fn(batch)
abstractmethod
staticmethod
¶
The collate_fn for torch dataloaders for batching.
We use this as the literal collate_fn to a torch DataLoader, and it is responsible for emitting well-formed batches of data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Sequence[Step]
|
A sequence of steps to collate. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
Batch |
Batch
|
A batch of data ready for processing using .step(). |
Source code in src/astra_rl/core/algorithm.py
flatten(graph)
abstractmethod
¶
Process a rollout graph into a sequence of steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
graph
|
Graph[StateT, ActionT]
|
The graph to flatten. |
required |
Returns:
| Type | Description |
|---|---|
Sequence[Step]
|
Sequence[Step]: A sequence of steps representing the flattened graph. |
Source code in src/astra_rl/core/algorithm.py
step(batch)
abstractmethod
¶
Take a batch and compute loss of this batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
A batch of data to process. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Dict[Any, Any]]
|
tuple[torch.Tensor, Dict[Any, Any]]: A tuple containing: - torch.Tensor: The loss computed by the algorithm (for current batch). - Dict[Any, Any]: Additional information for logging. |