将强化学习重新引入 RLHF
我们很高兴在 TRL 中介绍 RLOO (REINFORCE Leave One-Out) 训练器。作为一种替代 PPO 的方法,RLOO 是一种新的在线 RLHF 训练算法,旨在使其更易于访问和实施。特别是, RLOO 需要的 GPU 内存更少,并且达到收敛所需的挂钟时间也更短。如下面的图表所示:
-
🤑根据模型大小,RLOO 使用的 vRAM 比 PPO 少大约 50-70%; -
🚀对于 1B 参数模型,RLOO 的运行速度比 PPO 快 2 倍,对于 6.9B 参数模型,RLOO 的运行速度比 PPO 快 3 倍。 -
🔥在响应胜率 (由 GPT4 判断) 方面,RLOO 与 PPO 相当,并且始终优于 DPO 等流行的离线方法。
通过 RLOO,我们将强化学习重新引入 RLHF,使社区能够更轻松地探索在线 RL 方法。这令人兴奋,因为越来越多的研究表明,在线 RL 比 DPO 等离线方法更有效 (
这篇博客将解释 RLOO 训练器的背后的动机,它是如何工作的,以及如何在 TRL 中使用它。
动机
PPO 是一种有效的在线 RLHF 训练算法,用于训练最先进的模型,如 GPT-4。然而,由于其对 GPU 内存的高要求,PPO 在实际使用中可能相当具有挑战性。特别是,PPO 需要将模型的 4 个副本加载到内存中: 1) 策略模型,2) 参考策略模型,3) 奖励模型,以及 4) 价值模型,如下面的图所示。PPO 还有许多微妙的实现细节,这些细节可能很难正确把握 (
-
Engstrom 等人; 2020 https://openreview.net/forum?id=r1etN1rtPB -
Huang 等人 2022 https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/
在 Cohere 的一篇新论文中,
-
Ahmadian 等人 (2024) https://cohere.com/research/papers/back-to-basics-revisiting-reinforce-style-optimization-for-learning-from-human-feedback-in-llms-2024-02-23
重要的是,RLOO 需要更少的内存,这意味着它更容易:
-
在不出现 OOMs (内存不足错误) 的情况下运行 -
能够加载更大的批量大小 -
运行更高效且更快。
此外,RLOO 将整个补全 token 作为单一动作进行建模,如下图所示。在下一节中,我们将通过代码片段进一步详细介绍。
RLOO 是如何工作的
RLOO 和 PPO 有几个共同的步骤:
-
策略模型会生成一些补全 token ,并获取当前策略和参考策略下的每个 token 的对数概率。 -
然后我们计算每个 token 的 KL 惩罚,作为当前策略和参考策略下对数概率的差异。 -
接着我们从奖励模型中获得整个补全的得分。
从这里开始,常规的 PPO 和 RLOO 在方法上有所不同。RLOO 有几个关键想法。首先,它将 整个模型补全 视为单一动作,而常规 PPO 将 每个补全 token 视为单独的动作。通常,只有 EOS token 获得真正的奖励,这非常稀疏。常规 PPO 会将奖励归因于 EOS token,而 RLOO 会将 EOS 奖励归因于整个补全,如下所示。
from torch import Tensor
response = Tensor([4., 5., 6.])
per_token_logprobs = Tensor([-12.3, -8.3, -2.3])
reference_per_token_logprobs = Tensor([-11.3, -8.4, -2.0])
kl = per_token_logprobs - reference_per_token_logprobs
score_from_rm = 1.0
print(f"{kl=}") # kl=tensor([-1.0000, 0.1000, -0.3000])
per_token_reward = kl.clone()
per_token_reward[-1] += score_from_rm # assume last token is the EOS token
print(f"{per_token_reward=}") # per_token_reward=tensor([-1.0000, 0.1000, 0.7000])
print(f"{score_from_rm=}") # score_from_rm=1.0
print("#### Modeling each token as an action")
for action, reward in zip(response, per_token_reward):
print(f"{action=}, {reward=}")
# action=tensor(4.), reward=tensor(-1.)
# action=tensor(5.), reward=tensor(0.1000)
# action=tensor(6.), reward=tensor(0.7000)
print("#### Modeling the entire response as an action")
entire_generation_reward = per_token_reward.sum()
print(f"action='entire completion', reward={entire_generation_reward}")
# action='entire completion', reward=-0.2000 (-1 + 0.1 + 0.7)
其次,RLOO 使用 REINFORCE 损失,它基本上将 (奖励 - 基线) 与动作的对数概率相乘。在这里,我们突出了每个 token 的 REINFORCE 损失与整个补全的 REINFORCE 损失之间的区别。请注意,对于 PPO 的损失,我们还需要基于价值模型和
-
广义优势估计 (GAE) https://arxiv.org/abs/1506.02438
from torch import Tensor
response = Tensor([4., 5., 6.])
per_token_logprobs = Tensor([-12.3, -8.3, -2.3])
reference_per_token_logprobs = Tensor([-11.3, -8.4, -2.0])
kl = per_token_logprobs - reference_per_token_logprobs
score_from_rm = 1.0
print(f"{kl=}") # kl=tensor([-1.0000, 0.1000, -0.3000])
per_token_reward = kl.clone()
per_token_reward[-1] += score_from_rm # assume last token is the EOS token
print(f"{per_token_reward=}") # per_token_reward=tensor([-1.0000, 0.1000, 0.7000])
print(f"{score_from_rm=}") # score_from_rm=1.0
print("#### Modeling each token as an action")
for action, reward in zip(response, per_token_reward):
print(f"{action=}, {reward=}")
# action=tensor(4.), reward=tensor(-1.)
# action=tensor(5.), reward=tensor(0.1000)
# action=tensor(6.), reward=tensor(0.7000)
print("#### Modeling the entire response as an action")
entire_generation_reward = per_token_reward.sum()
print(f"action='entire completion', reward={entire_generation_reward}")
# action='entire completion', reward=-0.2000 (-1 + 0.1 + 0.7)
baseline = Tensor([0.2, 0.3, 0.4]) # dummy baseline
print("#### Modeling each token as an action")
advantage = per_token_reward - baseline
per_token_reinforce_loss = per_token_logprobs * advantage
print(f"{advantage=}") # advantage=tensor([-1.2000, -0.2000, 0.3000])
print(f"{per_token_reinforce_loss=}") # per_token_reinforce_loss=tensor([14.7600, 1.6600, -0.6900])
print(f"{per_token_reinforce_loss.mean()=}") # per_token_reinforce_loss.mean()=tensor(5.2433)
print("#### Modeling the entire response as an action")
advantage = entire_generation_reward - baseline.sum()
reinforce_loss = per_token_logprobs.sum()* advantage
print(f"{advantage=}") # advantage=tensor(-1.1000)
print(f"{reinforce_loss=}") # reinforce_loss=tensor(25.1900)
第三,RLOO 聪明地计算基线。注意我们上面使用了一个虚拟基线。在实际操作中,RLOO 使用批次中所有其他样本的奖励作为基线。下面是一个有 3 个提示和每个提示 4 个补全的例子。我们通过平均同一提示的所有其他补全的奖励来计算每个补全的基线。
import torch
local_batch_size = 3
rloo_k = 4
rlhf_reward = torch.tensor([
1, 2, 3, # first rlhf reward for three prompts
2, 3, 4, # second rlhf reward for three prompts
5, 6, 7, # third rlhf reward for three prompts
8, 9, 10, # fourth rlhf reward for three prompts
]).float() # here we have 3 prompts which have 4 completions each
# slow impl
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
advantages = torch.zeros_like(rlhf_reward)
for i in range(0, len(advantages), local_batch_size):
other_response_rlhf_rewards = []
for j in range(0, len(advantages), local_batch_size):
if i != j:
other_response_rlhf_rewards.append(rlhf_reward[j : j + local_batch_size])
advantages[i : i + local_batch_size] = rlhf_reward[i : i + local_batch_size] - torch.stack(
other_response_rlhf_rewards
).mean(0)
assert (1 - (2 + 5 + 8) / 3 - advantages[0].item()) < 1e-6
assert (6 - (3 + 2 + 9) / 3 - advantages[7].item()) < 1e-6
# vectorized impl
rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size)
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
vec_advantages = rlhf_reward - baseline
torch.testing.assert_close(vec_advantages.flatten(), advantages)
向 Arash Ahmadian 致谢,他提供了上述优势计算的向量化实现。
开始使用 TRL 的 RLOO
要开始使用 RLOO,你可以通过 pip install --upgrade trl
安装 TRL 的最新版本,并导入 RLOOTrainer。下面是一个展示一些高级 API 使用的简短代码片段。你可以随时查阅这些文档:
-
https://hf.co/docs/trl/main/en/rloo_trainer -
https://hf.co/docs/trl/main/en/ppov2_trainer
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
)
from trl.trainer.rloo_trainer import RLOOConfig, RLOOTrainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
base_model_name = "EleutherAI/pythia-1b-deduped"
tokenizer = AutoTokenizer.from_pretrained(base_model_name, padding_side="left")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
reward_model = AutoModelForSequenceClassification.from_pretrained(base_model_name, num_labels=1)
ref_policy = AutoModelForCausalLM.from_pretrained(base_model_name)
policy = AutoModelForCausalLM.from_pretrained(base_model_name)
train_dataset = ... # make sure to have columns "input_ids"
eval_dataset = ...
trainer = RLOOTrainer(
config=RLOOConfig(
per_device_train_batch_size=1,
gradient_accumulation_steps=64,
total_episodes=30000,
),
tokenizer=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
这是一个跟踪权重和偏差实验的
-
案例地址 https://wandb.ai/huggingface/trl/runs/dd2o3g35
在编写 RLOO 和 PPOv2 实现时,我们强调使模型开发的透明度更容易提升。特别是,我们已经增强了文档,包括对记录指标的解释以及阅读和调试这些指标的操作指南。例如,我们建议在训练期间密切监控 objective/rlhf_reward,这是 RLHF 训练的最终目标。
为了帮助可视化训练进度,我们定期记录模型的一些示例补全。这里是一个补全的例子。在一个权重和偏差跟踪运行的
-
示例地址 https://wandb.ai/huggingface/trl/runs/dd2o3g35
我们如何在 TRL 中实现 RLOO 训练器
我们基于新的实验性 PPOv2Trainer
实现了 RLOO 训练器,后者又是基于 https://arxiv.org/abs/2403.17031。有趣的是,我们实现的 RLOO 训练器仍然使用 PPO 损失。这是因为 REINFORCE 的损失是 PPO 的一个特例 (https://arxiv.org/abs/2205.09123)。请注意,即使对数概率明确出现在 REINFORCE 损失中,它也隐含在 PPO 损失中。眼见为实,所以让我们用一个简单的例子来证明这一点。
import torch.nn.functional as F
from torch import LongTensor, Tensor, gather, no_grad
action = LongTensor([1])
advantage = Tensor([1.0])
logits = Tensor([[1.0, 2.0, 1.0, 1.0]])
logits.requires_grad = True
all_logprob = F.log_softmax(logits, dim=-1)
with no_grad():
old_logprob = gather(all_logprob, 1, action.unsqueeze(-1)).squeeze(-1)
logprob = gather(all_logprob, 1, action.unsqueeze(-1)).squeeze(-1)
ratio = (logprob - old_logprob).exp()
ppo_loss = (ratio * advantage).mean() # [πθ(at | st) / πθ_old(at | st)* At]
# when the πθ and πθ_old are the same, the ratio is 1, and PPO's clipping has no effect
ppo_loss.backward()
print(f"{logits.grad=}") # tensor([[-0.1749, 0.5246, -0.1749, -0.1749]])
logits2 = Tensor([[1.0, 2.0, 1.0, 1.0]])
logits2.requires_grad = True
all_logprob2 = F.log_softmax(logits2, dim=-1)
logprob2 = gather(all_logprob2, 1, action.unsqueeze(-1)).squeeze(-1)
reinforce_loss = logprob2 * advantage # [log πθ(at | st)* At]
reinforce_loss.mean().backward()
print(f"{logits2.grad=}") # tensor([[-0.1749, 0.5246, -0.1749, -0.1749]])
实验
为了验证 RLOO 实现的有效性,我们在 Pythia 1B 和 6.9B 模型上进行了实验,并在这里发布了训练后的
-
检查点链接 https://hf.co/collections/vwxyzjn/rloo-ppov2-tl-dr-summarize-checkpoints-66679a3bfd95ddf66c97420d
我们从
-
Huang 等人,2024 https://arxiv.org/abs/2403.17031 -
https://hf.co/docs/trl/main/en/rloo_trainer#benchmark-experiments -
https://hf.co/docs/trl/main/en/rloo_trainer#benchmark-experiments
关键结果如下:
-
🚀高性能 RLOO 检查点: 使用 GPT4 作为评判模型,6.9B 检查点获得了 78.7% (k=2) 的偏好率,这甚至超过了原始
论文 中报告的最佳性能 77.9% (k=4) 和 74.2 (k=2)。这是一个很好的迹象,表明我们的 RLOO 训练按预期工作。论文地址 https://arxiv.org/abs/2402.14740 -
RLOO 1B 检查点的胜率为 40.1%,而 SFT 检查点的胜率为 21.3%。这是一个很好的迹象,表明 RLOO 训练按预期工作。
-
🤑 减少 GPU 内存并运行更快: RLOO 训练使用更少的内存并运行更快,使其成为在线 RL 训练中非常有用的算法。
数值稳定性: 黑暗面
尽管 RLOO 在性能和计算效率方面有优势,但我们想要强调一些数值问题。具体来说,生成过程中获得的响应对数概率与 bf16
下训练前向传递期间获得的对数概率在数值上略有不同。这给 PPO 和 RLOO 都带来了问题,但对于 RLOO 来说,问题更严重,如下所述。
例如,假设我们正在为两个序列生成 10 个 token。在 fp32
精度下,输出如下所示,其中 ratio = (forward_logprob - generation_logprob).exp()
,这是 PPO 用来裁剪的。在第一个周期和第一个小批量中,比率应该是完全相同的,因为模型还没有进行任何更新:
generation_logprob=tensor([[ -0.1527, -0.2258, -3.5535, -3.4805, -0.0519,
-2.3097, -2.0275, -0.4597, -0.1687, -0.0000],
[ -0.1527, -0.2258, -5.2855, -0.1686, -8.4760,
-4.3118, -1.0368, -0.8274, -1.6342, -2.6128]],
device='cuda:0')
forward_logprob=tensor([[-0.1527, -0.2258, -3.5535, -3.4805, -0.0519, -2.3097, -2.0275, -0.4597,
-0.1687],
[-0.1527, -0.2258, -5.2855, -0.1686, -8.4760, -4.3118, -1.0368, -0.8274,
-1.6342]], device='cuda:0', grad_fn=<SqueezeBackward1>)
ratio=tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
device='cuda:0', grad_fn=<ExpBackward0>)
ratio.mean()=0.9999998211860657
ratio.std()=6.592738373001339e-06
ratio.max()=1.0000133514404297
ratio.min()=0.9999887943267822
然而,在 bf16 精度下,我们得到结果:
generation_logprob=tensor([[ -0.1426, -0.1904, -3.5938, -3.4688, -0.0618,
-2.3906, -2.0781, -0.4375, -0.1562, -0.0000],
[ -0.1426, -0.1904, -5.2812, -0.1641, -8.5625,
-4.2812, -1.0078, -0.8398, -1.5781, -2.5781]],
device='cuda:0', dtype=torch.bfloat16)
forward_logprob=tensor([[-0.1445, -0.1670, -3.5938, -3.5156, -0.0554, -2.2969, -1.9688, -0.5273,
-0.1953],
[-0.1445, -0.1670, -5.2812, -0.1533, -8.5625, -4.3125, -1.0000, -0.7852,
-1.6641]], device='cuda:0', dtype=torch.bfloat16,
grad_fn=<SqueezeBackward1>)
ratio=tensor([[1.0000, 0.9766, 1.0000, 1.0469, 0.9922, 0.9102, 0.8945, 1.0938, 1.0391],
[1.0000, 0.9766, 1.0000, 0.9883, 1.0000, 1.0312, 0.9922, 0.9453, 1.0859]],
device='cuda:0', dtype=torch.bfloat16, grad_fn=<ExpBackward0>)
ratio.mean()=1.0
ratio.std()=0.051025390625
ratio.max()=1.09375
ratio.min()=0.89453125
和在 fp16 精度下,我们得到的结果
generation_logprob=tensor([[ -0.1486, -0.2212, -3.5586, -3.4688, -0.0526,
-2.3105, -2.0254, -0.4629, -0.1677, -0.0000],
[ -0.1486, -0.2212, -5.2852, -0.1681, -8.4844,
-4.3008, -1.0322, -0.8286, -1.6348, -2.6074]],
device='cuda:0', dtype=torch.float16)
forward_logprob=tensor([[-0.1486, -0.2212, -3.5586, -3.4805, -0.0529, -2.3066, -2.0332, -0.4629,
-0.1676],
[-0.1486, -0.2212, -5.2852, -0.1682, -8.4766, -4.3008, -1.0322, -0.8281,
-1.6299]], device='cuda:0', dtype=torch.float16,
grad_fn=<SqueezeBackward1>)
ratio=tensor([[1.0000, 1.0000, 1.0000, 1.0117, 1.0000, 0.9961, 1.0078, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000, 0.9922, 1.0000, 1.0000, 0.9995, 0.9951]],
device='cuda:0', dtype=torch.float16, grad_fn=<ExpBackward0>)
ratio.mean()=1.0
ratio.std()=0.00418853759765625
ratio.max()=1.01171875
ratio.min()=0.9921875
请注意, bf16
的比率由于某种原因非常不稳定。当比率变得很大时,PPO 的裁剪系数 = 0.2 开始发挥作用, 取消 那些比率大于 1.2 或小于 0.8 的 token 的梯度。对于 RLOO,这个问题更为极端,因为我们看到的是 (forward_logprob.sum(1) - generation_logprob.sum(1)).exp() = [ 1.0625, 12.1875]
,这意味着整个第二个序列的梯度被取消了。
在实际操作中,我们注意到 PPO 取消了大约 3% 的批次数据的梯度,而 RLOO 取消了大约 20-40% 的批次数据。从理论上讲,当不使用小批量时,RLOO 应该取消 0 %的批次数据。重要的是,我们观察到,一旦我们增加了在生成新批次之前的梯度步骤数 (通过 num_ppo_epochs 和 num_mini_batches),RLOO 的裁剪比率并没有显著变化; 这提供了实证证据,表明裁剪比率确实是由于 bf16 的数值问题,而不是因为行为和最新策略有很大不同,正如论文中所定位的。
要了解有关最新问题更新的更多信息,请查看
-
Issue #31267 链接 https://github.com/huggingface/transformers/issues/31267
结论
TRL 中引入的 RLOO (REINFORCE Leave One-Out) 训练器是在线 RLHF 训练中一个令人兴奋的算法,它提供了一个比 PPO 更易访问和高效的替代方案。通过减少 GPU 内存使用和简化训练过程,RLOO 使得可以使用更大的批量大小和更快的训练时间。我们的实验表明,RLOO 在响应胜率方面与 PPO 竞争,并且优于 DPO 检查点,使其成为有效的在线 RLHF 的有力工具。查看我们的文档来开始使用吧!
-
https://hf.co/docs/trl/main/en/rloo_trainer -
https://hf.co/docs/trl/main/en/ppov2_trainer
致谢和感谢
我们要感谢 Lewis Tunstall, Sara Hooker, Omar Sanseviero 和 Leandro Von Werra 对这篇博客提供的宝贵反馈。
原文链接: https://hf.co/blog/putting_rl_back_in_rlhf_with_rloo
原文作者: Shengyi Costa Huang, Arash Ahmadian
译者: innovation64
本文分享自微信公众号 - Hugging Face(gh_504339124f0f)。
如有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
Go-Job让你的任务调度不再繁琐
一、背景 在选择任务调度平台时,团队遇到了一些实际的问题。现有的开源项目如XXL-Job、Elastic-Job,虽然功能强大,但主要是围绕Java设计,而我们团队主要使用Go语言进行开发。这使得我们在集成和使用这些工具时遇到了诸多不顺。经过深入的调研和讨论,决定开发一个适合Go语言的任务调度框架,以满足我们的特定业务需求。于是,Go-Job应运而生。 为了让大家有个全面的了解,接下来主要探讨它的架构设计和功能特性。 本文的另一亮点是借助GPT生成了一些精美的章节彩色插图。看看大家和GPT是否是"同款"理解! 二、架构设计 Go-Job的目标是充分利用Go语言的优势,提供高性能、易扩展的分布式任务调度解决方案,满足不同业务团队的复杂需求和快速变化的技术环境。就像量脚定制鞋一样,专为Go项目量身打造! 整体架构 首先,让我们先认识一些关键术语: namespace:命名空间,用于资源隔离。 handler:任务处理类,用户自定义实现。 worker:运行handler的业务服务,通过SDK接入,与Go-Job服务端通信。 trigger:调度平台上的触发器,包含调度规则、模式、超时配置等...
- 下一篇
探索新升级!在 ART-Pi Smart 体验 RT-Thread Smart v5.1.0
1.引言 RT-Thread Smart v5.1.0 已经正式发布。这一版本在内核和功能上做了大量的改进与增强。我们可以在ART-Pi Smart开发板尽情探索这一新版更完善更强大的RT-Thread Smart操作系统。ART-Pi Smart开发板搭载了米尔科技的i.MX6ULL核心板,硬件设计和制作由韦东山团队完成,整体的BSP板级支持包由社区完成。这是首个运行 RT-Thread Smart 的 ARM Cortex-A7 开发板,具有强大的硬件性能和丰富的外设接口,非常适合学习,开发和调试RT-Thread Smart系统。 硬件规格情况 🔗开发板购买链接:https://m.tb.cn/h.gUlY2J1?tk=7ah8Wy7fME1 在ART-Pi Smart的板卡上包括了百兆以太网,TF卡插槽,USB device,USB host接口,以及RW007 WiFi模块,Type-C的USB UART,RGB888 FPC座,Camera FPC座等资源。核心板上板载i.MX6ULL 800MHz ARM Cortex-A7单核32位处理器,512MB DDR3内存,4...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- Red5直播服务器,属于Java语言的直播服务器
- CentOS7,8上快速安装Gitea,搭建Git服务器
- CentOS6,7,8上安装Nginx,支持https2.0的开启
- CentOS8,CentOS7,CentOS6编译安装Redis5.0.7
- SpringBoot2更换Tomcat为Jetty,小型站点的福音
- Jdk安装(Linux,MacOS,Windows),包含三大操作系统的最全安装
- SpringBoot2整合MyBatis,连接MySql数据库做增删改查操作
- SpringBoot2全家桶,快速入门学习开发网站教程
- CentOS8安装MyCat,轻松搞定数据库的读写分离、垂直分库、水平分库
- CentOS8编译安装MySQL8.0.19