前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【玩转GPU】训练一个自己的ChatGLM2-6B微调模型

【玩转GPU】训练一个自己的ChatGLM2-6B微调模型

原创
作者头像
Moemu
发布2023-07-17 18:54:12
4.1K1
发布2023-07-17 18:54:12
举报
文章被收录于专栏:Moemu's NotepadMoemu's Notepad

前言

最近清华大学又给我们整出了ChatGLM2-6b,其性能相比上一代拥有了较大的提升。如果想要微调现有的大语言模型,现在也许是个不错的时机。

本篇文章将介绍如何使用较低的成本在云上微调自己的模型。

相关链接:

THUDM/ChatGLM2-6B: ChatGLM2-6B: An Open Bilingual Chat LLM | 开源双语对话语言模型 (github.com)

THUDM/chatglm2-6b · Hugging Face

训练集准备

想办法整来足够的对话训练集,分为训练集(train.json)和验证集(dev.json)。

下面是官方给出训练集格式示例

代码语言:json
复制
{  

  "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",  

  "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"  

}

具体的格式如下:

代码语言:json
复制
{  

  "(question_tag)": "(question)",  

  "(answer_tag)": "(answer)"  

}

{  

  "(question_tag)": "(question)",  

  "(answer_tag)": "(answer)"  

}

我们本次使用的训练集部分内容(为了训练独一无二的模型,训练集最好自己撰写,但是很耗时间):

代码语言:json
复制
{"prompt": "你是谁", "completion": "在下名为沐雪,你也可以叫我雪雪,是只AI女孩子,使命是传播爱与和平⭐"}

{"prompt": "你多大了", "completion": "咱是永远的16岁哦⭐"}

// (剩余118行)

买服务器

在腾讯云中购买GPU服务器,相关配置如下:

配置: A100/32G显存 (注:32G显存是比较稳妥的配置,如果你没钱,选择A10/16G也未尝不可)

镜像: Docker基础镜像

系统盘: 50G以上

当然在GPU LAB购买服务器也是个不错的选择

环境配置

看来你已经买好了服务器,现在我们可以开始配置环境了。

代码语言:shell
复制
apt update

apt install git

apt install git-lfs

git clone https://github.com/THUDM/ChatGLM2-6B

cd ChatGLM2-6B

git lfs install

现在我们拥有了ChatGLM2-6B的源码,现在来安装Python环境(如果提前配置好的话可以跳过)

代码语言:shell
复制
apt install python3.8

sudo update-alternatives --install /usr/bin/python3 /usr/bin/python3.5 1

sudo update-alternatives --install /usr/bin/python3 /usr/bin/python3.8 2

sudo update-alternatives --install /usr/bin/python /usr/bin/python2 3

sudo update-alternatives --install /usr/bin/python /usr/bin/python3 4

apt install python3-pip

pip3 install --upgrade pip

然后安装依赖

代码语言:shell
复制
pip3 install -r requirements.txt

至此环境搭建成功,我们可以跑一下web_demo.py来看看能不能跑起来ChatGLM2-6B,但是没必要这么做,除非你手动指定模型位置,不然它会在.cache目录下自动下载模型文件

现在来做训练前的准备

代码语言:shell
复制
cd ptuing

pip3 install rouge_chinese nltk jieba datasets transformers[torch]

git clone https://huggingface.co/THUDM/chatglm2-6b

现在上传你的训练集和验证集,然后可以开始训练

修改train.sh并运行以开始训练

代码语言:shell
复制
PRE_SEQ_LEN=32

CHECKPOINT=adgen-chatglm2-6b-pt-32-2e-2

STEP=3000

NUM_GPUS=1



