前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >chatGLM3-LoRA微调实战

chatGLM3-LoRA微调实战

作者头像
逍遥壮士
发布2024-03-22 17:33:20
3960
发布2024-03-22 17:33:20
举报
文章被收录于专栏:技术趋势技术趋势

为什么要微调?

微调的主要目的是通过在特定任务上对预训练模型进行进一步训练,以适应该任务的需求,从而提高模型在该任务上的性能。具体来说,微调可以带来以下好处:

  • 提高模型的精度:预训练模型通常在大规模数据上进行训练,可以获得丰富的语言知识和语境信息。通过在特定任务上对这些预训练模型进行微调,我们可以更好地利用这些知识和信息,并将其适应到特定任务的需求中,从而提高模型在该任务上的精度。
  • 减少训练时间:与从头开始训练模型相比,微调通常只需要在相对较小的数据集上进行,因此训练时间更短。此外,由于预训练模型已经具有较好的参数初始化,微调也可以加速训练过程并提高模型的收敛速度。
  • 增加模型的泛化能力:由于预训练模型已经在大规模数据上进行了训练,因此具有较好的泛化能力。通过微调,我们可以将这种泛化能力应用到特定任务中,从而使模型能够更好地适应未见过的数据。

可以利用预训练模型在特定任务上进行进一步训练,从而提高模型的性能,加速训练过程并增加模型的泛化能力。

注意:ChatGLM3-6B-Base 和 ChatGLM3-6B-32K 模型不支持微调。

啥是泛化能力?

泛化能力是指机器学习算法对新鲜样本的适应能力,即算法在从原有数据集学习后,对具有同一规律的学习集以外的数据也能给出合适的输出。

个人理解:比如做题,不会加减乘除,通过不断做题,学会加减乘除,只要出现类似的题目你就会做了,而不用去关注具体输入内容是什么。

有哪些微调方式?

微调方法

名称

GPU占用

备注

LoRA

低(秩)rank 自适应微调方法全称全称是:LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

14082MiB 显存预计14GB

微软的研究者们于2021年通过论文《LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS》提出了低秩适应LoRA

Fine-Tuning

全参微调,一般结合:Accelerator和DeepSpeed框

至少上百GB

优点:全面全参微调效最效果比较好缺点:计算量太大及资用太大论文来源:https://arxiv.org/pdf/2110.07602.pdf

SFT全量微调

SFT是监督微调(Supervised Fine-Tun-ing)的缩写。这是一种常见的深度学习策略,通常在预训练的大语言模型上使用

4张显卡平均分配,每张显卡占用 48346MiB 显存大约是195GB

P-TuningV2

是ChatGLM模型专有的微调方式,

微调: 1张显卡,占用 18426MiB 显存,大约19GB内

论文来源:https://arxiv.org/pdf/2110.07602.pdf

模型精度

精度

说明

备注

32bit

32 比特全参数微调

FP16

16 比特冻结微调

或基于AQLM/AWQ/GPTQ/LLM.int8

8-bit

8比特

用于运行

4-bit

4比特

用于运行

微调实战

数据集准备

由于个人数据不方便公开这里就不贴具体的内容了,请参考官网的demo。

Tsinghua Cloud

参考:https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/README.md

说明:dev.json用于验证数据集的文件,train.json用于训练数据集的文件;

基于LORA微调方法

硬件要求

硬件名称

配置

备注

内存

16GB

GPU

15GB

多论对话模板
代码语言:javascript
复制
[
  {
"conversations": [
      {
"role": "system",
"content": "<system prompt text>"
      },
      {
"role": "user",
"content": "<user prompt text>"
      },
      {
"role": "assistant",
"content": "<assistant response text>"
      },
      // ... Muti Turn
      {
"role": "user",
"content": "<user prompt text>"
      },
      {
"role": "assistant",
"content": "<assistant response text>"
      }
    ]
  }
  // ...
]

工具能力模板

