前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >ChatGLM-6B微调推理实战

ChatGLM-6B微调推理实战

作者头像
@小森
发布2024-08-24 14:09:52
20
发布2024-08-24 14:09:52
举报
文章被收录于专栏:xiaosen

ChatGLM-6B是一个由清华大学和智谱AI联合研发的开源对话语言模型,它基于General Language Model(GLM)架构,具有62亿参数,并支持中英双语问答。结合模型量化技术,用户可以在消费级的显卡上进行本地部署。在INT4量化级别下,最低只需6GB显存即可运行。

运行环境:

⭐️环境安装:

使用 pip 安装依赖:`pip install -r requirements.txt`,其中 `transformers` 库版本推荐为 `4.27.1`,但理论上不低于 `4.23.1` 即可。

⭐️代码调用方式:

代码语言:javascript
复制
>>> from transformers import AutoTokenizer, AutoModel
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
>>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
>>> model = model.eval()
>>> response, history = model.chat(tokenizer, "你好", history=[])
>>> print(response)
你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
>>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
>>> print(response)
晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法:

1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。
2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。
3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。
4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。
5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。
6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。

如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。
⭐️下载ChatGLM-6B

1️⃣Modelscope下载:ChatGLM下载

2️⃣ 阿里云OSS存储:(加快下载速度)其他云也一样

代码语言:javascript
复制
import os
dsw_region = os.environ.get("dsw_region")
# 从环境变量中获取dsw_region的值
url_link = {
    "cn-shanghai": "https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz",
    "cn-hangzhou": "https://atp-modelzoo.oss-cn-hangzhou.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz",
    "cn-shenzhen": "https://atp-modelzoo-sz.oss-cn-shenzhen-internal.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz",
    "cn-beijing": "https://atp-modelzoo-bj.oss-cn-beijing.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz", 
}

🌈我离北京较近,选择cn- beijing

代码语言:javascript
复制
path = url_link["cn-beijing"]
os.environ['LINK_CHAT'] = path
# 选择cn-beijing区域对应的URL,并将其赋值给path变量
!wget $LINK_CHAT
# wget命令下载LINK_CHAT环境变量指向的URL处的文件
!tar -xvf ChatGLM-6B-main.tar.gz

⭐️安装依赖
代码语言:javascript
复制
!cd ChatGLM-6B-main && pip install -r requirements.txt && \
pip install rouge_chinese nltk jieba datasets gradio==3.37.0

数据准备

数据文件为json文件。json文件中每条数据是一个字典,记录输入文本和输出文本。

我们使用官方的数据集:(AdvertiseGen_Simple)

代码语言:javascript
复制
!cd ChatGLM-6B-main/ptuning && wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/chatGLM/AdvertiseGen_Simple.zip  && unzip AdvertiseGen_Simple.zip
  • content 字段包含了用于生成文本的输入信息(如服装的类型、版型、风格等),而 summary 字段则包含了与输入信息相对应的输出文本(即描述服装的文本)

微调模型

train.sh:

代码语言:javascript
复制
PRE_SEQ_LEN=8
LR=1e-2

CUDA_VISIBLE_DEVICES=0 python main.py \
    --do_train \
    --train_file AdvertiseGen_Simple/train.json \
    --validation_file AdvertiseGen_Simple/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path chatglm-6b \
    --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --logging_steps 10 \
    --save_steps 6 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4 \
    --num_train_epochs 1
  • CUDA_VISIBLE_DEVICES=0:指定使用编号为0的GPU设备进行训练。
  • --quantization_bit 4:指定量化位数,这里设置为4位。
代码语言:javascript
复制
!cd ChatGLM-6B-main/ptuning && bash train.sh

模型推理

evaluate.sh:

代码语言:javascript
复制
PRE_SEQ_LEN=8
CHECKPOINT=adgen-chatglm-6b-pt-8-1e-2
STEP=6

CUDA_VISIBLE_DEVICES=0 python main.py \
    --do_predict \
    --validation_file AdvertiseGen_Simple/dev.json \
    --test_file AdvertiseGen_Simple/dev.json \
    --overwrite_cache \
    --prompt_column content \
    --response_column summary \
    --model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP  \
    --output_dir ./output/$CHECKPOINT \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_eval_batch_size 1 \
    --predict_with_generate \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4
  • --do_predict:表示要进行预测
  • --overwrite_cache:如果缓存文件已存在,则覆盖它们。
  • --overwrite_output_dir:如果输出目录已存在,则覆盖它。
代码语言:javascript
复制
!cd ChatGLM-6B-main/ptuning &&  bash evaluate.sh

🍺生成的结果:

🍺完成后可以启动web_demo.py 启动网页对话:

代码语言:javascript
复制
!cd ChatGLM-6B-main/ && python web_demo.py

一般我们下载大模型后,会有cli_demo.py和 web_demo.py:

cli_demo.py

  • 用户可以通过命令行参数与脚本交互,输入文本并获取模型的生成结果。

web_demo.py

  • 用户可以通过浏览器访问一个网页,输入文本并获取模型的生成结果。

💯访问本机(服务器)的端口7860就可以访问到 :

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-08-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • ⭐️下载ChatGLM-6B
  • ⭐️安装依赖
  • 微调模型
  • 模型推理
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档