Skip to content

Algorithm

astra_rl.core.algorithm

algorithm.py

Algorithm

Bases: ABC, Generic[StateT, ActionT, Step, Batch]

An Algorithm used for performing training.

Specifically, the Algorithm object is responsible for encoding how a particular rollout graph becomes processed into a loss which updates the weights of the model. To implement its children, you basically call self.problem's various methods to push values through the network.

Attributes:

Name Type Description
problem Problem

The problem instance that defines the environment and actions.

Generics

StateT (type): The type of the state in the environment. ActionT (type): The type of the action in the environment. Step (type): The type of a single step in the environment. Batch (type): The type of a batch of steps, passed to the .step() function for gradient.

Source code in src/astra_rl/core/algorithm.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class Algorithm(ABC, Generic[StateT, ActionT, Step, Batch]):
    """An Algorithm used for performing training.

    Specifically, the Algorithm object is responsible for encoding
    how a particular rollout graph becomes processed into a loss
    which updates the weights of the model. To implement its children,
    you basically call self.problem's various methods to push values
    through the network.


    Attributes:
        problem (Problem): The problem instance that defines the environment and actions.

    Generics:
        StateT (type): The type of the state in the environment.
        ActionT (type): The type of the action in the environment.
        Step (type): The type of a single step in the environment.
        Batch (type): The type of a batch of steps, passed to the .step() function for gradient.
    """

    def __init__(self, problem: Problem[StateT, ActionT]):
        self.problem = problem

    @abstractmethod
    def flatten(self, graph: Graph[StateT, ActionT]) -> Sequence[Step]:
        """Process a rollout graph into a sequence of steps.

        Args:
            graph (Graph[StateT, ActionT]): The graph to flatten.

        Returns:
            Sequence[Step]: A sequence of steps representing the flattened graph.
        """
        pass

    @staticmethod
    @abstractmethod
    def collate_fn(batch: Sequence[Step]) -> Batch:
        """The collate_fn for torch dataloaders for batching.

        We use this as the literal collate_fn to a torch DataLoader, and
        it is responsible for emitting well-formed batches of data.

        Args:
            batch (Sequence[Step]): A sequence of steps to collate.

        Returns:
            Batch: A batch of data ready for processing using .step().
        """
        pass

    @abstractmethod
    def step(self, batch: Batch) -> tuple[torch.Tensor, Dict[Any, Any]]:
        """Take a batch and compute loss of this batch.

        Args:
            batch (Batch): A batch of data to process.

        Returns:
            tuple[torch.Tensor, Dict[Any, Any]]: A tuple containing:
                - torch.Tensor: The loss computed by the algorithm (for current batch).
                - Dict[Any, Any]: Additional information for logging.
        """
        pass

collate_fn(batch) abstractmethod staticmethod

The collate_fn for torch dataloaders for batching.

We use this as the literal collate_fn to a torch DataLoader, and it is responsible for emitting well-formed batches of data.

Parameters:

Name Type Description Default
batch Sequence[Step]

A sequence of steps to collate.

required

Returns:

Name Type Description
Batch Batch

A batch of data ready for processing using .step().

Source code in src/astra_rl/core/algorithm.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
@staticmethod
@abstractmethod
def collate_fn(batch: Sequence[Step]) -> Batch:
    """The collate_fn for torch dataloaders for batching.

    We use this as the literal collate_fn to a torch DataLoader, and
    it is responsible for emitting well-formed batches of data.

    Args:
        batch (Sequence[Step]): A sequence of steps to collate.

    Returns:
        Batch: A batch of data ready for processing using .step().
    """
    pass

flatten(graph) abstractmethod

Process a rollout graph into a sequence of steps.

Parameters:

Name Type Description Default
graph Graph[StateT, ActionT]

The graph to flatten.

required

Returns:

Type Description
Sequence[Step]

Sequence[Step]: A sequence of steps representing the flattened graph.

Source code in src/astra_rl/core/algorithm.py
38
39
40
41
42
43
44
45
46
47
48
@abstractmethod
def flatten(self, graph: Graph[StateT, ActionT]) -> Sequence[Step]:
    """Process a rollout graph into a sequence of steps.

    Args:
        graph (Graph[StateT, ActionT]): The graph to flatten.

    Returns:
        Sequence[Step]: A sequence of steps representing the flattened graph.
    """
    pass

step(batch) abstractmethod

Take a batch and compute loss of this batch.

Parameters:

Name Type Description Default
batch Batch

A batch of data to process.

required

Returns:

Type Description
tuple[Tensor, Dict[Any, Any]]

tuple[torch.Tensor, Dict[Any, Any]]: A tuple containing: - torch.Tensor: The loss computed by the algorithm (for current batch). - Dict[Any, Any]: Additional information for logging.

Source code in src/astra_rl/core/algorithm.py
66
67
68
69
70
71
72
73
74
75
76
77
78
@abstractmethod
def step(self, batch: Batch) -> tuple[torch.Tensor, Dict[Any, Any]]:
    """Take a batch and compute loss of this batch.

    Args:
        batch (Batch): A batch of data to process.

    Returns:
        tuple[torch.Tensor, Dict[Any, Any]]: A tuple containing:
            - torch.Tensor: The loss computed by the algorithm (for current batch).
            - Dict[Any, Any]: Additional information for logging.
    """
    pass