代码语言:javascript
复制
[
  {
"tools": [
      // available tools, format is not restricted
    ],
"conversations": [
      {
"role": "system",
"content": "<system prompt text>"
      },
      {
"role": "user",
"content": "<user prompt text>"
      },
      {
"role": "assistant",
"content": "<assistant thought to text>"
      },
      {
"role": "tool",
"name": "<name of the tool to be called",
"parameters": {
"<parameter_name>": "<parameter_value>"
        },
"observation": "<observation>"
        // don't have to be string
      },
      {
"role": "assistant",
"content": "<assistant response to observation>"
      },
      // ... Muti Turn
      {
"role": "user",
"content": "<user prompt text>"
      },
      {
"role": "assistant",
"content": "<assistant response text>"
      }
    ]
  }
  // ...
]

训练大约13G

基于5000次结果

测试微调结果

代码语言:javascript
复制
python inference_hf.py output/checkpoint-3000/ --prompt '类型#裙*材质#网纱*颜色#粉红色*图案#线条*图案#刺绣*裙腰型# 高腰*裙长#连衣裙*裙袖长#短袖*裙领型#圆领'
代码语言:javascript
复制
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [01:10<00:00, 10.02s/it]
粉红色连衣裙,高腰版型,让穿着者显得更加娇俏可爱。短袖袖型,让穿着者显得十分精神。圆领设计,修饰颈部线条,更显优雅。网纱刺绣,让连衣裙更加浪漫优雅。

3000次结果

5000次的结果

对比结果

训练数据:
代码语言:javascript
复制

代码语言:javascript
复制
{
    "conversations": [
        {
            "role": "user",
            "content": "类型#裙*版型#显瘦*材质#蕾丝*颜色#纯色*风格#知性*风格#高贵*风格#性感*图案#纯色*图案#蕾丝*裙型#背带裙*裙型#包臀裙*裙型#鱼尾裙*裙长#连衣裙*裙袖型#喇叭袖"
        },
        {
            "role": "assistant",
            "content": "蕾丝喇叭袖上衣,搭配鱼尾包臀背带裙,整体造型给人甜美可人的感觉。偏爱蕾丝的浪漫柔情,流露别致女人味。喇叭袖的设计凸显别样浪漫,透露隐约小性感。两件套连衣裙,平添视觉层次感。鱼尾的设计修身显瘦,喇叭袖时尚减龄,纯色设计更加凸显女性知性高贵的气质。"
        }
    ]
}
代码语言:javascript
复制
原版(未训练)
训练后

注:由于微调受限于参数以及存在随机性,所以若想实现FAQ的场景可以考虑使用向量数据库+大模型比如chatchat

基于P-TuningV2微调方法

由于GPU需要18426MiB后续支持

基于SFT 全量微调方法

由于需要GPU4张显卡平均分配,每张显卡占用 48346MiB 显存,后续支持

训练过程中遇到的问题

找不到对应文件:special_tokens_map.json

