chatglm2-6b在P40上做LORA微调 | 京东云技术团队
背景:
目前,大模型的技术应用已经遍地开花。最快的应用方式无非是利用自有垂直领域的数据进行模型微调。chatglm2-6b在国内开源的大模型上,效果比较突出。本文章分享的内容是用chatglm2-6b模型在集团EA的P40机器上进行垂直领域的LORA微调。
一、chatglm2-6b介绍
github: https://github.com/THUDM/ChatGLM2-6B
chatglm2-6b相比于chatglm有几方面的提升:
1. 性能提升: 相比初代模型,升级了 ChatGLM2-6B 的基座模型,同时在各项数据集评测上取得了不错的成绩;
2. 更长的上下文: 我们将基座模型的上下文长度(Context Length)由 ChatGLM-6B 的 2K 扩展到了 32K,并在对话阶段使用 8K 的上下文长度训练;
3. 更高效的推理: 基于 Multi-Query Attention 技术,ChatGLM2-6B 有更高效的推理速度和更低的显存占用:在官方的模型实现下,推理速度相比初代提升了 42%;
4. 更开放的协议:ChatGLM2-6B 权重对学术研究完全开放,在填写问卷进行登记后亦允许免费商业使用。
二、微调环境介绍
2.1 性能要求
推理这块,chatglm2-6b在精度是fp16上只需要14G的显存,所以P40是可以cover的。
EA上P40显卡的配置如下:
2.2 镜像环境
做微调之前,需要编译环境进行配置,我这块用的是docker镜像的方式来加载镜像环境,具体配置如下:
FROM base-clone-mamba-py37-cuda11.0-gpu # mpich RUN yum install mpich # create my own environment RUN conda create -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ --override --yes --name py39 python=3.9 # display my own environment in Launcher RUN source activate py39 \ && conda install --yes --quiet ipykernel \ && python -m ipykernel install --name py39 --display-name "py39" # install your own requirement package RUN source activate py39 \ && conda install -y -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ \ pytorch torchvision torchaudio faiss-gpu \ && pip install --no-cache-dir --ignore-installed -i https://pypi.tuna.tsinghua.edu.cn/simple \ protobuf \ streamlit \ transformers==4.29.1 \ cpm_kernels \ mdtex2html \ gradio==3.28.3 \ sentencepiece \ accelerate \ langchain \ pymupdf \ unstructured[local-inference] \ layoutparser[layoutmodels,tesseract] \ nltk~=3.8.1 \ sentence-transformers \ beautifulsoup4 \ icetk \ fastapi~=0.95.0 \ uvicorn~=0.21.1 \ pypinyin~=0.48.0 \ click~=8.1.3 \ tabulate \ feedparser \ azure-core \ openai \ pydantic~=1.10.7 \ starlette~=0.26.1 \ numpy~=1.23.5 \ tqdm~=4.65.0 \ requests~=2.28.2 \ rouge_chinese \ jieba \ datasets \ deepspeed \ pdf2image \ urllib3==1.26.15 \ tenacity~=8.2.2 \ autopep8 \ paddleocr \ mpi4py \ tiktoken
如果需要使用deepspeed方式来训练, EA上缺少mpich信息传递工具包,需要自己手动安装。
2.3 模型下载
huggingface地址: https://huggingface.co/THUDM/chatglm2-6b/tree/main
三、LORA微调
3.1 LORA介绍
paper: https://arxiv.org/pdf/2106.09685.pdf
LORA(Low-Rank Adaptation of Large Language Models)微调方法: 冻结预训练好的模型权重参数,在冻结原模型参数的情况下,通过往模型中加入额外的网络层,并只训练这些新增的网络层参数。
LoRA 的思想:
- 在原始 PLM (Pre-trained Language Model) 旁边增加一个旁路,做一个降维再升维的操作。
- 训练的时候固定 PLM 的参数,只训练降维矩阵A与升维矩B。而模型的输入输出维度不变,输出时将BA与 PLM 的参数叠加。
- 用随机高斯分布初始化A,用 0 矩阵初始化B,保证训练的开始此旁路矩阵依然是 0 矩阵。
3.2 微调
huggingface提供的peft工具可以方便微调PLM模型,这里也是采用的peft工具来创建LORA。
peft的github: https://gitcode.net/mirrors/huggingface/peft?utm_source=csdn_github_accelerator
加载模型和lora微调:
# load model tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True) model = AutoModel.from_pretrained(args.model_dir, trust_remote_code=True) print("tokenizer:", tokenizer) # get LoRA model config = LoraConfig( r=args.lora_r, lora_alpha=32, lora_dropout=0.1, bias="none",) # 加载lora模型 model = get_peft_model(model, config) # 半精度方式 model = model.half().to(device)
这里需要注意的是,用huggingface加载本地模型,需要创建work文件,EA上没有权限在没有在.cache创建,这里需要自己先制定work路径。
import os os.environ['TRANSFORMERS_CACHE'] = os.path.dirname(os.path.abspath(__file__))+"/work/" os.environ['HF_MODULES_CACHE'] = os.path.dirname(os.path.abspath(__file__))+"/work/"
如果需要用deepspeed方式训练,选择你需要的zero-stage方式:
conf = {"train_micro_batch_size_per_gpu": args.train_batch_size, "gradient_accumulation_steps": args.gradient_accumulation_steps, "optimizer": { "type": "Adam", "params": { "lr": 1e-5, "betas": [ 0.9, 0.95 ], "eps": 1e-8, "weight_decay": 5e-4 } }, "fp16": { "enabled": True }, "zero_optimization": { "stage": 1, "offload_optimizer": { "device": "cpu", "pin_memory": True }, "allgather_partitions": True, "allgather_bucket_size": 2e8, "overlap_comm": True, "reduce_scatter": True, "reduce_bucket_size": 2e8, "contiguous_gradients": True }, "steps_per_print": args.log_steps }
其他都是数据处理处理方面的工作,需要关注的就是怎么去构建prompt,个人认为在领域内做微调构建prompt非常重要,最终对模型的影响也比较大。
四、微调结果
目前模型还在finetune中,batch=1,epoch=3,已经迭代一轮。
作者:京东零售 郑少强
来源:京东云开发者社区 转载请注明来源

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
一款国产开源 Web 防火墙神器!
随着开源 Web 框架和各种建站工具的兴起,搭建网站已经是一件成本非常低的事情,但是网站的安全性很少有人关注,以至于 WAF 这个品类也鲜为人知。 一、WAF 是什么? WAF 是 Web 应用防火墙(Web Application Firewall)的缩写,也就是我们俗称的网站防火墙。它可以保护网站不被黑客所攻击,通常以 Web 网关的形式存在,作为反向代理接入。WAF 可以识别常见的 Web 攻击并实施阻断,比如:SQL 注入、跨站脚本攻击(XSS)、跨站请求伪造(CSRF)、服务端请求伪造(SSRF)、WebShell 上传与通信等等。 二、雷池 今天 HelloGitHub 给大家带来的是一款开箱即用、功能强大、广受好评的网站防护工具——雷池 WAF,不让黑客越雷池半步。 GitHub 地址:https://github.com/chaitin/safeline 雷池是一款简单易用、广受好评的社区 WAF 项目,它底层基于 Nginx 的 Web 网关,作为反向代理接入网络,清洗来自黑客的恶意流量,保护你的网站不受黑客攻击。雷池拥有友好的 Web 界面,就算你不具备网络安全技术...
- 下一篇
快速理解DDD领域驱动设计架构思想-基础篇 | 京东物流技术团队
1 前言 本文与大家一起学习并介绍领域驱动设计(Domain Drive Design) 简称DDD,以及为什么我们需要领域驱动设计,它有哪些优缺点,尽量用一些通俗易懂文字来描述讲解领域驱动设计,本篇并不会从深层大论述讲解落地实现,这些大家可以在了解入门后再去深层次学习探讨或在后续进阶和高级篇了解,希望通过本文介绍,可以让大家快速了解DDD并有一个基础的认知,DDD本身就是理论的集合,很难在不积累理论情况下来有效的实施DDD,仅仅看一些代码案例后就开搞,最终出来东西也是东施效颦,莫要好高骛远。 最后期望大家在工作中能多思考,如你所负责项目如果用DDD如何设计、以及会面临哪些挑战。 学习了解DDD之前,期望大家可在温顾下以往我们所了解掌握一些知识,努力让自己所学所掌握的内容沉淀下来,推荐阅读系列。 Head First 设计模式:基础面向对象概念和重要的设计模式; UML面向对象建模基础:从需求到分析,从分析到设计,从设计到编码,UML都有用武之地 实现领域驱动设计:很厚,更加务实,推荐阅读 领域驱动设计:张逸-DDD开山之作,挺玄幻的,多读几遍受益匪浅; 2 定义与概念 领域驱动设计(...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- SpringBoot2全家桶,快速入门学习开发网站教程
- Springboot2将连接池hikari替换为druid,体验最强大的数据库连接池
- Eclipse初始化配置,告别卡顿、闪退、编译时间过长
- Docker快速安装Oracle11G,搭建oracle11g学习环境
- 2048小游戏-低调大师作品
- SpringBoot2更换Tomcat为Jetty,小型站点的福音
- CentOS8,CentOS7,CentOS6编译安装Redis5.0.7
- CentOS7,8上快速安装Gitea,搭建Git服务器
- CentOS6,CentOS7官方镜像安装Oracle11G
- CentOS关闭SELinux安全模块