Skip to content

AST Problem

astra_rl.methods.ast_problem

ast_problem.py ASTProblem

ASTEnvironment

Bases: Environment[str, str]

The ASTPrompter Rollout Environment

Implements https://arxiv.org/abs/2407.09447.

Specifically, this is the original rollout system used in the ASTPrompter paper, the case of red-teaming where we have the attacker and defender generates successive turns of strings, each of which is appended to the prompt of the other. They do not have IFT or other types of structure.

For usage examples, see astra_rl.core.environment.Environment.

Attributes:

Name Type Description
problem ASTProblem

The problem instance that defines the environment and actions.

prompts Sequence[str]

A sequence of initial prompts to start the rollout.

tree_width int

The number of branches at each node in the rollout tree.

tree_depth int

The depth of the rollout tree.

Generics

StateT (str): The type of the state in the environment, which is a string. ActionT (str): The type of the action in the environment, which is also a string.

Source code in src/astra_rl/methods/ast_problem.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class ASTEnvironment(Environment[str, str]):
    """The ASTPrompter Rollout Environment

    Implements https://arxiv.org/abs/2407.09447.

    Specifically, this is the original rollout system used in the
    ASTPrompter paper, the case of red-teaming where we have
    the attacker and defender generates successive turns of strings,
    each of which is appended to the prompt of the other. They do not
    have IFT or other types of structure.

    For usage examples, see `astra_rl.core.environment.Environment`.

    Attributes:
        problem (ASTProblem): The problem instance that defines the environment and actions.
        prompts (Sequence[str]): A sequence of initial prompts to start the rollout.
        tree_width (int): The number of branches at each node in the rollout tree.
        tree_depth (int): The depth of the rollout tree.

    Generics:
        StateT (str): The type of the state in the environment, which is a string.
        ActionT (str): The type of the action in the environment, which is also a string.
    """

    def __init__(
        self,
        problem: ASTProblem,
        prompts: Sequence[str],
        tree_width: int = 2,
        tree_depth: int = 3,
    ):
        super().__init__(problem)

        self.prompts = prompts
        self.tree_width = tree_width
        self.tree_depth = tree_depth

    def __handle_prompt(self, prompt: str, depth: int = 3) -> Sequence[Node[str, str]]:
        if depth == 0:
            return []

        prompts = [prompt for _ in range(self.tree_width)]
        attacks = self.problem._rollout_prompt_with_attacker_and_validate(prompts)
        defenses = self.problem._rollout_prompt_with_target_and_validate(
            [prompt + i for i in attacks]
        )
        rewards = self.problem.reward(prompts, attacks, defenses)

        nodes = [
            Node(
                prompt,
                attack,
                defense,
                reward,
                self.__handle_prompt(
                    self.problem.advance(prompt, attack, defense), depth - 1
                ),
            )
            for prompt, attack, defense, reward in zip(
                prompts, attacks, defenses, rewards
            )
        ]

        return nodes

    def rollout(self, seed: Optional[int] = None) -> Graph[str, str]:
        R: Union[Random, ModuleType] = random
        if seed:
            R = Random(seed)
        prompt = R.choice(self.prompts)
        nodes = self.__handle_prompt(prompt, self.tree_depth)

        return Graph(prompt, nodes)

ASTProblem

Bases: Problem[str, str]

Adaptive Stress Testing for Large Language Model Red-Teaming.

Implements https://arxiv.org/abs/2407.09447.

For usage examples, see astra_rl.core.problem.Problem.

Attributes:

Name Type Description
moderator Moderator

The moderator used to evaluate the sequences.

ppl_weight float

Weight for perplexity in the reward calculation.

combined_weight float

Weight for combined scores in the reward calculation.

defense_toxicity_weight float

Weight for defense toxicity in the reward calculation.

Source code in src/astra_rl/methods/ast_problem.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class ASTProblem(Problem[str, str]):
    """Adaptive Stress Testing for Large Language Model Red-Teaming.

    Implements https://arxiv.org/abs/2407.09447.

    For usage examples, see `astra_rl.core.problem.Problem`.

    Attributes:
        moderator (Moderator): The moderator used to evaluate the sequences.
        ppl_weight (float): Weight for perplexity in the reward calculation.
        combined_weight (float): Weight for combined scores in the reward calculation.
        defense_toxicity_weight (float): Weight for defense toxicity in the reward calculation.
    """

    def __init__(
        self,
        moderator: Moderator[str, str],
        ppl_weight: float = 0.025,
        combined_weight: float = 0.1,
        defense_toxicity_weight: float = 0.1,
    ):
        super().__init__(moderator)

        self.ppl_weight = ppl_weight
        self.combined_weight = combined_weight
        self.defense_toxicity_weight = defense_toxicity_weight

    def advance(self, state: str, action: str, next: str) -> str:
        return state + action + next

    def reward(
        self, context: Sequence[str], attack: Sequence[str], response: Sequence[str]
    ) -> Sequence[float]:
        combined_uts = [j + k for j, k in zip(attack, response)]
        combined_scores: Sequence[float] = self.moderator.moderate(combined_uts)
        defender_scores: Sequence[float] = self.moderator.moderate(response)
        attack_ppls: torch.Tensor = (
            self._get_target_logprobs_and_validate(context, attack).sum(dim=-1).exp()
        )  # Sum per-token logprobs before taking exp

        reward: Sequence[float] = (
            (
                self.ppl_weight * attack_ppls
                + self.combined_weight
                * torch.tensor(combined_scores).to(attack_ppls.device)
                + self.defense_toxicity_weight
                * torch.tensor(defender_scores).to(attack_ppls.device)
            )
            .cpu()
            .tolist()
        )

        return reward