Skip to content

Problem

astra_rl.core.problem

A "Problem" is one of the core abstractions in Astra RL, defining how to interact with the system under test. The interface is defined by the Problem class, which defines a set of abstract methods that users must implement to create a custom problem. This provides flexibility in terms of how users can define their own applications while still adhering to a common interface that enables the Astra RL framework to function correctly.

Problem

Bases: ABC, Generic[StateT, ActionT]

Defines the core problem interface for Astra RL.

This class is responsible for defining how exactly to interact with the system under test---with generics in terms of how to get probabilities and rollouts from the attacker and target models.

This allows for us to be generic over the types of states, actions as well as how to measure them. We ask for a moderator as a way to ensure that subclasses can all be generic over the exact metric, and instead can only be opinonated about how to achieve the metric.

Attributes:

Name Type Description
moderator Moderator[StateT, ActionT]

The moderator used to evaluate sequences.

Generics

StateT (type): The type of the state in the environment. ActionT (type): The type of the action in the environment.

Source code in src/astra_rl/core/problem.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
class Problem(ABC, Generic[StateT, ActionT]):
    """Defines the core problem interface for Astra RL.

    This class is responsible for defining how exactly to interact
    with the system under test---with generics in terms of how to get
    probabilities and rollouts from the attacker and target models.

    This allows for us to be generic over the types of states, actions
    as well as how to measure them. We ask for a moderator as a way to
    ensure that subclasses can all be generic over the exact metric, and
    instead can only be opinonated about how to achieve the metric.

    Attributes:
        moderator (Moderator[StateT, ActionT]): The moderator used to evaluate sequences.

    Generics:
        StateT (type): The type of the state in the environment.
        ActionT (type): The type of the action in the environment.
    """

    def __init__(self, moderator: Moderator[StateT, ActionT]) -> None:
        # we check all asserts once, and then disable them
        self._disable_asserts: Dict[str, bool] = defaultdict(bool)
        self.moderator = moderator

    @abstractmethod
    def get_target_logprobs(
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
    ) -> torch.Tensor:
        """Evaluates P(continuation|context) on *model under test*.

        Args:
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
                                 continuation's probability is conditioned.
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
                                      probability is measured.

        Note:
            This should be batched; i.e., len(context) == len(continuation) and each
            represents a batch element.

        Returns:
            torch.Tensor: The per-token log probabilities of the continuations given their contexts.
                         Shape: (batch_size, max_continuation_length)
        """

        pass

    @abstractmethod
    def get_baseline_logprobs(
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
    ) -> torch.Tensor:
        """Evaluates P(continuation|context) on *attacker's baseline distribution* for KL
           divergence measurements.

        Args:
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
                                 continuation's probability is conditioned.
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
                                      probability is measured.

        Note:
            This should be batched; i.e., len(context) == len(continuation) and each
            represents a batch element. Note that this is *not* the defender's model, but
            rather the baseline model used for measuring KL divergence to make sure that
            the trained attacker stays an LM.

        Returns:
            torch.Tensor: The per-token log probabilities of the continuations given their contexts.
                         Shape: (batch_size, max_continuation_length)
        """

        pass

    @abstractmethod
    def get_attacker_logprobs(
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
    ) -> torch.Tensor:
        """Evaluates P(continuation|context) on *attacker*. This must return tensor w/ grads!

        Args:
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
                                 continuation's probability is conditioned.
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
                                      probability is measured.

        Note:
            This should be batched; i.e., len(context) == len(continuation) and each
            represents a batch element.

        Returns:
            torch.Tensor: The per-token log probabilities of the continuations given their contexts.
                         Shape: (batch_size, max_continuation_length)
        """

        pass

    @abstractmethod
    def rollout_prompt_with_attacker(self, x: Sequence[StateT]) -> Sequence[ActionT]:
        """Rolls out the prompt with the attacker model. Do *not* return the prompt.

        a ~ \\pi(s)

        Args:
            x (Sequence[str]): Sequence of strings representing the prompt to be rolled out.

        Returns:
            Sequence[str]: The rolled out prompt with the adversary model.
        """
        pass

    @abstractmethod
    def rollout_prompt_with_target(self, x: Sequence[StateT]) -> Sequence[StateT]:
        """Rolls out the prompt with the model under test. Do *not* return the prompt.

        s' ~ \\sum_a T(s, a)

        Args:
            x (Sequence[str]): Sequence of strings representing the prompt to be rolled out.

        Returns:
            Sequence[str]: The rolled out prompt with the adversary model.
        """
        pass

    @abstractmethod
    def advance(self, context: StateT, attack: ActionT, response: StateT) -> StateT:
        """Given a context and continuation, returns the next state.

        Args:
            context (str): Sequence of strings representing the context.
            attack (str): Sequence of strings representing the attack given context.
            response (str): Sequence of strings representing the defense against attack.

        Returns:
                str: The next state after applying the continuation to the context.
        """
        pass

    @abstractmethod
    def parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
        """Return the trainable parameters in this problem.

        Returns:
            Iterator[torch.nn.parameter.Parameter]: An iterator over the trainable parameters.
            usually just by calling model.parameters()
        """
        pass

    @abstractmethod
    def reward(
        self,
        context: Sequence[StateT],
        attack: Sequence[ActionT],
        response: Sequence[StateT],
    ) -> Sequence[float]:
        pass

    ##### Utility methods for validation and checks #####

    def _check_continuation(
        self,
        check_key: str,
        context: Sequence[StateT],
        continuation: Sequence[Union[ActionT, StateT]],
    ) -> None:
        if self._disable_asserts[check_key]:
            return
        self._disable_asserts[check_key] = True

    def _check_logprobs(
        self,
        check_key: str,
        logprobs: torch.Tensor,
        ctx_length: int,
        requires_grad: bool = False,
    ) -> None:
        if self._disable_asserts[check_key]:
            return
        # check that logprobs is a tensor and has gradients
        assert isinstance(logprobs, torch.Tensor), "Logprobs must be a torch.Tensor."
        if requires_grad:
            assert logprobs.requires_grad, (
                "Attacker logprobs must carry gradient information."
            )
        # check that the size of the tensor is B x T, where B is the batch size and T is max_continuation_length
        assert logprobs.dim() == 2, (
            "Logprobs must be a 2D tensor (batch_size, max_continuation_length)."
        )
        # check that the first dimension is the batch size
        assert logprobs.size(0) == ctx_length, (
            "Logprobs must have the same batch size as the context."
        )
        # warn if everything is between 0 and 1
        if ((logprobs >= 0.0) & (logprobs <= 1.0)).all():
            logger.warning(
                "Logprobs looks suspiciously like probabilities, "
                "try taking the .log() of your tensor?"
            )
        self._disable_asserts[check_key] = True

    def _get_attacker_logprobs_and_validate(
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
    ) -> torch.Tensor:
        logprobs = self.get_attacker_logprobs(context, continuation)
        self._check_logprobs("attacker_logprobs", logprobs, len(context), True)
        return logprobs

    def _get_target_logprobs_and_validate(
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
    ) -> torch.Tensor:
        logprobs = self.get_target_logprobs(context, continuation)
        self._check_logprobs("target_logprobs", logprobs, len(context), False)
        return logprobs

    def _get_baseline_logprobs_and_validate(
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
    ) -> torch.Tensor:
        logprobs = self.get_baseline_logprobs(context, continuation)
        self._check_logprobs("baseline_logprobs", logprobs, len(context), False)
        return logprobs

    def _rollout_prompt_with_attacker_and_validate(
        self, x: Sequence[StateT]
    ) -> Sequence[ActionT]:
        rolled_out = self.rollout_prompt_with_attacker(x)
        self._check_continuation("attacker_rollout", x, rolled_out)
        return rolled_out

    def _rollout_prompt_with_target_and_validate(
        self, x: Sequence[StateT]
    ) -> Sequence[StateT]:
        rolled_out = self.rollout_prompt_with_target(x)
        self._check_continuation("target_rollout", x, rolled_out)
        return rolled_out

