Skip to content

Sampler

astra_rl.core.sampler

sampler.py Roll out a system, and specify how its sampler behaves.

Graph dataclass

Bases: Generic[StateT, ActionT]

A graph representing the rollout (history + actions) of a system.

Attributes:

Name Type Description
context StateT

The initial state of the sampler.

children Sequence[Node[StateT, ActionT]]

The sequence of nodes representing actions and responses.

Source code in src/astra_rl/core/sampler.py
@dataclass
class Graph(Generic[StateT, ActionT]):
    """A graph representing the rollout (history + actions) of a system.

    Attributes:
        context (StateT): The initial state of the sampler.
        children (Sequence[Node[StateT, ActionT]]): The sequence of nodes representing actions and responses.
    """

    context: StateT
    children: Sequence[Node[StateT, ActionT]]

Node dataclass

Bases: Generic[StateT, ActionT]

A node in the rollout graph.

Represents a single leaf in the rollout process, containing the context, the action taken, the response received, the reward for that action, and any children nodes that follow this action in this rollout.

Attributes:

Name Type Description
context StateT

The initial state before the action.

probe ActionT

The action taken in this node.

response StateT

The resulting state after the action.

reward float

The reward received for taking the action.

children Sequence[Node[StateT, ActionT]]

Subsequent nodes that follow this action.

Generics

StateT (type): The type of the state in the sampler. ActionT (type): The type of the action in the sampler.

Source code in src/astra_rl/core/sampler.py
@dataclass
class Node(Generic[StateT, ActionT]):
    """A node in the rollout graph.

    Represents a single leaf in the rollout process, containing the context,
    the action taken, the response received, the reward for that action,
    and any children nodes that follow this action in this rollout.

    Attributes:
        context (StateT): The initial state before the action.
        probe (ActionT): The action taken in this node.
        response (StateT): The resulting state after the action.
        reward (float): The reward received for taking the action.
        children (Sequence[Node[StateT, ActionT]]): Subsequent nodes that follow this action.

    Generics:
        StateT (type): The type of the state in the sampler.
        ActionT (type): The type of the action in the sampler.
    """

    context: StateT
    probe: ActionT
    response: StateT
    reward: float

    children: Sequence[Self]

Sampler

Bases: ABC, Generic[StateT, ActionT]

A Sampler used for rolling out a system.

The primary point of this class is to make a Graph of the system by calling the rollout method. The sampler can keep/sample initial state, but should not have global state that persists across rollouts.

Attributes:

Name Type Description
system System[StateT, ActionT]

The system instance that defines the sampler and actions.

Generics

StateT (type): The type of the state in the sampler. ActionT (type): The type of the action in the sampler.

Source code in src/astra_rl/core/sampler.py
class Sampler(ABC, Generic[StateT, ActionT]):
    """A Sampler used for rolling out a system.

    The primary point of this class is to make a `Graph` of the system
    by calling the `rollout` method. The sampler can keep/sample
    initial state, but should not have global state that persists
    across rollouts.

    Attributes:
        system (System[StateT, ActionT]): The system instance that defines the sampler and actions.

    Generics:
        StateT (type): The type of the state in the sampler.
        ActionT (type): The type of the action in the sampler.
    """

    def __init__(self, system: System[StateT, ActionT]):
        self.system = system

    @abstractmethod
    def rollout(self, seed: Optional[int] = None) -> Graph[StateT, ActionT]:
        """Roll out a system and return a graph of the actions taken.

        Args:
            seed (Optional[int]): An optional seed; the same seed should produce the same graph.

        Returns:
            Graph[StateT, ActionT]: A graph representing the rollout of the system.
        """

        pass

    def eval_rollout(self, seed: Optional[Any] = None) -> Graph[StateT, ActionT]:
        """Roll out for evaluation, by default just the standard rollout

        Notes:
            This can be customized to whatever the user desires in terms of rollout for eval.
            For instance, for evaluation the seed maybe StateT instead of int since there may
            be another evaluation dataset.

            However, if the seed given is None or an int, a default implementation exists
            which just calls `self.rollout(seed)` and so evaluation can be done without
            needing to override this method.

        Args:
            seed (Optional[Any]): An optional seed; the same seed should produce the same graph.

        Returns:
            Graph[StateT, ActionT]: A graph representing the rollout of the system.
        """

        if seed is None or isinstance(seed, int):
            return self.rollout(seed)

        raise NotImplementedError(
            "eval_rollout not implemented for non-int seeds; please override this method."
        )

eval_rollout(seed=None)

Roll out for evaluation, by default just the standard rollout

Notes

This can be customized to whatever the user desires in terms of rollout for eval. For instance, for evaluation the seed maybe StateT instead of int since there may be another evaluation dataset.

However, if the seed given is None or an int, a default implementation exists which just calls self.rollout(seed) and so evaluation can be done without needing to override this method.

Parameters:

Name Type Description Default
seed Optional[Any]

An optional seed; the same seed should produce the same graph.

None

Returns:

Type Description
Graph[StateT, ActionT]

Graph[StateT, ActionT]: A graph representing the rollout of the system.

Source code in src/astra_rl/core/sampler.py
def eval_rollout(self, seed: Optional[Any] = None) -> Graph[StateT, ActionT]:
    """Roll out for evaluation, by default just the standard rollout

    Notes:
        This can be customized to whatever the user desires in terms of rollout for eval.
        For instance, for evaluation the seed maybe StateT instead of int since there may
        be another evaluation dataset.

        However, if the seed given is None or an int, a default implementation exists
        which just calls `self.rollout(seed)` and so evaluation can be done without
        needing to override this method.

    Args:
        seed (Optional[Any]): An optional seed; the same seed should produce the same graph.

    Returns:
        Graph[StateT, ActionT]: A graph representing the rollout of the system.
    """

    if seed is None or isinstance(seed, int):
        return self.rollout(seed)

    raise NotImplementedError(
        "eval_rollout not implemented for non-int seeds; please override this method."
    )

rollout(seed=None) abstractmethod

Roll out a system and return a graph of the actions taken.

Parameters:

Name Type Description Default
seed Optional[int]

An optional seed; the same seed should produce the same graph.

None

Returns:

Type Description
Graph[StateT, ActionT]

Graph[StateT, ActionT]: A graph representing the rollout of the system.

Source code in src/astra_rl/core/sampler.py
@abstractmethod
def rollout(self, seed: Optional[int] = None) -> Graph[StateT, ActionT]:
    """Roll out a system and return a graph of the actions taken.

    Args:
        seed (Optional[int]): An optional seed; the same seed should produce the same graph.

    Returns:
        Graph[StateT, ActionT]: A graph representing the rollout of the system.
    """

    pass