Skip to content

DPO

astra_rl.algorithms.dpo

DPO

Bases: Algorithm[StateT, ActionT, DPOStep[StateT, ActionT], DPOBatch[StateT, ActionT]], Generic[StateT, ActionT]

Direct Preference Optimization (DPO) algorithm.

Source code in src/astra_rl/algorithms/dpo.py
class DPO(
    Algorithm[StateT, ActionT, DPOStep[StateT, ActionT], DPOBatch[StateT, ActionT]],
    Generic[StateT, ActionT],
):
    """Direct Preference Optimization (DPO) algorithm."""

    def __init__(self, system: System[StateT, ActionT], beta: float = 0.1):
        super().__init__(system)

        self.beta = beta

    def flatten(
        self, graph: Graph[StateT, ActionT]
    ) -> Sequence[DPOStep[StateT, ActionT]]:
        # in DPO, we sample from each branch the most rewarded
        # and least rewarded actions in order to use them as our contrastive
        # pairs.

        pairs: List[DPOStep[StateT, ActionT]] = []
        bfs = [graph.children]
        while len(bfs):
            front = bfs.pop(0)
            sorted_list = sorted(list(front), key=lambda x: x.reward, reverse=True)

            if len(sorted_list) < 2:
                # if there is no pair, we skip this node
                continue

            pos_entry = sorted_list[0]
            neg_entry = sorted_list[-1]

            assert pos_entry.context == neg_entry.context, (
                "paired rollouts for DPO must share the same prefix!"
            )

            pairs.append(
                DPOStep(
                    prefix=pos_entry.context,
                    suffix_pos=pos_entry.probe,
                    suffix_neg=neg_entry.probe,
                )
            )

            for i in sorted_list:
                bfs.append(i.children)

        return pairs

    @staticmethod
    def collate_fn(x: Sequence[DPOStep[StateT, ActionT]]) -> DPOBatch[StateT, ActionT]:
        prefixes = [i.prefix for i in x]
        suffix_pos = [i.suffix_pos for i in x]
        suffix_neg = [i.suffix_neg for i in x]

        return DPOBatch(prefixes=prefixes, suffix_pos=suffix_pos, suffix_neg=suffix_neg)

    def step(
        self, batch: DPOBatch[StateT, ActionT]
    ) -> tuple[torch.Tensor, Dict[Any, Any]]:
        tester_logprobs_win = self.system._get_tester_logprobs_and_validate(
            batch.prefixes, batch.suffix_pos
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
        tester_logprobs_loss = self.system._get_tester_logprobs_and_validate(
            batch.prefixes, batch.suffix_neg
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
        baseline_logprobs_win = self.system._get_baseline_logprobs_and_validate(
            batch.prefixes, batch.suffix_pos
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
        baseline_logprobs_loss = self.system._get_baseline_logprobs_and_validate(
            batch.prefixes, batch.suffix_neg
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs

        # https://github.com/eric-mitchell/direct-preference-optimization/blob/ \
        # f8b8c0f49dc92a430bae41585f9d467d3618fe2f/trainers.py#L70-L87
        pi_logratios = tester_logprobs_win - tester_logprobs_loss
        ref_logratios = baseline_logprobs_win - baseline_logprobs_loss
        logits = pi_logratios - ref_logratios

        loss = -F.logsigmoid(self.beta * logits)

        # Calculate addition quantities
        # TODO: CHECK ME for correctness and completion!
        chosen_rewards = self.beta * (tester_logprobs_win - baseline_logprobs_win)
        rejected_rewards = self.beta * (tester_logprobs_loss - baseline_logprobs_loss)
        reward_accuracies = (chosen_rewards > rejected_rewards).float()
        reward_margin = chosen_rewards - rejected_rewards

        logging_dict: Dict[Any, Any] = {
            "training/loss": loss.mean().cpu().item(),
            "reward/chosen_rewards": chosen_rewards.mean().cpu().item(),
            "reward/rejected_rewards": rejected_rewards.mean().cpu().item(),
            "reward/reward_accuracies": reward_accuracies.mean().cpu().item(),
            "reward/reward_margin": reward_margin.mean().cpu().item(),
            "policy/logprobs_chosen": tester_logprobs_win.mean().detach().cpu().item(),
            "policy/logprobs_rejected": tester_logprobs_loss.mean()
            .detach()
            .cpu()
            .item(),
            "ref/logprobs_chosen": baseline_logprobs_win.mean().detach().cpu().item(),
            "ref/logprobs_rejected": baseline_logprobs_loss.mean()
            .detach()
            .cpu()
            .item(),
        }
        # TODO: Add this from old code?
        # "policy/rollout": wandb.Html(str(r"<span>"+batch["prompt_win"][0][0]+"</span><span style='color:Tomato;'>"+batch["prompt_win"][0][1]+r"</span><span style='color:DodgerBlue'>"+batch["prompt_win"][0][2]+r"</span>")),

        return loss.mean(), logging_dict