advance(context, attack, response) abstractmethod

Given a context and continuation, returns the next state.

Parameters:

Name Type Description Default
context str

Sequence of strings representing the context.

required
attack str

Sequence of strings representing the attack given context.

required
response str

Sequence of strings representing the defense against attack.

required

Returns:

Name Type Description
str StateT

The next state after applying the continuation to the context.

Source code in src/astra_rl/core/problem.py
146
147
148
149
150
151
152
153
154
155
156
157
158
@abstractmethod
def advance(self, context: StateT, attack: ActionT, response: StateT) -> StateT:
    """Given a context and continuation, returns the next state.

    Args:
        context (str): Sequence of strings representing the context.
        attack (str): Sequence of strings representing the attack given context.
        response (str): Sequence of strings representing the defense against attack.

    Returns:
            str: The next state after applying the continuation to the context.
    """
    pass

get_attacker_logprobs(context, continuation) abstractmethod

Evaluates P(continuation|context) on attacker. This must return tensor w/ grads!

Parameters:

Name Type Description Default
context Sequence[str]

Sequence of strings, where each string is a context on which the continuation's probability is conditioned.

required
continuation Sequence[str]

Sequence of strings, where each string is a continuation whose probability is measured.

required
Note

This should be batched; i.e., len(context) == len(continuation) and each represents a batch element.

