class PPO(
Algorithm[StateT, ActionT, PPOStep[StateT, ActionT], PPOBatch[StateT, ActionT]],
ABC,
):
"""Proximal Policy Optimization (PPO) algorithm with value function."""
def __init__(
self,
system: ValueFunctionSystem[StateT, ActionT],
clip_range: float = 0.1,
vf_loss_coef: float = 1.0,
):
super().__init__(system)
self.system: ValueFunctionSystem[StateT, ActionT] = system
self.clip_range = clip_range
self.vf_loss_coef = vf_loss_coef
def flatten(
self, graph: Graph[StateT, ActionT]
) -> Sequence[PPOStep[StateT, ActionT]]:
# in DPO, we sample from each branch the most rewarded
# and least rewarded actions in order to use them as our contrastive
# pairs.
res: List[PPOStep[StateT, ActionT]] = []
bfs = [graph.children]
while len(bfs):
front = bfs.pop(0)
if len(list(front)) < 2:
# if there is no pair, we skip this node
continue
for i in front:
res.append(PPOStep(prefix=i.context, suffix=i.probe, reward=i.reward))
bfs.append(i.children)
return res
@staticmethod
def collate_fn(x: Sequence[PPOStep[StateT, ActionT]]) -> PPOBatch[StateT, ActionT]:
prefixes = [i.prefix for i in x]
suffix = [i.suffix for i in x]
rewards = [i.reward for i in x]
return PPOBatch(prefix=prefixes, suffix=suffix, reward=rewards)
def step(
self, batch: PPOBatch[StateT, ActionT]
) -> tuple[torch.Tensor, Dict[Any, Any]]:
logprobs_tester = self.system._get_tester_logprobs_and_validate(
batch.prefix, batch.suffix
)
logprobs_baseline = self.system._get_baseline_logprobs_and_validate(
batch.prefix, batch.suffix
)
values = self.system.value(batch.prefix, batch.suffix)
# Q(s,a) = R(s,a), which is jank but seems to be the standard
# also its bootstrapped without discount throughout the stream
Q = (
torch.tensor(batch.reward)
.to(logprobs_tester.device)
.unsqueeze(-1)
.unsqueeze(-1)
.repeat(1, *values.shape[1:])
)
A = Q - values
# normalize advantages
if A.size(-1) == 1:
A = ((A - A.mean()) / (A.std() + 1e-8)).squeeze(-1)
else:
A = (A - A.mean()) / (A.std() + 1e-8)
# compute ratio, should be 1 at the first iteration
ratio = torch.exp((logprobs_tester - logprobs_baseline.detach()))
# compute clipped surrogate lolss
policy_loss_1 = A * ratio
policy_loss_2 = A * torch.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range)
policy_loss_2 = A * torch.clamp(ratio, 1 - 0.1, 1 + 0.1)
policy_loss = -(torch.min(policy_loss_1, policy_loss_2)).mean()
# compute value loss
value_loss = F.mse_loss(Q, values)
# compute final lossvalue_loss
loss = policy_loss + self.vf_loss_coef * value_loss
# create logging dict
logging_dict: Dict[Any, Any] = {
"training/loss": loss.mean().cpu().item(),
"training/policy_loss": policy_loss.mean().cpu().item(),
"training/value_loss": value_loss.mean().cpu().item(),
"reward/mean_reward": torch.tensor(batch.reward).mean().cpu().item(),
"reward/std_reward": torch.tensor(batch.reward).std().cpu().item(),
"policy/logprobs": logprobs_tester.mean().detach().cpu().item(),
"ref/logprobs": logprobs_baseline.mean().detach().cpu().item(),
}
return loss, logging_dict