基于 P-Tuning v2 进行 ChatGLM2-6B 微调实践 | 京东云技术团队
微调类型简介
1. SFT监督微调:适用于在源任务中具有较高性能的模型进行微调,学习率较小。常见任务包括中文实体识别、语言模型训练、UIE模型微调。优点是可以快速适应目标任务,但缺点是可能需要较长的训练时间和大量数据。
2. LoRA微调:通过高阶矩阵秩的分解减少微调参数量,不改变预训练模型参数,新增参数。优点是减少了微调的参数量和成本,同时能达到与全模型微调相近的效果。
3. P-tuning v2微调:引入了prefix-tuning的思想,每一层都加入了prefix,并采用了多任务学习。解决了P-tuning v1中序列标注任务效果不佳和普遍性差的问题。其参数对象是各层的prefix。优点是适用于多任务学习,但在自然语言理解任务上表现可能不佳。
4. Freeze微调:主要用于大语言模型的微调,后几层网络提取语义特征,前几层提取文本表层特征。优点是参数高效,适用于提取特定层次的特征。
综上所述,各种微调方法适用于不同的场景和任务。SFT监督微调适用于快速适应目标任务,LoRA适用于减少参数量和成本,P-tuning v2适用于多任务学习,而Freeze适用于提取特定层次的特征。
1.下载glm2训练脚本
git clone https://github.com/THUDM/ChatGLM2-6B.git
2.然后使用 pip 安装依赖
pip install -r requirements.txt -i https://pypi.douban.com/simple/
运行行微调除 ChatGLM2-6B 的依赖之外,还需要安装以下依赖
pip install rouge_chinese nltk jieba datasets transformers[torch] -i https://pypi.douban.com/simple/
3.下载样例数据或者自己构建样例
{"content": "类型#裙*材质#网纱*颜色#粉红色*图案#线条*图案#刺绣*裙腰型#高腰*裙长#连衣裙*裙袖长#短袖*裙领型#圆领", "summary": "这款连衣裙,由上到下都透出女性魅力,经典圆领型,开口度恰好,露出修长的脖颈线条,很是优雅气质,短袖设计,这款对身材有很好的修饰作用,穿起来很女神;裙身粉红色花枝重工刺绣,让人一眼难忘!而且在这种网纱面料上做繁复图案的绣花,是很考验工艺的,对机器的要求会更高,更加凸显我们的高品质做工;"}
可以根据以上格式,构建自己的训练样本,我们可以用一些行业生产数据,如会话记录对模型进行训练,
官方示例数据下载:
https%3A//cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/%3Fdl%3D1
4.根据自己的环境修改训练脚本中对应的文件地址
PRE_SEQ_LEN=128 #序列的预设长度为128 LR=2e-2 #学习率为0.02 NUM_GPUS=4 #用几颗GPU进行训练 torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS main.py \ --do_train \ --train_file /export/data/train.json \ #设置训练数据文件的目录 --validation_file /export/data/validation.json \ #设置验证文件的目录 --preprocessing_num_workers 10 \ --prompt_column content \ --response_column summary \ --overwrite_cache \ --model_name_or_path /opt/tritonserver/python_backend/models/chatglm2-6b \ #模型目录 --output_dir /export/models/trained-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \ #训练后的模型目录 --overwrite_output_dir \ --max_source_length 64 \ --max_target_length 128 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 16 \ --predict_with_generate \ --max_steps 3000 \ --logging_steps 10 \ --save_steps 1000 \ --learning_rate $LR \ --pre_seq_len $PRE_SEQ_LEN \ --quantization_bit 4
5.开始训练吧
sh train.sh
训练中
快要训练完成
6.训练完成
Training completed. Do not forget to share your model on huggingface.co/models =)
{'train_runtime': 4598.3849, 'train_samples_per_second': 41.754, 'train_steps_per_second': 0.652, 'train_loss': 0.1287700497706731, 'epoch': 2400.0}
100%|██████████| 3000/3000 [1:16:37<00:00, 1.53s/it]
***** train metrics *****
epoch = 2400.0
train_loss = 0.1288
train_runtime = 1:16:38.38
train_samples = 24
train_samples_per_second = 41.754
train_steps_per_second = 0.652
7.部署训练后的模型
在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM-6B 模型以及 PrefixEncoder 的权重
model_path = "/opt/tritonserver/python_backend/models/chatglm2-6b" model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True) prefix_state_dict = torch.load(os.path.join('/opt/train/trained-chatglm2-6b-pt-128-1e-4/checkpoint-3000', "pytorch_model.bin")) new_prefix_state_dict = {} for k, v in prefix_state_dict.items(): if k.startswith("transformer.prefix_encoder."): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
8.过程中遇到的问题
8.1 微调后无法应答
PRE_SEQ_LEN=128 LR=2e-2 NUM_GPUS=1 torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS main.py \ --do_train \ --train_file train.json \ --validation_file dev.json \ --preprocessing_num_workers 10 \ --prompt_column content \ --response_column summary \ --overwrite_cache \ --model_name_or_path /opt/tritonserver/python_backend/models/chatglm2-6b \ --output_dir trained-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \ --overwrite_output_dir \ --max_source_length 64 \ --max_target_length 64 \ --per_device_train_batch_size 16 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 1 \ --predict_with_generate \ --max_steps 3000 \ --logging_steps 10 \ --save_steps 1000 \ --learning_rate $LR \ --pre_seq_len $PRE_SEQ_LEN \
使用官方脚本中的学习率设置 LR=2e-2 (0.02)
模型出现无法应答,灾难性遗忘,基本上原有的知识都遗忘了,无法应答普通提问 , 比如"你好.."
于是尝试使用 LR=1e-4 (0.0001) 进行训练
"1e-4" 表示 1 乘以 10 的 -4 次方,即等于 0.0001,"2e-2" 表示 2 乘以 10 的 -2 次方,即等于 0.02。
模型最终可以应答.
镜像问题:
https://github.com/THUDM/ChatGLM-6B/issues/1148
8.2 关于学习率:
我理解是,学习率大小像看书看的粗细,看的太粗就学的快(收敛快)但啥也学不到,
学习率是影响模型训练效果的重要参数。过大的学习率可能导致模型不稳定,过小的学习率则可能导致训练速度变慢。因此,需要反复试验,找到合适的学习率。
学习率(lr)表示每次更新权重参数的尺度(步长),ΔΘ=Θ0−(lr)(loss′)。
学习率与batch_size在权重更新中的关系
学习率(lr)直观可以看出lr越大,权重更新的跨度越大,模型参数调整变化越快。
batch_size对模型的影响,在于模型每次更新时,计算梯度是计算整个Batch的平均梯度,
即权重更新公式中的loss′=1batchsize(lossbatch)′, 整合就是 ΔΘ=Θ0−(lr)1batchsize(lossbatch)′ 。即lr与batch_size共同影响模型更新。
作者:京东科技 杨建
来源:京东云开发者社区 转发请注明来源

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
【交付高质量,用户高增长】-用户增长质量保证方法论 | 京东云技术团队
前言 俗话说,“测试是质量的守护者”,但单凭测试本身却远远不够。大多数情况下,测试像“一面镜子”,照出系统的面貌,给开发者提供修改代码的依据,这个“照镜子”的过程,就是质量评估的过程,或者说,测试的过程更像“量体温”,虽然可以测量出温度进而判断健康状况,却不能靠量体温治病。同时,需求交付的高质量不仅仅体现在结果层面,如功能、性能、可靠性、可用性、可维护性、安全性以及用户体验,也应该包括交付的过程层面,如业务需求的高质量、产品文档的高质量、提测代码的高质量等等。所以,应该站在更高的维度、更宽的视野来看待质量保证。 本文基于C端用户拉新的业务场景,以质量保证的全视角,总结了质量保证过程中的框架、策略、流程、规范、方法、工具以及实践,全面阐述了用户增长质量保证的价值观、方法论以及我们所理解的内涵,即高质量=质量策略多样化+质量流程标准化+质量活动规范化+质量工具平台化+质量运营常态化。 限于自身认知、专业能力以及总结能力的局限,文中有些观点可能有些偏颇甚至是错误的,同时也并不意味着我们的水平足够高,可以在这儿传道解惑,更大的意义在于我们在梳理本篇文章的过程,本身就是一个不断学习、总结、反思及...
- 下一篇
体验提升-一个“小技巧”彻底解决锦礼商品可见不可售 | 京东云技术团队
一、背景 锦礼平台,作为一家企业级B2B2C电商平台,同时服务于企业客户和企业员工,因此需要遵循企业客户的政策规范,确保商城内商品符合规定,并提升员工购物体验。然而,这种独特的运营模式导致锦礼平台上商品的可见不可售问题较为突出,对最终消费者的购物体验和平台的产品和业务产生了较大的负面影响。 二、解决方案 如题,之所以说是小技巧,是因为我们并没有使用一些高精的技术,只是把多种成熟技术结合加入一些算法而已。 以下是我们经历的3个版本的方案迭代,也代表着一个技术人从技术思维到业务思维的转变 版本1.0:我们尝试在不可售商品上增加一个遮罩,标注其不可售的原因,以防止客户误操作。然而,这种方法并未完全解决问题,因为消费者可能仍然对某些商品为何不可售(例如为何在锦礼平台无法购买黄金,或为何看到的商品被列入黑名单)感到困惑。 版本2.0:我们努力提升搜索的效率,加快不可售商品出库的速度,并优化商品同步的机制,以降低不可售商品的出现频率。然而,随着锦礼平台不可售规则的扩展(如定价规则和价格倒挂限制等),这种定制化的方式对搜索团队来说过于复杂。 版本3.0:随着我们对消费者需求的深入理解,我们逐渐意识到...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- 2048小游戏-低调大师作品
- Windows10,CentOS7,CentOS8安装MongoDB4.0.16
- SpringBoot2全家桶,快速入门学习开发网站教程
- Springboot2将连接池hikari替换为druid,体验最强大的数据库连接池
- Eclipse初始化配置,告别卡顿、闪退、编译时间过长
- Docker快速安装Oracle11G,搭建oracle11g学习环境
- SpringBoot2更换Tomcat为Jetty,小型站点的福音
- CentOS8,CentOS7,CentOS6编译安装Redis5.0.7
- CentOS7,8上快速安装Gitea,搭建Git服务器
- CentOS6,CentOS7官方镜像安装Oracle11G