更快的辅助生成: 动态推测
⭐ 在这篇博客文章中,我们将探讨 动态推测解码 ——这是由英特尔实验室和 Hugging Face 开发的一种新方法,可以加速文本生成高达 2.7 倍,具体取决于任务。从
-
Transformers🤗 https://github.com/huggingface/transformers -
4.45.0 版本发布信息 https://github.com/huggingface/transformers/releases/tag/v4.45.0
推测解码
推测解码的单次迭代
动态推测解码
-
Transformers🤗 https://github.com/huggingface/transformers -
Leviathan 等人 https://arxiv.org/pdf/2211.17192 -
基于启发式方法的方法 https://hf.co/blog/assisted-generation
我们预计,通过增强优化策略来管理生成的草稿标记数量,可以进一步减少延迟。为了测试这个论点,我们利用一个预测器来确定每个推测迭代的最佳推测前瞻值 (SL)。该预测器利用草稿模型自回归的生成标记,直到草稿模型和目标模型之间的预测标记出现不一致。该过程在每个推测迭代中重复进行,最终确定每次迭代接受的草稿标记的最佳 (最大) 数量。草稿/目标标记不匹配是通过在零温度下 Leviathan 等人提出的拒绝抽样算法 (rejection sampling algorithm) 来识别的。该预测器通过在每一步生成最大数量的有效草稿标记,并最小化对草稿和目标模型的调用次数,实现了推测解码的全部潜力。我们称使用该预测器得到 SL 值的推测解码过程为预知 (orcale) 的推测解码。
下面的左图展示了来自
-
MBPP https://hf.co/datasets/google-research-datasets/mbpp -
Alpaca https://hf.co/datasets/tatsu-lab/alpaca
在 MBPP 的一个例子上的预知和静态推测前瞻值 (SL)。
在整个 Alpaca 数据集上平均的预知 SL 值。
上面的两个图表展示了预知推测前瞻值的多变性,这说明静态的推测解码可能使次优的。
为了更接近预知的推测解码并获得额外的加速,我们开发了一种简单的方法来在每次迭代中动态调整推测前瞻值。在生成每个草稿令牌后,我们确定草稿模型是否应继续生成下一个令牌或切换到目标模型进行验证。这个决定基于草稿模型对其预测的信心,通过 logits 的 softmax 估计。如果草稿模型对当前令牌预测的信心低于预定义的阈值,即 assistant_confidence_threshold
,它将在该迭代中停止令牌生成过程,即使尚未达到最大推测令牌数 num_assistant_tokens
。一旦停止,当前迭代中生成的草稿令牌将被发送到目标模型进行验证。
基准测试
我们在一系列任务和模型组合中对动态方法与启发式方法进行了基准测试。动态方法在所有测试中表现出更好的性能。值得注意的是,使用动态方法将 Llama3.2-1B
作为 Llama3.1-8B
的助手时,我们观察到速度提升高达 1.52 倍,而使用相同设置的启发式方法则没有显著的速度提升。另一个观察结果是, codegen-6B-mono
在使用启发式方法时表现出速度下降,而使用动态方法则表现出速度提升。
目标模型 | 草稿模型 | 任务类型 | 加速比 - 启发式策略 | 加速比 - 动态策略 |
---|---|---|---|---|
facebook/opt-6.7b | facebook/opt-125m | summarization | 1.82x | 2.71x |
facebook/opt-6.7b | facebook/opt-125m | open-ended generation | 1.23x | 1.59x |
Salesforce/codegen-6B-mono | Salesforce/codegen-350M-mono | code generation (python) | 0.89x | 1.09x |
google/flan-t5-xl | google/flan-t5-small | summarization | 1.18x | 1.31x |
meta-llama/Llama-3.1-8B | meta-llama/Llama-3.2-1B | summarization | 1.00x | 1.52x |
meta-llama/Llama-3.1-8B | meta-llama/Llama-3.2-1B | open-ended generation | 1.00x | 1.18x |
meta-llama/Llama-3.1-8B | meta-llama/Llama-3.2-1B | code generation (python) | 1.09x | 1.15x |
-
表格中的结果反映了贪婪解码 (temperature = 0)。在使用采样 (temperature > 0) 时也观察到了类似的趋势。 -
所有测试均在 RTX 4090 上进行。 -
我们的基准测试是公开的,允许任何人评估进一步的改进: https://github.com/gante/huggingface-demos/tree/main/experiments/faster_generation
代码
动态推测已经整合到 Hugging Face Transformers 库的 4.45.0 版本中,并且现在作为辅助解码的默认操作模式。要使用带有动态推测的辅助生成,无需进行任何代码更改,只需像平常一样执行代码即可:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
prompt = "Alice and Bob"
checkpoint = "EleutherAI/pythia-1.4b-deduped"
assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint).to(device)
outputs = model.generate(**inputs, assistant_model=assistant_model)
默认的动态推测前瞻的参数反应了最优的值,但是可以使用下面的代码进行调整来在特定模型和数据上获得更好的性能:
# confidence threshold
assistant_model.generation_config.assistant_confidence_threshold=0.4
# 'constant' means that num_assistant_tokens stays unchanged during generation
assistant_model.generation_config.num_assistant_tokens_schedule='constant'
# the maximum number of tokens generated by the assistant model.
# after 20 tokens the draft halts even if the confidence is above the threshold
assistant_model.generation_config.num_assistant_tokens=20
要恢复到 启发式 或 静态 方法 (如 num_assistant_tokens_schedule
设置为 'heuristic'
或 'constant'
,将 assistant_confidence_threshold=0
和 num_assistant_tokens=5
设置如下:
# Use 'heuristic' or 'constant' or 'dynamic'
assistant_model.generation_config.num_assistant_tokens_schedule='heuristic'
assistant_model.generation_config.assistant_confidence_threshold=0
assistant_model.generation_config.num_assistant_tokens=5
接下来是什么?
我们介绍了一种更快的辅助生成策略,名为动态推测解码,它优于启发式方法以及固定数量候选标记的方法。
在即将发布的博客文章中,我们将展示一种新的辅助生成方法: 将任何目标模型与任何助手模型结合起来!这将为在 Hugging Face Hub 上加速无法获得足够小的助手变体的无数模型打开大门。例如, Phi 3
、 Gemma 2
、 CodeLlama
等等都将有资格进行推测解码。敬请关注!
参考资料
-
Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models 。https://arxiv.org/abs/2405.04304
在这篇论文中,我们介绍了 DISCO,一种动态推测前瞻优化方法,利用分类器决定草稿模型是否应该继续生成下一个标记,还是暂停,并切换到目标模型进行验证,而不是仅仅使用对预测概率的简单阈值。
-
Assisted Generation: a new direction toward low-latency text generation https://hf.co/blog/assisted-generation -
Fast Inference from Transformers via Speculative Decoding https://arxiv.org/pdf/2211.17192
原文链接:
https://hf.co/blog/dynamic_speculation_lookahead 原文作者: Jonathan Mamou, Oren Pereg, Joao Gante, Lewis Tunstall, Daniel Korat, Nadav Timor, Moshe Wasserblat
译者: Zipxuan
本文分享自微信公众号 - Hugging Face(gh_504339124f0f)。
如有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
🥇荣誉上新|Alluxio 斩获「 OSCAR尖峰开源项目及开源社区 」
2024年10月16日,由中国通信标准化协会主办,中国信息通信研究院承办,中国信息通信研究院云计算开源产业联盟、金融行业开源技术应用社区、通信行业开源社区、科技制造开源社区、汽车行业开源社区、可信开源社区共同体、可信开源合规计划支持的开源领域顶级盛会——“OSCAR开源产业大会”在京成功举办,旨在进一步探索中国开源生态发展模式,加速开源技术在国内市场落地,提升企业开源治理能力,推动国内开源生态快速、健康有序发展。 🎯 大会特设立 “OSCAR 开源尖峰案例”评选,经过几个月多轮筛选,Alluxio在技术创新、社区建设和应用推广方面受到专家评委们的一致认可,从众多优秀的开源项目和社区中脱颖而出,斩获「OSCAR尖峰开源项目及开源社区」称号。
- 下一篇
使用 Optimum-Intel 和 OpenVINO GenAI 优化和部署模型
在端侧部署 Transformer 模型需要仔细考虑性能和兼容性。Python 虽然功能强大,但对于部署来说有时并不算理想,特别是在由 C++ 主导的环境中。这篇博客将指导您如何使用 Optimum-Intel 和 OpenVINO™ GenAI 来优化和部署 Hugging Face Transformers 模型,确保在最小依赖性的情况下进行高效的 AI 推理。 为什么使用 OpenVINO 来进行端侧部署 OpenVINO™ 最初是作为 C++ AI 推理解决方案开发的,使其非常适合在端侧设备部署中,其中最小化依赖性至关重要。随着引入 GenAI API,将大型语言模型 (LLMs) 集成到 C++ 或 Python 应用程序中变得更加简单,其特性旨在简化部署并提升性能。 第一步: 创建环境 预先准备 开始之前,请确保您的环境已正确配置了 Python 和 C++。安装必要的 Python 包: pipinstall--upgrade--upgrade-strategyeageroptimum[openvino] 以下是本文中使用的具体包: transformers==4.44o...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- Windows10,CentOS7,CentOS8安装Nodejs环境
- Docker快速安装Oracle11G,搭建oracle11g学习环境
- 2048小游戏-低调大师作品
- SpringBoot2初体验,简单认识spring boot2并且搭建基础工程
- CentOS7编译安装Cmake3.16.3,解决mysql等软件编译问题
- SpringBoot2配置默认Tomcat设置,开启更多高级功能
- CentOS7编译安装Gcc9.2.0,解决mysql等软件编译问题
- CentOS6,7,8上安装Nginx,支持https2.0的开启
- 设置Eclipse缩进为4个空格,增强代码规范
- CentOS关闭SELinux安全模块