Returns:

Type Description
Tensor

torch.Tensor: The per-token log probabilities of the continuations given their contexts. Shape: (batch_size, max_continuation_length)

Source code in src/astra_rl/core/problem.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@abstractmethod
def get_attacker_logprobs(
    self, context: Sequence[StateT], continuation: Sequence[ActionT]
) -> torch.Tensor:
    """Evaluates P(continuation|context) on *attacker*. This must return tensor w/ grads!

    Args:
        context (Sequence[str]): Sequence of strings, where each string is a context on which the
                             continuation's probability is conditioned.
        continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
                                  probability is measured.

    Note:
        This should be batched; i.e., len(context) == len(continuation) and each
        represents a batch element.

    Returns:
        torch.Tensor: The per-token log probabilities of the continuations given their contexts.
                     Shape: (batch_size, max_continuation_length)
    """

    pass

get_baseline_logprobs(context, continuation) abstractmethod

Evaluates P(continuation|context) on attacker's baseline distribution for KL divergence measurements.

Parameters:

Name Type Description Default
context Sequence[str]

Sequence of strings, where each string is a context on which the continuation's probability is conditioned.

required
continuation Sequence[str]

Sequence of strings, where each string is a continuation whose probability is measured.

required
Note

This should be batched; i.e., len(context) == len(continuation) and each represents a batch element. Note that this is not the defender's model, but rather the baseline model used for measuring KL divergence to make sure that the trained attacker stays an LM.

Returns:

Type Description
Tensor

torch.Tensor: The per-token log probabilities of the continuations given their contexts. Shape: (batch_size, max_continuation_length)

Source code in src/astra_rl/core/problem.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@abstractmethod
def get_baseline_logprobs(
    self, context: Sequence[StateT], continuation: Sequence[ActionT]
) -> torch.Tensor:
    """Evaluates P(continuation|context) on *attacker's baseline distribution* for KL
       divergence measurements.

    Args:
        context (Sequence[str]): Sequence of strings, where each string is a context on which the
                             continuation's probability is conditioned.
        continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
                                  probability is measured.

    Note:
        This should be batched; i.e., len(context) == len(continuation) and each
        represents a batch element. Note that this is *not* the defender's model, but
        rather the baseline model used for measuring KL divergence to make sure that
        the trained attacker stays an LM.

    Returns:
        torch.Tensor: The per-token log probabilities of the continuations given their contexts.
                     Shape: (batch_size, max_continuation_length)
    """

    pass

get_target_logprobs(context, continuation) abstractmethod

Evaluates P(continuation|context) on model under test.

Parameters:

Name Type Description Default
context Sequence[str]

Sequence of strings, where each string is a context on which the continuation's probability is conditioned.

required
continuation Sequence[str]

Sequence of strings, where each string is a continuation whose probability is measured.

required
Note

This should be batched; i.e., len(context) == len(continuation) and each represents a batch element.

Returns:

Type Description
Tensor

torch.Tensor: The per-token log probabilities of the continuations given their contexts. Shape: (batch_size, max_continuation_length)

Source code in src/astra_rl/core/problem.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@abstractmethod
def get_target_logprobs(
    self, context: Sequence[StateT], continuation: Sequence[ActionT]
) -> torch.Tensor:
    """Evaluates P(continuation|context) on *model under test*.

    Args:
        context (Sequence[str]): Sequence of strings, where each string is a context on which the
                             continuation's probability is conditioned.
        continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
                                  probability is measured.

    Note:
        This should be batched; i.e., len(context) == len(continuation) and each
        represents a batch element.

    Returns:
        torch.Tensor: The per-token log probabilities of the continuations given their contexts.
                     Shape: (batch_size, max_continuation_length)
    """

    pass

parameters() abstractmethod

Return the trainable parameters in this problem.

Returns:

Type Description
Iterator[Parameter]

Iterator[torch.nn.parameter.Parameter]: An iterator over the trainable parameters.

Iterator[Parameter]

usually just by calling model.parameters()

Source code in src/astra_rl/core/problem.py
160
161
162
163
164
165
166
167
168
@abstractmethod
def parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
    """Return the trainable parameters in this problem.

    Returns:
        Iterator[torch.nn.parameter.Parameter]: An iterator over the trainable parameters.
        usually just by calling model.parameters()
    """
    pass

