内存是否不足以支持长时间聊天内容的#LLM应用?NVIDIA工程师Song Han 开发了StreamingLLM,集成了TensorRT LLM v0.8。让我们看看StreamingLLM在中的应用吧!
Song Han 是NVIDIA杰出工程师,也是麻省理工学院电气工程与计算机科学系的副教授。他在深度学习领域取得了许多进展,并创办了多家人工智能公司。
在他的笔记里,介绍如何使用StreamingLLM框架在Mistral上运行推理。TensorRT-LLM为用户提供了一个易于使用的Python API,用于定义大型语言模型(LLM)并构建包含最先进优化的TensorRT引擎,以在NVIDIA GPU上高效进行推理。StreamingLLM是在MIT-Han-Lab开发的一种新型框架,并在TensorRT-LLM中得到支持。查看Github仓库获取更多示例和文档!
StreamingLLM简介
使用LLM处理无限长度文本存在挑战。特别是,存储所有先前的Key和Value(KV)状态需要大量内存,并且模型可能难以生成超出其训练序列长度的文本。StreamingLLM通过仅保留最近的标记和注意力汇聚,丢弃中间标记,来解决这个问题。这使得模型能够从最近的标记生成连贯的文本,而无需重置缓存 —— 这是以前方法中没有看到的能力。
StreamingLLM针对流式应用进行了优化,例如多轮对话。它非常适用于模型需要持续运行而不需要大量内存或依赖于过去数据的场景。一个示例是基于LLM的每日助手。StreamingLLM将让模型持续运行,根据最近的对话生成响应,而无需刷新其缓存。以前的方法在对话长度超过训练长度时可能需要重置缓存(丢失最近的上下文),或者重新计算来自最近文本历史的KV状态,这可能是耗时的。
!nvidia-smi
!pip install -q ipywidgets
!pip install tensorrt_llm -U -q --extra-index-url https://pypi.nvidia.com
!wget https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/main/tensorrt_llm/models/llama/convert.py
!mv convert.py /usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/llama/
!wget https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/main/examples/llama/convert_checkpoint.py -P .
!wget https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/main/examples/run.py -P .
!wget https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/main/examples/utils.py -P .
将Mistral转换为TensorRT格式。为了启用StreamingLLM,我们需要在检查点转换中传递两个额外的标志。
# Build the model model with StreamingLLM feature using a single GPU and FP16.
!python convert_checkpoint.py --model_dir mistralai/Mistral-7B-v0.1 \
--output_dir ./tllm_checkpoint_1gpu_streamingllm \
--dtype float16 \
--dense_context_fmha \
--enable_pos_shift
# Build the model model with StreamingLLM feature using a single GPU and FP16.
!python convert_checkpoint.py --model_dir mistralai/Mistral-7B-v0.1 \
--output_dir ./tllm_checkpoint_1gpu_nostream \
--dtype float16
# Streaming
!trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_streamingllm \
--output_dir ./mistralengine_streaming \
--gemm_plugin float16
我们使用一个开源的莎士比亚数据集进行演示。我们将125,000个字符作为我们的输入。
import requests
import re
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
response = requests.get(url)
if response.status_code == 200:
story = response.text
story = re.sub('\s+', ' ', story).strip()
else:
story = None
print("Failed to retrieve the document.")
%%time
# Use the streaming engine with a sliding window/cache size 2048 and sink token length 4
!python3 ./run.py --max_output_len=150 \
--tokenizer_dir mistralai/Mistral-7B-v0.1 \
--engine_dir=./mistralengine_streaming \
--max_attention_window_size=4096 \
--sink_token_length=4 \
--input_text f"{story[983152:]}"
原文代码链接:https://console.brev.dev/notebook/streamingllm-tensorrt-llm?=&linkId=100000248965640