前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >优化内存使用:TensorRT-LLM和StreamingLLM在Mistral上提升推理效率

优化内存使用:TensorRT-LLM和StreamingLLM在Mistral上提升推理效率

作者头像
GPUS Lady
发布2024-03-25 14:06:14
1450
发布2024-03-25 14:06:14
举报
文章被收录于专栏:GPUS开发者GPUS开发者

内存是否不足以支持长时间聊天内容的#LLM应用?NVIDIA工程师Song Han 开发了StreamingLLM,集成了TensorRT LLM v0.8。让我们看看StreamingLLM在中的应用吧!

Song Han 是NVIDIA杰出工程师,也是麻省理工学院电气工程与计算机科学系的副教授。他在深度学习领域取得了许多进展,并创办了多家人工智能公司。

在他的笔记里,介绍如何使用StreamingLLM框架在Mistral上运行推理。TensorRT-LLM为用户提供了一个易于使用的Python API,用于定义大型语言模型(LLM)并构建包含最先进优化的TensorRT引擎,以在NVIDIA GPU上高效进行推理。StreamingLLM是在MIT-Han-Lab开发的一种新型框架,并在TensorRT-LLM中得到支持。查看Github仓库获取更多示例和文档!

StreamingLLM简介

使用LLM处理无限长度文本存在挑战。特别是,存储所有先前的Key和Value(KV)状态需要大量内存,并且模型可能难以生成超出其训练序列长度的文本。StreamingLLM通过仅保留最近的标记和注意力汇聚,丢弃中间标记,来解决这个问题。这使得模型能够从最近的标记生成连贯的文本,而无需重置缓存 —— 这是以前方法中没有看到的能力。

StreamingLLM针对流式应用进行了优化,例如多轮对话。它非常适用于模型需要持续运行而不需要大量内存或依赖于过去数据的场景。一个示例是基于LLM的每日助手。StreamingLLM将让模型持续运行,根据最近的对话生成响应,而无需刷新其缓存。以前的方法在对话长度超过训练长度时可能需要重置缓存(丢失最近的上下文),或者重新计算来自最近文本历史的KV状态,这可能是耗时的。

代码语言:javascript
复制
!nvidia-smi

安装 TensorRT-LLM

代码语言:javascript
复制
!pip install -q ipywidgets
!pip install tensorrt_llm -U -q --extra-index-url https://pypi.nvidia.com

!wget https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/main/tensorrt_llm/models/llama/convert.py
!mv convert.py /usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/llama/
!wget https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/main/examples/llama/convert_checkpoint.py -P .
!wget https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/main/examples/run.py -P .
!wget https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/main/examples/utils.py -P .

将Mistral转换为TensorRT格式。为了启用StreamingLLM,我们需要在检查点转换中传递两个额外的标志。

  • dense_context_fmha - 在上下文阶段使用密集上下文fmha
  • enable_pos_shift - 允许我们在KV缓存中使用位置以进行RoPE
代码语言:javascript
复制
# Build the model model with StreamingLLM feature using a single GPU and FP16.
!python convert_checkpoint.py --model_dir mistralai/Mistral-7B-v0.1 \
                         --output_dir ./tllm_checkpoint_1gpu_streamingllm \
                         --dtype float16 \
                         --dense_context_fmha \
                         --enable_pos_shift

# Build the model model with StreamingLLM feature using a single GPU and FP16.
!python convert_checkpoint.py --model_dir mistralai/Mistral-7B-v0.1 \
                         --output_dir ./tllm_checkpoint_1gpu_nostream \
                         --dtype float16

为模型构建TensorRT引擎

代码语言:javascript
复制
# Streaming 
!trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_streamingllm \
            --output_dir ./mistralengine_streaming \
            --gemm_plugin float16

使用大型输入序列运行推理

我们使用一个开源的莎士比亚数据集进行演示。我们将125,000个字符作为我们的输入。

代码语言:javascript
复制
import requests
import re

url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'

response = requests.get(url)

if response.status_code == 200:
    story = response.text
    story = re.sub('\s+', ' ', story).strip()
else:
    story = None
    print("Failed to retrieve the document.")
代码语言:javascript
复制
%%time 

# Use the streaming engine with a sliding window/cache size 2048 and sink token length 4 
!python3 ./run.py --max_output_len=150 \
                  --tokenizer_dir mistralai/Mistral-7B-v0.1 \
                  --engine_dir=./mistralengine_streaming \
                  --max_attention_window_size=4096 \
                  --sink_token_length=4 \
                  --input_text f"{story[983152:]}"

原文代码链接:https://console.brev.dev/notebook/streamingllm-tensorrt-llm?=&linkId=100000248965640

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2024-03-18,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 GPUS开发者 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 安装 TensorRT-LLM
  • 为模型构建TensorRT引擎
  • 使用大型输入序列运行推理
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档