社区供稿 | 源码解析ChatGLM2多轮对话训练方法的不足,以及改进方法
01
前言
🎉Firefly项目支持微调ChatGLM2模型啦,我们实现了一种比ChatGLM2官方更加充分高效的多轮对话训练方法,并且沿袭了官方的数据组织格式。
在此之前,很多同学询问Firefly项目是否支持微调ChatGLM或ChatGLM2模型,而我们迟迟未进行适配的原因主要如下:
-
此前,Firefly虽然已支持微调Llma2、Llama、Baichuan、InternLM、Ziya、Bloom等开源大模型,但都是在Pretrain模型上进行指令微调,指令数据的组织格式相对自由,可按需自行设计。
-
ChatGLM不属于严格意义上的Causal Language Model(因果语言模型),因为它存在prefix attention mask的设计。对于prefix而言,它的attention是双向的,而预测部分的attention是单向的,存在一定的适配成本。但ChatGLM2做出了改变,它的注意力是单向的。
-
ChatGLM2是一个经过指令微调的chat模型,微调时遵从官方的数据组织格式,才能达到最优效果。
-
Firefly项目有自己独特的多轮对话训练方式。
对于预训练模型,可以自由设计训练数据的组织格式;对于chat模型,最好遵从官方的数据组织格式。
在适配ChatGLM2的过程中,我们阅读了一些ChatGLM2的官方代码,发现ChatGLM2的多轮对话训练方式存在不足之处,在后续章节中,我们也将从源码对其进行分析。我们也将分享Firefly如何实现对ChatGLM2进行更加充分高效的多轮对话训练,以及训练效果。
此前,我们专门分享过多轮对话的训练方法,结合阅读有助于理解:一文看懂:如何充分高效训练多轮对话大模型。
Firefly项目链接:
https://github.com/yangjianxin1/Firefly
firefly-chatglm2-6b权重:
https://huggingface.co/YeungNLP/firefly-chatglm2-6b
02
微调效果
对话示例1:
对话示例2:
03
ChatGLM2源码解析
-
ChatGLM2如何组织多轮对话训练数据? -
ChatGLM2采用何种方式训练多轮对话?
[Round 1]
问:{input1}
答:{target1}
[Round 2]
问:{input2}
答:{target2}
[Round 3]
问:{input3}
答:{target3}</s>
04
Firefly方法
方法概述
Firefly微调ChatGLM2的方法如下图所示,该方法的优势如下:
-
推理时候,模型不会出现“自问自答”和“不停止”的情况。
-
训练时,多轮对话中的每个回复都被充分利用。
-
计算高效,不需要将一条多轮对话数据拆分成多条数据。
在微调ChatGLM2时,Firefly基本上沿袭了ChatGLM2的数据组织格式,仅在每个target后面添加了</s>停止符。对于一条多轮对话数据,所有"{target}</s>"都会并行参与计算loss。并且因为</s>停止符的妙用,在推理时,模型不会遇到“自问自答”和“不停止”的情况。
[Round 1]
问:{input1}
答:{target1}</s>
[Round 2]
问:{input2}
答:{target2}</s>
[Round 3]
问:{input2}
答:{target2}</s>
为什么这种做法是可行的?详见文章:一文看懂:如何充分高效训练多轮对话大模型。
代码实现
Talk is cheap,Show me the code。接下来将从代码层面介绍我们是如何充分高效地实现多轮对话训练。
微调ChatGLM2时,Firefly将多轮对话拼接成如下格式。
[ ]
问:{input1}
答:{target1}</s>
[ ]
问:{input2}
答:{target2}</s>
[ ]
问:{input2}
答:{target2}</s>
在生成input_ids的时候,我们还会生成一个target_mask,取值为0或1,用来标记每个token是否属于target部分,即是否参与loss计算。其中“target</s>”部分的target_mask均为1,其他部分均为0。
我们会并行计算每个位置的loss,但只有target_mask=1的部分的loss,才会参与权重更新。这种方式充分利用了模型并行计算的优势,更加高效,并且多轮对话中的每个target部分都参与了训练,更加充分利用了数据。
数据组织格式如下:
class ChatGLM2SFTDataset(SFTDataset):
def __getitem__(self, index):
"""
基本沿袭ChatGLM2的指令微调的格式,做了小修改,多轮对话如下。
"""
# 每条数据格式为: [Round 1]\n\n问:{input1}\n\n答:{target1}</s>[Round 2]\n\n问:{input2}\n\n答:{target2}</s>...
data = self.data_list[index]
data = json.loads(data)
conversation = data['conversation']
input_format = '[Round {}]\n\n问:{}\n\n答:'
target_format = '{}'
# 收集多轮对话
utterances = []
for i, x in enumerate(conversation):
human = input_format.format(i+1, x['human'])
assistant = target_format.format(x['assistant'])
utterances += ([human, assistant])
utterances_ids = self.tokenizer(utterances, add_special_tokens=False).input_ids
# 每条数据格式为: [Round 1]\n\n问:{input1}\n\n答:{target1}</s>[Round 2]\n\n问:{input2}\n\n答:{target2}</s>...
input_ids = []
target_mask = [] # 用于对input进行mask,只计算target部分的loss
for i, utterances_id in enumerate(utterances_ids):
input_ids += utterances_id
# input部分
if i % 2 == 0:
target_mask += [0] * (len(utterances_id))
# target部分
else:
input_ids += [self.eos_token_id]
target_mask += [1] * (len(utterances_id) + 1)
assert len(input_ids) == len(target_mask)
# 对长度进行截断
input_ids = input_ids[:self.max_seq_length]
target_mask = target_mask[:self.max_seq_length]
attention_mask = [1] * len(input_ids)
assert len(input_ids) == len(target_mask) == len(attention_mask)
inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'target_mask': target_mask
}
return inputs
loss计算方式如下:
class TargetLMLoss(Loss):
def __init__(self, ignore_index):
super().__init__()
self.ignore_index = ignore_index
self.loss_fn = nn.CrossEntropyLoss(ignore_index=ignore_index)
def __call__(self, model, inputs, training_args, return_outputs=False):
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
target_mask = inputs['target_mask']
# 模型前馈预测
outputs = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
logits = outputs["logits"] if isinstance(outputs, dict) else outputs[0]
# 将labels中不属于target的部分,设为ignore_index,只计算target部分的loss
labels = torch.where(target_mask == 1, input_ids, self.ignore_index)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return (loss, outputs) if return_outputs else loss
本文转载自社区供稿内容,不代表官方立场。了解更多,请关注微信公众号"YeungNLP":
https://hf.link/tougao
本文分享自微信公众号 - Hugging Face(gh_504339124f0f)。
如有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
带你快速上手HetuEngine
本文分享自华为云社区《【手把手带你玩转HetuEngine】(一)HetuEngine快速上手》,作者:HetuEngine九级代言。 HetuEngine是什么 HetuEngine是华为推出的高性能交互式SQL分析及数据虚拟化引擎。与大数据生态无缝融合,实现海量数据秒级交互式查询;支持跨源跨域统一访问,使能数据湖内、湖间、湖仓一站式SQL融合分析。 HetuEngine适合做什么 适用于Hadoop集群(FusionInsight MRS)的Hive、Hudi数据源的交互式快速查询场景; 适用于跨源(多种数据源,如Hive,Hudi,HBase,GaussDB(DWS),Elasticsearch,ClickHouse等)查询; 适用于跨域(多个地域或数据中心)的快速联合查询; 不擅长大批量、复杂逻辑的跑批处理、创建事务、数据入库操作等。 HetuEngine特点 HetuEngine基本架构 HetuEngine面向企业级能力方面,构建了极致稳定、高性能的企业级交互式分析引擎。 云服务层:提供了企业级的运维管理监控能力,认证与业务接入统一访问入口,友好的可视化界面操作,一键式参数...
- 下一篇
介绍一下我们的开源“充电之旅” - 两位新晋 Apache Flink Committer 专访
本文出自字节跳动流式计算团队的方勇、胡伟华同学专访。两位同学在 Apache Flink 社区主要贡献了包括 Runtime Coordinator、Streaming Warehouse 等相关 Feature。于2023年7月正式受邀成为 Apache Flink Committer。 在软件开发的世界中,开源已成为普遍关注的话题。越来越多的企业和开发者认识到开源的重要性,并开始积极拥抱开源、贡献开源。自2017年开始,字节跳动流式计算团队开始尝试使用 Apache Flink 作为流式计算引擎,并逐步加大对开源社区的关注和投入。 近两个月来,团队方勇、胡伟华两位同学先后受邀成为 Apache Flink Committer。本文将对两位新晋 Committer 参与开源的心路历程进行专访。 我的开源参与之路 Apache Flink 是一个高性能的分布式计算框架,目前也已经是流式计算的事实标准,很大程度上推动了整个流式数据处理方面的发展。对于两位新晋 Committer 而言,Flink 在 Apache 中是不可忽视的明星项目。 作为一个非常活跃的社区,用户提出的问题很快就会...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- CentOS6,7,8上安装Nginx,支持https2.0的开启
- CentOS8,CentOS7,CentOS6编译安装Redis5.0.7
- SpringBoot2更换Tomcat为Jetty,小型站点的福音
- Jdk安装(Linux,MacOS,Windows),包含三大操作系统的最全安装
- CentOS7,8上快速安装Gitea,搭建Git服务器
- SpringBoot2整合MyBatis,连接MySql数据库做增删改查操作
- SpringBoot2全家桶,快速入门学习开发网站教程
- CentOS8安装MyCat,轻松搞定数据库的读写分离、垂直分库、水平分库
- CentOS8编译安装MySQL8.0.19
- CentOS7,CentOS8安装Elasticsearch6.8.6