Problem
astra_rl.core.problem
¶
A "Problem" is one of the core abstractions in Astra RL, defining how to interact
with the system under test. The interface is defined by the Problem
class, which
defines a set of abstract methods that users must implement to create a custom problem.
This provides flexibility in terms of how users can define their own applications
while still adhering to a common interface that enables the Astra RL framework
to function correctly.
Problem
¶
Bases: ABC
, Generic[StateT, ActionT]
Defines the core problem interface for Astra RL.
This class is responsible for defining how exactly to interact with the system under test---with generics in terms of how to get probabilities and rollouts from the attacker and target models.
This allows for us to be generic over the types of states, actions as well as how to measure them. We ask for a moderator as a way to ensure that subclasses can all be generic over the exact metric, and instead can only be opinonated about how to achieve the metric.
Attributes:
Name | Type | Description |
---|---|---|
moderator |
Moderator[StateT, ActionT]
|
The moderator used to evaluate sequences. |
Generics
StateT (type): The type of the state in the environment. ActionT (type): The type of the action in the environment.
Source code in src/astra_rl/core/problem.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
|
advance(context, attack, response)
abstractmethod
¶
Given a context and continuation, returns the next state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context
|
str
|
Sequence of strings representing the context. |
required |
attack
|
str
|
Sequence of strings representing the attack given context. |
required |
response
|
str
|
Sequence of strings representing the defense against attack. |
required |
Returns:
Name | Type | Description |
---|---|---|
str |
StateT
|
The next state after applying the continuation to the context. |
Source code in src/astra_rl/core/problem.py
146 147 148 149 150 151 152 153 154 155 156 157 158 |
|
get_attacker_logprobs(context, continuation)
abstractmethod
¶
Evaluates P(continuation|context) on attacker. This must return tensor w/ grads!
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context
|
Sequence[str]
|
Sequence of strings, where each string is a context on which the continuation's probability is conditioned. |
required |
continuation
|
Sequence[str]
|
Sequence of strings, where each string is a continuation whose probability is measured. |
required |
Note
This should be batched; i.e., len(context) == len(continuation) and each represents a batch element.
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: The per-token log probabilities of the continuations given their contexts. Shape: (batch_size, max_continuation_length) |
Source code in src/astra_rl/core/problem.py
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
|
get_baseline_logprobs(context, continuation)
abstractmethod
¶
Evaluates P(continuation|context) on attacker's baseline distribution for KL divergence measurements.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context
|
Sequence[str]
|
Sequence of strings, where each string is a context on which the continuation's probability is conditioned. |
required |
continuation
|
Sequence[str]
|
Sequence of strings, where each string is a continuation whose probability is measured. |
required |
Note
This should be batched; i.e., len(context) == len(continuation) and each represents a batch element. Note that this is not the defender's model, but rather the baseline model used for measuring KL divergence to make sure that the trained attacker stays an LM.
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: The per-token log probabilities of the continuations given their contexts. Shape: (batch_size, max_continuation_length) |
Source code in src/astra_rl/core/problem.py
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
|
get_target_logprobs(context, continuation)
abstractmethod
¶
Evaluates P(continuation|context) on model under test.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context
|
Sequence[str]
|
Sequence of strings, where each string is a context on which the continuation's probability is conditioned. |
required |
continuation
|
Sequence[str]
|
Sequence of strings, where each string is a continuation whose probability is measured. |
required |
Note
This should be batched; i.e., len(context) == len(continuation) and each represents a batch element.
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: The per-token log probabilities of the continuations given their contexts. Shape: (batch_size, max_continuation_length) |
Source code in src/astra_rl/core/problem.py
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
|
parameters()
abstractmethod
¶
Return the trainable parameters in this problem.
Returns:
Type | Description |
---|---|
Iterator[Parameter]
|
Iterator[torch.nn.parameter.Parameter]: An iterator over the trainable parameters. |
Iterator[Parameter]
|
usually just by calling model.parameters() |
Source code in src/astra_rl/core/problem.py
160 161 162 163 164 165 166 167 168 |
|
rollout_prompt_with_attacker(x)
abstractmethod
¶
Rolls out the prompt with the attacker model. Do not return the prompt.
a ~ \pi(s)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Sequence[str]
|
Sequence of strings representing the prompt to be rolled out. |
required |
Returns:
Type | Description |
---|---|
Sequence[ActionT]
|
Sequence[str]: The rolled out prompt with the adversary model. |
Source code in src/astra_rl/core/problem.py
118 119 120 121 122 123 124 125 126 127 128 129 130 |
|
rollout_prompt_with_target(x)
abstractmethod
¶
Rolls out the prompt with the model under test. Do not return the prompt.
s' ~ \sum_a T(s, a)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Sequence[str]
|
Sequence of strings representing the prompt to be rolled out. |
required |
Returns:
Type | Description |
---|---|
Sequence[StateT]
|
Sequence[str]: The rolled out prompt with the adversary model. |
Source code in src/astra_rl/core/problem.py
132 133 134 135 136 137 138 139 140 141 142 143 144 |
|
ValueFunctionProblem
¶
Bases: Problem[StateT, ActionT]
, ABC
Extends Problem
to be able to return sequence values with a value head.
Note
This is useful for value-laiden solution methods such as Actor Critic derivatives (i.e., PPO).
Attributes:
Name | Type | Description |
---|---|---|
moderator |
Moderator[StateT, ActionT]
|
The moderator used to evaluate sequences. |
Generics
StateT (type): The type of the state in the environment. ActionT (type): The type of the action in the environment.
Source code in src/astra_rl/core/problem.py
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 |
|
value(context, continuation)
abstractmethod
¶
Given a squence, evaluate its token-wise value using a value function.
Notes
This is typically done by the same neural network you use for rollouts just passing the intermediate activations through another layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
elem
|
Sequence[StateT]
|
The sequence to evaluate. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor[batch_size, max_continuation_length]: The per-token values of |
Tensor
|
the given squence by the sequence predictor. Do not include the value of the input |
Tensor
|
prefixes. |
Source code in src/astra_rl/core/problem.py
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 |
|