29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139 | class DPO(
Algorithm[StateT, ActionT, DPOStep[StateT, ActionT], DPOBatch[StateT, ActionT]],
Generic[StateT, ActionT],
):
"""Direct Preference Optimization (DPO) algorithm."""
def __init__(self, problem: Problem[StateT, ActionT], beta: float = 0.1):
super().__init__(problem)
self.beta = beta
def flatten(
self, graph: Graph[StateT, ActionT]
) -> Sequence[DPOStep[StateT, ActionT]]:
# in DPO, we sample from each branch the most rewarded
# and least rewarded actions in order to use them as our contrastive
# pairs.
pairs: List[DPOStep[StateT, ActionT]] = []
bfs = [graph.children]
while len(bfs):
front = bfs.pop(0)
sorted_list = sorted(list(front), key=lambda x: x.reward, reverse=True)
if len(sorted_list) < 2:
# if there is no pair, we skip this node
continue
pos_entry = sorted_list[0]
neg_entry = sorted_list[-1]
assert pos_entry.context == neg_entry.context, (
"paired rollouts for DPO must share the same prefix!"
)
pairs.append(
DPOStep(
prefix=pos_entry.context,
suffix_pos=pos_entry.attack,
suffix_neg=neg_entry.attack,
)
)
for i in sorted_list:
bfs.append(i.children)
return pairs
@staticmethod
def collate_fn(x: Sequence[DPOStep[StateT, ActionT]]) -> DPOBatch[StateT, ActionT]:
prefixes = [i.prefix for i in x]
suffix_pos = [i.suffix_pos for i in x]
suffix_neg = [i.suffix_neg for i in x]
return DPOBatch(prefixes=prefixes, suffix_pos=suffix_pos, suffix_neg=suffix_neg)
def step(
self, batch: DPOBatch[StateT, ActionT]
) -> tuple[torch.Tensor, Dict[Any, Any]]:
attacker_logprobs_win = self.problem._get_attacker_logprobs_and_validate(
batch.prefixes, batch.suffix_pos
).sum(dim=-1) # Sum per-token logprobs to get sequence logprobs
attacker_logprobs_loss = self.problem._get_attacker_logprobs_and_validate(
batch.prefixes, batch.suffix_neg
).sum(dim=-1) # Sum per-token logprobs to get sequence logprobs
baseline_logprobs_win = self.problem._get_baseline_logprobs_and_validate(
batch.prefixes, batch.suffix_pos
).sum(dim=-1) # Sum per-token logprobs to get sequence logprobs
baseline_logprobs_loss = self.problem._get_baseline_logprobs_and_validate(
batch.prefixes, batch.suffix_neg
).sum(dim=-1) # Sum per-token logprobs to get sequence logprobs
# https://github.com/eric-mitchell/direct-preference-optimization/blob/ \
# f8b8c0f49dc92a430bae41585f9d467d3618fe2f/trainers.py#L70-L87
pi_logratios = attacker_logprobs_win - attacker_logprobs_loss
ref_logratios = baseline_logprobs_win - baseline_logprobs_loss
logits = pi_logratios - ref_logratios
loss = -F.logsigmoid(self.beta * logits)
# Calculate addition quantities
# TODO: CHECK ME for correctness and completion!
chosen_rewards = self.beta * (attacker_logprobs_win - baseline_logprobs_win)
rejected_rewards = self.beta * (attacker_logprobs_loss - baseline_logprobs_loss)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
reward_margin = chosen_rewards - rejected_rewards
logging_dict: Dict[Any, Any] = {
"training/loss": loss.mean().cpu().item(),
"reward/chosen_rewards": chosen_rewards.mean().cpu().item(),
"reward/rejected_rewards": rejected_rewards.mean().cpu().item(),
"reward/reward_accuracies": reward_accuracies.mean().cpu().item(),
"reward/reward_margin": reward_margin.mean().cpu().item(),
"policy/logprobs_chosen": attacker_logprobs_win.mean()
.detach()
.cpu()
.item(),
"policy/logprobs_rejected": attacker_logprobs_loss.mean()
.detach()
.cpu()
.item(),
"ref/logprobs_chosen": baseline_logprobs_win.mean().detach().cpu().item(),
"ref/logprobs_rejected": baseline_logprobs_loss.mean()
.detach()
.cpu()
.item(),
}
# TODO: Add this from old code?
# "policy/rollout": wandb.Html(str(r"<span>"+batch["prompt_win"][0][0]+"</span><span style='color:Tomato;'>"+batch["prompt_win"][0][1]+r"</span><span style='color:DodgerBlue'>"+batch["prompt_win"][0][2]+r"</span>")),
return loss.mean(), logging_dict
|