为什么某些 batch size 会突然导致性能下降?
编者按:你是否曾在优化深度学习模型时感到困惑,明明增加了 batch size,GPU 利用率却没有如预期提升?在实际项目中,这个问题可能导致资源浪费、训练效率低下,甚至影响整个 AI 产品的交付周期。
本文作者深入剖析了现代 GPU 批处理的工作原理,揭示了内存带宽与计算能力之间的微妙关系。通过建立理论模型并结合实际实验,作者不仅解释了为什么某些 batch size 会突然导致性能下降,还提供了如何找到最佳 batch size 的方法。
作者 | Finbarr Timbers
编译 | 岳扬
一般来说,对于现代深度学习系统而言,你能做的第一个也是最重要的优化措施就是实现批处理(batching)。在进行推理时,不是单独处理单个输入,而是同时处理包含 N 个输入的一批数据。大多数情况下,这个操作是无需额外成本的 ------ 无论是处理单个输入还是 N 个输入,推理所需的时间几乎相同。这又是为何呢?表面上看,批量处理数据似乎应该消耗更多资源,毕竟,工作量增加了 N 倍。
然而,如果我们使用一个简单或者不成熟的模型来理解神经网络的工作方式,那么批处理(batching)的计算并不是没有成本的。实际上,批处理确实需要 N 倍的计算能力。如果你在 CPU 上运行某个特定的计算任务,你会发现前文提到的这一点是成立的。
然而,在现代 GPU 上运行相同的计算任务时,情况却并非如此。以下是我们在一款 T4 GPU 上观察到的情况:
从图中可以看到,batch size 从 1 到 3 时,所消耗的时间并不会增加。但是,一旦 batch size 超过 3,时间消耗就会呈线性增长。
这是什么原因呢?关键在于并发处理能力。现代 GPU 能够同时执行多次运算(尽管在单线程处理时,它们其实比 CPU 要慢)。
通常,当我们谈论"用模型对单个数据样本进行推理"时,容易把模型看作一个整体块(single block)。但实际上,模型是由众多矩阵组成的。推理过程中,每个矩阵都会被加载到内存中。具体来说,矩阵的每个块都会被加载到设备的共享内存单元(在 A100 显卡上仅有 192 kb)。这个块随后用于计算 batch 中每个元素的结果。需要注意的是,这与 GPU RAM(即 HBM)不同。A100 显卡根据型号不同,配备了 40 GB 或 80 GB 的 HBM,但设备内存仅有 192 kb。这导致在执行数学运算时,内存带宽成为了性能瓶颈,因为数据需要不断地在设备内存中读写。我们可以通过模型大小除以内存带宽来估算传输权重所需的时间,通过模型的浮点运算次数(FLOPS)除以 GPU 的 FLOPS 来估算计算所需的时间。
使用多层感知机(MLP),浮点运算次数(FLOPS)大约是参数数量的两倍乘以 batch 中元素的数量[1](即为 2 * m * n * b,数据批次大小(batch size)为 b ,矩阵为 m x n )。因此,当传输时间等于计算时间时,意味着:
在此,我们可以观察到左右两边的参数数量是可以相互抵消的:
同时,我们可以根据 batch size 来重新排列:
当 batch size 小于 FLOPS 与内存带宽的比值时,内存带宽将成为性能瓶颈。而一旦 batch size 超过了这个比值,计算能力(FLOPS)则成为新的瓶颈。 请注意,这一分析仅适用于多层感知机(MLP),对于像 ResNet50 这样的卷积神经网络来说,情况会更为复杂。
在 T4 GPU(产品规格表[2])上,其浮点运算能力达到 65 TFLOPS(32位浮点数),内存带宽则是 300 GB/s,按照这个数据,理想的运算效率比(magic ratio)应该是 216。实际运行一个深度为 8、宽度为 1024 的多层感知机(MLP)模型时,我们得到的结果与预期相吻合。
尽管数据中存在一些噪声干扰,但总体趋势与我们的预测一致:推理时间在接近 128 的阈值时开始急剧增加(在此,我们采取逐步加倍的方式来观察和记录不同 batch size 对推理时间(inference time)的影响)。如果我们改变 MLP 层的宽度,会发现这一现象在多种架构中都存在(下面是一张对数-对数(log-log)坐标图,以便所有的数据点都能在图表中清晰地显示)。
这真是太酷🆒了!我们可以看到在多种不同的模型架构中,都存在一个关键的阈值。有趣的是,较小的网络在处理速度上并没有随着 batch sizes(从 1 到 512)的增加而变化,基本保持恒定。 我对此的初步解释是,这是因为 GPU 在执行数学运算时速度极快,而其他硬件(如 CPU)则相对较慢。在实验初期,我们观察到了大量的噪声干扰,对于这一现象,我暂时只能归咎于"系统开销(overhead)"。
对于许多机器学习工程师而言,他们的时间往往没有花在机器学习本身,而是花在消除这些系统开销上,这些开销大多出现在非机器学习相关的代码中。在强化学习(RL)研究领域,尤其是那些专注于持续学习问题(continual learning problems)的研究者,除非1)他们拥有一个非常大的神经网络,或者2)对整个技术栈进行了极致优化,否则在实验中使用 GPU 往往并不划算。如果你想让一位曾在 DeepMind 工作过的工程师感到尴尬,可以问他们关于"内置计算图环境"(in-graph environments)的问题------在某个阶段,我们甚至是在 TensorFlow 的计算图中实现 RL 环境。
那么,卷积神经网络的情况又是如何呢?
在卷积神经网络中,权重的总数是滤波器数量与滤波器尺寸的乘积。以 torch.nn.Conv2d 为例,权重的计算方式是 kernel_size^2 乘以 out_channels。假设我们处理的是一张分辨率为 (224, 224) 的图像,步长为 1,卷积核大小为 3,那么每个滤波器会被重复使用 224 次。这就意味着,在卷积层中,批处理的优势并不明显,因为我们会反复使用相同的权重。至于池化层,其处理计算量与像素数量呈线性关系,这一点与你所想的相符。
Transformers 的情况又是怎么样呢?
Transformers 本质上就是多层感知机(MLPs),我们可以将它们视为相同的东西。它们具有注意力机制,但是,由于有了 KV 缓存(能够将计算数据保留在内存中),注意力机制所消耗的时间被大幅减少。我之前已经撰写文章对此进行了深入的探讨[3]。
这一观点同样适用于混合专家模型(Mixture of Experts model)。在许多 Transformers 的实现中,KV 缓存是内置于注意力类中的(例如,MaxText[4] 就是一个典型案例[5])。由于 MoE 模型与普通解码器之间的差异仅在于,某些前馈网络层被替换为了 MoE 层,因此 KV 缓存的表现将保持一致,推理过程也是如此,但有一点不同。
MoE 层中的门控机制会将数据批次(batch)分配到不同的专家上。如果门控没有均匀分配数据批次,就可能会引发一些问题。虽然有避免这种情况的路由机制(如"expert's choice"),但在自回归解码器中,我们通常只能采用"token's choice",这可能会导致门控出现偏差。强制门控均匀分配 tokens 是1)当前研究的焦点,并且是2)在训练过程中需要优化的一个重要目标。
Thanks for reading!
Hope you have enjoyed and learned new things from this blog!
About the authors
Finbarr Timbers
empiricist. ml researcher. previously: engineering at deepmind 🧠
END
本期互动内容 🍻
❓你在实际项目中是如何选择 batch size 的?有没有遇到过意外的性能瓶颈?
🔗文中链接🔗
[1]https://www.stat.cmu.edu/~ryantibs/convexopt-F18/scribes/Lecture_19.pdf
[3]https://www.artfintel.com/p/where-do-llms-spend-their-flops
[4]https://github.com/google/maxtext
[5]https://github.com/google/maxtext/blob/main/MaxText/layers/attentions.py#L91
原文链接:
https://www.artfintel.com/p/how-does-batching-work-on-modern

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
Gitee AI 助力医疗科研:医用耗材使用分析研究
本文作者:铂金小猪 我叫铂金小猪,万万没想到,我一个公立医院负责医疗耗材管理的野生程序员,有一天我会成为医疗界的「AI技术专家」(自己都觉得脸红)。上一篇文章《Gitee AI+Dify 双剑合璧,打造另类 RAG 知识库》中,我提到我们用 AI 技术参与了一个国家卫健委的科研项目。今天,我继续分享一下,我们是如何利用 AI 来参与到医用耗材管理工作里面的。 项目名称:2024医学工程研究项目 项目发起部门:国家卫健委医院管理研究所 课题名称:基于病种和临床路径的医用耗材使用管理和评价体系研究 课题单位:云南省第一人民医院、富源县人民医院 根据国家卫健委医院管理研究所的相关要求,基于病种和临床路径的医用耗材使用管理和评价体系研究项目的主要研究方向是,聚焦医院医用耗材精细化使用管理需求,研究建立基于病种和临床路径的医用耗材使用管理方法及评价体系,针对植介入重点管控医用耗材开展应用评价,优化医用耗材使用管理,服务于临床专科能力建设和医院高质量发展。 按照这一要求,我们对业务进行了分析,发现这事儿得要用 AI 的一些能力才能干完。 主要痛点分析 整个项目,基本就是围绕病历、结算数据去进行分...
- 下一篇
MyBatis布尔字段映射陷阱全过程解析
在开发过程中,我们常常会遇到一些看似简单却令人困惑的问题。本文记录了一次将 boolean 改为 Boolean 后,MyBatis 插入数据时出现的意外情况。本文不仅逐步揭示了问题的根本原因,还提供了解决方案,并强调了在开发中遵循规范和仔细排查问题的重要性。 背景 为了实现某个功能,需要为已有的表新增字段,其中有一个字段需要表达的含义是:是否有对话条数。 加字段要遵守规范,咱就去看了《阿里巴巴开发规约》的“MySQL规约”,有这么一段描述: 因此,“是否有对话条数”的字段名为 is_has_messages,数据类型为:unsigned tinyint(1表示是,0表示否;默认为1) 给mysql加好字段了,咱还得给 xxxDO 加上字段,按照上面的说法“POJO类中的任何布尔类型的变量,都不要加is前缀”,那就这么写: /** * 是否有对话条数;1表示是,0表示否 <br> * 默认为1 */private boolean hasMessages; 一切看起来是那么的自然~ 翻车 ▐奇怪的结果 is_has_messages 咋是0了?true 不应该映射为1吗?...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- CentOS7编译安装Gcc9.2.0,解决mysql等软件编译问题
- 2048小游戏-低调大师作品
- SpringBoot2全家桶,快速入门学习开发网站教程
- CentOS8安装MyCat,轻松搞定数据库的读写分离、垂直分库、水平分库
- Docker快速安装Oracle11G,搭建oracle11g学习环境
- Hadoop3单机部署,实现最简伪集群
- CentOS7,CentOS8安装Elasticsearch6.8.6
- Eclipse初始化配置,告别卡顿、闪退、编译时间过长
- SpringBoot2更换Tomcat为Jetty,小型站点的福音
- SpringBoot2编写第一个Controller,响应你的http请求并返回结果