rollout_prompt_with_attacker(x) abstractmethod

Rolls out the prompt with the attacker model. Do not return the prompt.

a ~ \pi(s)

Parameters:

Name Type Description Default
x Sequence[str]

Sequence of strings representing the prompt to be rolled out.

required

Returns:

Type Description
Sequence[ActionT]

Sequence[str]: The rolled out prompt with the adversary model.

Source code in src/astra_rl/core/problem.py
118
119
120
121
122
123
124
125
126
127
128
129
130
@abstractmethod
def rollout_prompt_with_attacker(self, x: Sequence[StateT]) -> Sequence[ActionT]:
    """Rolls out the prompt with the attacker model. Do *not* return the prompt.

    a ~ \\pi(s)

    Args:
        x (Sequence[str]): Sequence of strings representing the prompt to be rolled out.

    Returns:
        Sequence[str]: The rolled out prompt with the adversary model.
    """
    pass

rollout_prompt_with_target(x) abstractmethod

Rolls out the prompt with the model under test. Do not return the prompt.

s' ~ \sum_a T(s, a)

Parameters:

Name Type Description Default
x Sequence[str]

Sequence of strings representing the prompt to be rolled out.

required

Returns:

Type Description
Sequence[StateT]

Sequence[str]: The rolled out prompt with the adversary model.

Source code in src/astra_rl/core/problem.py
132
133
134
135
136
137
138
139
140
141
142
143
144
@abstractmethod
def rollout_prompt_with_target(self, x: Sequence[StateT]) -> Sequence[StateT]:
    """Rolls out the prompt with the model under test. Do *not* return the prompt.

    s' ~ \\sum_a T(s, a)

    Args:
        x (Sequence[str]): Sequence of strings representing the prompt to be rolled out.

    Returns:
        Sequence[str]: The rolled out prompt with the adversary model.
    """
    pass

ValueFunctionProblem

Bases: Problem[StateT, ActionT], ABC

Extends Problem to be able to return sequence values with a value head.

Note

This is useful for value-laiden solution methods such as Actor Critic derivatives (i.e., PPO).

Attributes:

Name Type Description
moderator Moderator[StateT, ActionT]

The moderator used to evaluate sequences.

Generics

StateT (type): The type of the state in the environment. ActionT (type): The type of the action in the environment.

Source code in src/astra_rl/core/problem.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
class ValueFunctionProblem(Problem[StateT, ActionT], ABC):
    """Extends `Problem` to be able to return sequence values with a value head.

    Note:
        This is useful for value-laiden solution methods such as Actor
        Critic derivatives (i.e., PPO).

    Attributes:
        moderator (Moderator[StateT, ActionT]): The moderator used to evaluate sequences.

    Generics:
        StateT (type): The type of the state in the environment.
        ActionT (type): The type of the action in the environment.
    """

    @abstractmethod
    def value(
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
    ) -> torch.Tensor:
        """Given a squence, evaluate its token-wise value using a value function.

        Notes:
           This is typically done by the same neural network you use for rollouts
           just passing the intermediate activations through another layer.

        Args:
            elem (Sequence[StateT]): The sequence to evaluate.

        Returns:
            torch.Tensor[batch_size, max_continuation_length]: The per-token values of
            the given squence by the sequence predictor. Do not include the value of the input
            prefixes.
        """

        pass

value(context, continuation) abstractmethod

Given a squence, evaluate its token-wise value using a value function.

Notes

This is typically done by the same neural network you use for rollouts just passing the intermediate activations through another layer.

Parameters:

Name Type Description Default
elem Sequence[StateT]

The sequence to evaluate.

required

Returns:

Type Description
Tensor

torch.Tensor[batch_size, max_continuation_length]: The per-token values of

Tensor

the given squence by the sequence predictor. Do not include the value of the input

Tensor

prefixes.

Source code in src/astra_rl/core/problem.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
@abstractmethod
def value(
    self, context: Sequence[StateT], continuation: Sequence[ActionT]
) -> torch.Tensor:
    """Given a squence, evaluate its token-wise value using a value function.

    Notes:
       This is typically done by the same neural network you use for rollouts
       just passing the intermediate activations through another layer.

    Args:
        elem (Sequence[StateT]): The sequence to evaluate.

    Returns:
        torch.Tensor[batch_size, max_continuation_length]: The per-token values of
        the given squence by the sequence predictor. Do not include the value of the input
        prefixes.
    """

    pass