代码语言:javascript
复制
Loading checkpoint shards: 100%|██████████████████| 7/7 [01:31<00:00, 13.11s/it]
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /data/llm/llama/ChatGLM3/finetune_demo/finetune_hf.py:446 in main            │
│                                                                              │
│   443 │   │   ),                                                             │
│   444 ):                                                                     │
│   445 │   ft_config = FinetuningConfig.from_file(config_file)                │
│ ❱ 446 │   tokenizer, model = load_tokenizer_and_model(model_dir, peft_config │
│   447 │   data_manager = DataManager(data_dir, ft_config.data_config)        │
│   448 │                                                                      │
│   449 │   train_dataset = data_manager.get_dataset(                          │
│                                                                              │
│ /data/llm/llama/ChatGLM3/finetune_demo/finetune_hf.py:388 in                 │
│ load_tokenizer_and_model                                                     │
│                                                                              │
│   385 │   │   │   │   config=config,                                         │
│   386 │   │   │   )                                                          │
│   387 │   │   if peft_config.peft_type.name == "LORA":                       │
│ ❱ 388 │   │   │   model = AutoModelForCausalLM.from_pretrained(              │
│   389 │   │   │   │   model_dir,                                             │
│   390 │   │   │   │   trust_remote_code=True,                                │
│   391 │   │   │   │   empty_init=False,                                      │
│                                                                              │
│ /home/tools/yes/envs/chatGLM/lib/python3.10/site-packages/transformers/model │
│ s/auto/auto_factory.py:561 in from_pretrained                                │
│                                                                              │
│   558 │   │   │   │   model_class.register_for_auto_class(cls.__name__)      │
│   559 │   │   │   else:                                                      │
│   560 │   │   │   │   cls.register(config.__class__, model_class, exist_ok=T │
│ ❱ 561 │   │   │   return model_class.from_pretrained(                        │
│   562 │   │   │   │   pretrained_model_name_or_path, *model_args, config=con │
│   563 │   │   │   )                                                          │
│   564 │   │   elif type(config) in cls._model_mapping.keys():                │
│                                                                              │
│ /home/tools/yes/envs/chatGLM/lib/python3.10/site-packages/transformers/model │
│ ing_utils.py:3931 in from_pretrained                                         │
│                                                                              │
│   3928 │   │   │   model = quantizer.post_init_model(model)                  │
│   3929 │   │                                                                 │
│   3930 │   │   if _adapter_model_path is not None:                           │
│ ❱ 3931 │   │   │   model.load_adapter(                                       │
│   3932 │   │   │   │   _adapter_model_path,                                  │
│   3933 │   │   │   │   adapter_name=adapter_name,                            │
│   3934 │   │   │   │   token=token,                                          │
│                                                                              │
│ /home/tools/yes/envs/chatGLM/lib/python3.10/site-packages/transformers/integ │
│ rations/peft.py:180 in load_adapter                                          │
│                                                                              │
│   177 │   │   │   │   │   "adapter model."                                   │
│   178 │   │   │   │   )                                                      │
│   179 │   │   │                                                              │
│ ❱ 180 │   │   │   peft_config = PeftConfig.from_pretrained(                  │
│   181 │   │   │   │   peft_model_id,                                         │
│   182 │   │   │   │   token=token,                                           │
│   183 │   │   │   │   **adapter_kwargs,                                      │
│                                                                              │
│ /home/tools/yes/envs/chatGLM/lib/python3.10/site-packages/peft/config.py:137 │
│ in from_pretrained                                                           │
│                                                                              │
│   134 │   │   │   config_cls = cls                                           │
│   135 │   │                                                                  │
│   136 │   │   kwargs = {**class_kwargs, **loaded_attributes}                 │
│ ❱ 137 │   │   config = config_cls(**kwargs)                                  │
│   138 │   │   return config                                                  │
│   139 │                                                                      │
│   140 │   @classmethod                                                       │
╰──────────────────────────────────────────────────────────────────────────────╯
TypeError: PrefixTuningConfig.__init__() got an unexpected keyword argument 
'_name_or_path'

原因:训练过程中才会生成special_tokens_map.json,所以要把原来的去掉。

最后

通过训练可以扩展当前模型的泛化能力以及特定知识支持,官方默认为3000可以根据自已进行调整到5000或8000,由于本人的机器配置有点低所以效果,还过得去。当然如果想实现动态数据可以通过function calling方式进行实现,此块在请关注本公众号,后续继续输出。

参考文献:

https://zhuanlan.zhihu.com/p/646791309

https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/README.md

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2024-03-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 技术趋势 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 为什么要微调?
  • 啥是泛化能力?
  • 泛化能力是指机器学习算法对新鲜样本的适应能力,即算法在从原有数据集学习后,对具有同一规律的学习集以外的数据也能给出合适的输出。
  • 有哪些微调方式?
  • 模型精度
  • 微调实战
    • 数据集准备
      • 基于LORA微调方法
        • 硬件要求
        • 多论对话模板
      • 对比结果
      • 基于P-TuningV2微调方法
      • 基于SFT 全量微调方法
      • 训练过程中遇到的问题
      • 找不到对应文件:special_tokens_map.json
      相关产品与服务
      向量数据库
      腾讯云向量数据库(Tencent Cloud VectorDB)是一款全托管的自研企业级分布式数据库服务,专用于存储、检索、分析多维向量数据。该数据库支持多种索引类型和相似度计算方法,单索引支持千亿级向量规模,可支持百万级 QPS 及毫秒级查询延迟。腾讯云向量数据库不仅能为大模型提供外部知识库,提高大模型回答的准确性,还可广泛应用于推荐系统、自然语言处理等 AI 领域。
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档