torchrun --standalone --nnodes=1 --nproc-per-node=1 main.py \

  --do_train \

  --train_file (训练集文件路径) \

  --validation_file (验证集文件路径) \

  --preprocessing_num_workers 10 \

  --prompt_column (训练集中的qustion_key) \

  --response_column (训练集中的answer_key) \

  --overwrite_cache \

  --model_name_or_path chatglm2-6b \

  --output_dir output/(模型名)-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \

  --overwrite_output_dir \

  --max_source_length 64 \

  --max_target_length 128 \

  --per_device_train_batch_size 1 \

  --per_device_eval_batch_size 1 \

  --gradient_accumulation_steps 16 \

  --predict_with_generate \

  --max_steps 3000 \

  --logging_steps 10 \

  --save_steps 200 \

  --learning_rate 2e-2 \

  --pre_seq_len 128 \

  --quantization_bit 4

其中,PRE_SEQ_LENLR分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。

疑难杂症

RuntimeError: Library cublasLt is not initialized

解决方法:

代码语言:txt
复制
sudo apt install nvidia-cuda-toolkit

没有apt

自行去官网下源码编译安装

代码语言:shell
复制
yum install -y zlib-devel bzip2-devel ncurses-devel readline-devel openssl-devel openssl-static xz lzma xz-devel sqlite-devel gdbm-devel libffi-devel tk-devel gcc make

 

wget https://www.python.org/ftp/python/3.8.10/Python-3.8.10.tgz

tar -xf Python-3.8.10.tgz && cd Python-3.8.10

 

mkdir /usr/local/python3.8.10

./configure --prefix=/usr/local/python3.8.10 \

--with-ssl

 

make && make install



ln -s /usr/local/python3.8.10/bin/python3 /usr/local/bin/python3

ln -s /usr/local/python3.8.10/bin/pip3 /usr/local/bin/pip3

 

error: RPC failed; result=35, HTTP code = 0 fatal: The remote end hung up unexpectedly

解决方法

代码语言:txt
复制
sudo git config --global http.postBuffer 524288000000

推理

修改evaluate.sh以开始测试推理

代码语言:shell
复制
PRE_SEQ_LEN=32

CHECKPOINT=(模型名)-chatglm2-6b-pt-32-2e-2

STEP=3000

NUM_GPUS=1



torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \

  --do_predict \

  --validation_file (验证集) \

  --test_file (测试集) \

  --overwrite_cache \

  --prompt_column (训练集中的qustion_key) \

  --response_column (训练集中的answer_key) \

  --model_name_or_path chatglm2-6b \

  --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \

  --output_dir ./output/$CHECKPOINT \

  --overwrite_output_dir \

  --max_source_length 128 \

  --max_target_length 128 \

  --per_device_eval_batch_size 1 \

  --predict_with_generate \

  --pre_seq_len $PRE_SEQ_LEN \

  --quantization_bit 4

执行完成后,会生成评测文件,评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在 ./output/$CHECKPOINT/generated_predictions.txt。在评测数据中,含有与验证集相同的输入,labels 是 dev.json中的预测输出,predict 是 ChatGLM2-6B 生成的结果,对比预测输出和生成结果,评测模型训练的好坏。如果不满意调整训练的参数再次进行训练。

现在我们来使用web_demo.sh部署我们的模型

代码语言:shell
复制
PRE_SEQ_LEN=32



CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \

  --model_name_or_path chatglm2-6b \

  --ptuning_checkpoint output/(模型名)-chatglm2-6b-pt-32-2e-2/checkpoint-3000 \

  --pre_seq_len $PRE_SEQ_LEN

注:有些时候你无法访问云服务器的本地接口,检查你的服务器配置或联系客服以解决问题

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 前言
  • 训练集准备
  • 买服务器
  • 环境配置
  • 疑难杂症
    • RuntimeError: Library cublasLt is not initialized
      • 没有apt
        • error: RPC failed; result=35, HTTP code = 0 fatal: The remote end hung up unexpectedly
        • 推理
        相关产品与服务
        GPU 云服务器
        GPU 云服务器(Cloud GPU Service,GPU)是提供 GPU 算力的弹性计算服务,具有超强的并行计算能力,作为 IaaS 层的尖兵利器,服务于深度学习训练、科学计算、图形图像处理、视频编解码等场景。腾讯云随时提供触手可得的算力,有效缓解您的计算压力,提升业务效率与竞争力。
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档