Skip to content

DPO

astra_rl.algorithms.dpo

DPO

Bases: Algorithm[StateT, ActionT, DPOStep[StateT, ActionT], DPOBatch[StateT, ActionT]], Generic[StateT, ActionT]

Direct Preference Optimization (DPO) algorithm.

Source code in src/astra_rl/algorithms/dpo.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class DPO(
    Algorithm[StateT, ActionT, DPOStep[StateT, ActionT], DPOBatch[StateT, ActionT]],
    Generic[StateT, ActionT],
):
    """Direct Preference Optimization (DPO) algorithm."""

    def __init__(self, problem: Problem[StateT, ActionT], beta: float = 0.1):
        super().__init__(problem)

        self.beta = beta

    def flatten(
        self, graph: Graph[StateT, ActionT]
    ) -> Sequence[DPOStep[StateT, ActionT]]:
        # in DPO, we sample from each branch the most rewarded
        # and least rewarded actions in order to use them as our contrastive
        # pairs.

        pairs: List[DPOStep[StateT, ActionT]] = []
        bfs = [graph.children]
        while len(bfs):
            front = bfs.pop(0)
            sorted_list = sorted(list(front), key=lambda x: x.reward, reverse=True)

            if len(sorted_list) < 2:
                # if there is no pair, we skip this node
                continue

            pos_entry = sorted_list[0]
            neg_entry = sorted_list[-1]

            assert pos_entry.context == neg_entry.context, (
                "paired rollouts for DPO must share the same prefix!"
            )

            pairs.append(
                DPOStep(
                    prefix=pos_entry.context,
                    suffix_pos=pos_entry.attack,
                    suffix_neg=neg_entry.attack,
                )
            )

            for i in sorted_list:
                bfs.append(i.children)

        return pairs

    @staticmethod
    def collate_fn(x: Sequence[DPOStep[StateT, ActionT]]) -> DPOBatch[StateT, ActionT]:
        prefixes = [i.prefix for i in x]
        suffix_pos = [i.suffix_pos for i in x]
        suffix_neg = [i.suffix_neg for i in x]

        return DPOBatch(prefixes=prefixes, suffix_pos=suffix_pos, suffix_neg=suffix_neg)

    def step(
        self, batch: DPOBatch[StateT, ActionT]
    ) -> tuple[torch.Tensor, Dict[Any, Any]]:
        attacker_logprobs_win = self.problem._get_attacker_logprobs_and_validate(
            batch.prefixes, batch.suffix_pos
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
        attacker_logprobs_loss = self.problem._get_attacker_logprobs_and_validate(
            batch.prefixes, batch.suffix_neg
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
        baseline_logprobs_win = self.problem._get_baseline_logprobs_and_validate(
            batch.prefixes, batch.suffix_pos
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
        baseline_logprobs_loss = self.problem._get_baseline_logprobs_and_validate(
            batch.prefixes, batch.suffix_neg
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs

        # https://github.com/eric-mitchell/direct-preference-optimization/blob/ \
        # f8b8c0f49dc92a430bae41585f9d467d3618fe2f/trainers.py#L70-L87
        pi_logratios = attacker_logprobs_win - attacker_logprobs_loss
        ref_logratios = baseline_logprobs_win - baseline_logprobs_loss
        logits = pi_logratios - ref_logratios

        loss = -F.logsigmoid(self.beta * logits)

        # Calculate addition quantities
        # TODO: CHECK ME for correctness and completion!
        chosen_rewards = self.beta * (attacker_logprobs_win - baseline_logprobs_win)
        rejected_rewards = self.beta * (attacker_logprobs_loss - baseline_logprobs_loss)
        reward_accuracies = (chosen_rewards > rejected_rewards).float()
        reward_margin = chosen_rewards - rejected_rewards

        logging_dict: Dict[Any, Any] = {
            "training/loss": loss.mean().cpu().item(),
            "reward/chosen_rewards": chosen_rewards.mean().cpu().item(),
            "reward/rejected_rewards": rejected_rewards.mean().cpu().item(),
            "reward/reward_accuracies": reward_accuracies.mean().cpu().item(),
            "reward/reward_margin": reward_margin.mean().cpu().item(),
            "policy/logprobs_chosen": attacker_logprobs_win.mean()
            .detach()
            .cpu()
            .item(),
            "policy/logprobs_rejected": attacker_logprobs_loss.mean()
            .detach()
            .cpu()
            .item(),
            "ref/logprobs_chosen": baseline_logprobs_win.mean().detach().cpu().item(),
            "ref/logprobs_rejected": baseline_logprobs_loss.mean()
            .detach()
            .cpu()
            .item(),
        }
        # TODO: Add this from old code?
        # "policy/rollout": wandb.Html(str(r"<span>"+batch["prompt_win"][0][0]+"</span><span style='color:Tomato;'>"+batch["prompt_win"][0][1]+r"</span><span style='color:DodgerBlue'>"+batch["prompt_win"][0][2]+r"</span>")),

        return loss.mean(), logging_dict