1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
| class RLHF_PPO: def __init__(self, policy, ref_model, reward_model, value_model): self.policy = policy self.ref_model = ref_model self.reward_model = reward_model self.value_model = value_model self.kl_coef = 0.2 self.gamma = 1.0 self.lam = 0.95 self.clip_range = 0.2 self.vf_coef = 0.1 self.ent_coef = 0.01 def generate_and_score(self, queries): """生成回答并计算各项指标""" with torch.no_grad(): responses = self.policy.generate( queries, max_length=512, do_sample=True, temperature=0.7 ) rm_scores = self.reward_model(queries, responses) policy_logp = self.policy.log_prob(queries, responses) ref_logp = self.ref_model.log_prob(queries, responses) kl_div = (policy_logp - ref_logp).mean() rewards = rm_scores - self.kl_coef * kl_div values = self.value_model(queries, responses) return responses, rewards, values, policy_logp def compute_advantages(self, rewards, values): """计算GAE优势估计""" advantages = [] gae = 0 next_value = 0 for t in reversed(range(len(rewards))): delta = rewards[t] + self.gamma * next_value - values[t] gae = delta + self.gamma * self.lam * gae advantages.insert(0, gae) next_value = values[t] advantages = torch.tensor(advantages) return (advantages - advantages.mean()) / (advantages.std() + 1e-8) def ppo_update(self, queries, responses, old_logprobs, rewards, values, advantages): """PPO策略更新""" for epoch in range(4): new_logprobs = self.policy.log_prob(queries, responses) ratio = torch.exp(new_logprobs - old_logprobs) surr1 = ratio * advantages surr2 = torch.clamp( ratio, 1 - self.clip_range, 1 + self.clip_range ) * advantages policy_loss = -torch.min(surr1, surr2).mean() new_values = self.value_model(queries, responses) value_loss = F.mse_loss(new_values, rewards) entropy = self.policy.entropy(queries, responses).mean() total_loss = ( policy_loss + self.vf_coef * value_loss - self.ent_coef * entropy ) total_loss.backward() torch.nn.utils.clip_grad_norm_( self.policy.parameters(), 0.5 ) self.optimizer.step() self.optimizer.zero_grad()
|