Algorithm
astra_rl.core.algorithm
¶
algorithm.py
Algorithm
¶
Bases: ABC
, Generic[StateT, ActionT, Step, Batch]
An Algorithm used for performing training.
Specifically, the Algorithm object is responsible for encoding how a particular rollout graph becomes processed into a loss which updates the weights of the model. To implement its children, you basically call self.problem's various methods to push values through the network.
Attributes:
Name | Type | Description |
---|---|---|
problem |
Problem
|
The problem instance that defines the environment and actions. |
Generics
StateT (type): The type of the state in the environment. ActionT (type): The type of the action in the environment. Step (type): The type of a single step in the environment. Batch (type): The type of a batch of steps, passed to the .step() function for gradient.
Source code in src/astra_rl/core/algorithm.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
|
collate_fn(batch)
abstractmethod
staticmethod
¶
The collate_fn for torch dataloaders for batching.
We use this as the literal collate_fn to a torch DataLoader, and it is responsible for emitting well-formed batches of data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Sequence[Step]
|
A sequence of steps to collate. |
required |
Returns:
Name | Type | Description |
---|---|---|
Batch |
Batch
|
A batch of data ready for processing using .step(). |
Source code in src/astra_rl/core/algorithm.py
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
|
flatten(graph)
abstractmethod
¶
Process a rollout graph into a sequence of steps.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
graph
|
Graph[StateT, ActionT]
|
The graph to flatten. |
required |
Returns:
Type | Description |
---|---|
Sequence[Step]
|
Sequence[Step]: A sequence of steps representing the flattened graph. |
Source code in src/astra_rl/core/algorithm.py
38 39 40 41 42 43 44 45 46 47 48 |
|
step(batch)
abstractmethod
¶
Take a batch and compute loss of this batch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Batch
|
A batch of data to process. |
required |
Returns:
Type | Description |
---|---|
tuple[Tensor, Dict[Any, Any]]
|
tuple[torch.Tensor, Dict[Any, Any]]: A tuple containing: - torch.Tensor: The loss computed by the algorithm (for current batch). - Dict[Any, Any]: Additional information for logging. |
Source code in src/astra_rl/core/algorithm.py
66 67 68 69 70 71 72 73 74 75 76 77 78 |
|