前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >改善大型语言模型的3种简单方法

改善大型语言模型的3种简单方法

作者头像
磐创AI
发布2023-11-27 14:04:05
4630
发布2023-11-27 14:04:05
举报

大型语言模型(LLMs)已经成为现实。随着最近发布的Llama 2,开源LLMs正在接近ChatGPT的性能,并且经过适当调整,甚至可以超越它。

使用这些LLMs通常并不像看起来那么简单,特别是如果你想将LLM进行精细调整以适应特定用例。

在本文中,我们将介绍3种改善任何LLM性能的最常见方法:

  1. 提示工程
  2. 检索增强生成(RAG)
  3. 参数高效微调(PEFT)

还有许多其他方法,但这些是最简单的方法,可以在不多的工作量下带来重大改进。

这3种方法从最简单的方法开始,即所谓的低挂果,到更复杂的改进LLM的方法之一。

要充分利用LLMs,甚至可以将这三种方法结合起来使用!

在开始之前,这里是更详细的方法概述,以便更容易参考。

你还可以在Google Colab Notebook中跟随操作,以确保一切都按预期工作。

加载Llama 2

在开始之前,我们需要加载一个LLM,以便在这些示例中使用。我们选择基本的Llama 2,因为它展现出令人难以置信的性能,而且我也喜欢在教程中坚持使用基础模型。

在开始之前,我们首先需要接受许可协议。请按照以下步骤操作:

  1. 在此处创建一个HuggingFace帐户。
  2. 在此处申请Llama 2的访问权限。
  3. 在此处获取你的HuggingFace令牌。

完成后,我们可以使用HuggingFace凭据登录,以便此环境知道我们有权限下载我们感兴趣的Llama 2模型:

代码语言:javascript
复制
from huggingface_hub import notebook_login
notebook_login()

接下来,我们可以加载Llama 2的13B变体。

代码语言:javascript
复制
from torch import cuda, bfloat16
import transformers

model_id = 'meta-llama/Llama-2-13b-chat-hf'
pyt
# 4-bit Quanityzation to load Llama 2 with less GPU memory
bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,  
    bnb_4bit_quant_type='nf4',  
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=bfloat16
)

# Llama 2 Tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)

# Llama 2 Model
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    quantization_config=bnb_config,
    device_map='auto',
)
model.eval()

# Our text generator
generator = transformers.pipeline(
    model=model, tokenizer=tokenizer,
    task='text-generation',
    temperature=0.1,
    max_new_tokens=500,
    repetition_penalty=1.1
)

大多数开源LLMs在创建提示时都必须遵循某种模板。就Llama 2而言,以下内容有助于引导提示的编写:

这意味着我们必须按以下方式使用提示来正确生成文本:

代码语言:javascript
复制
basic_prompt = """
<s>[INST] <<SYS>>

You are a helpful assistant

<</SYS>>

What is 1 + 1? [/INST]
"""
print(generator(basic_prompt)[0]["generated_text"])

然后,生成以下输出:

代码语言:javascript
复制
"""
Oh my, that's a simple one! 
The answer to 1 + 1 is... (drumroll please)... 2! 😄
"""

这个模板并没有看起来那么复杂,但稍加练习,你很快就能掌握它。

现在,让我们深入探讨改进LLM输出的第一种方法,即提示工程。

1.提示工程 ⚙️

我们询问LLM某事的方式对我们获得的输出质量有重大影响。我们需要明确、完整,并提供我们感兴趣的输出的示例。

这种定制提示的过程称为提示工程。

提示工程是一种非常出色的“调整”模型的方式。它不需要更新模型,你可以快速迭代。

提示工程有两个主要概念:

  • 基于示例的
  • 基于思考的
基于示例的提示工程

在基于示例的提示工程中,例如一次性或少量示例学习,我们向LLM提供了一些我们寻找的示例。

这通常生成更符合我们期望的文本。

例如,让我们对一个简短的评论应用情感分类:

代码语言:javascript
复制
prompt = """
<s>[INST] <<SYS>>

You are a helpful assistant.

<</SYS>>

Classify the text into neutral, negative or positive. 
Text: I think the food was okay. [/INST]
"""
print(generator(prompt)[0]["generated_text"])

然后生成以下输出:

代码语言:javascript
复制
"""
Positive. The word "okay" is a mildly positive word, 
indicating that the food was satisfactory or acceptable.
"""

个人而言,我对这个答案并不太满意。我认为它更中性而不是积极。此外,我们必须在文本中搜索答案。

相反,让我们给出一个我们希望生成答案的示例:

代码语言:javascript
复制
prompt = """
<s>[INST] <<SYS>>

You are a helpful assistant.

<</SYS>>

Classify the text into neutral, negative or positive. 
Text: I think the food was alright.
Sentiment: 
[/INST] 

Neutral</s><s>

[INST]
Classify the text into neutral, negative or positive. 
Text: I think the food was okay.
Sentiment: 
[/INST] 
"""
print(generator(prompt)[0]["generated_text"])

然后查看输出,我们得到了预期结果:

代码语言:javascript
复制
"""
Neutral
"""

现在,LLM言之有物,只提供我们感兴趣的答案。好多了!

基于思考的提示工程

我们可以进一步要求LLM“思考”其答案。

通过让LLM将其思考分解为较小的步骤,允许每个步骤都有更多的计算。这些较小的步骤通常被称为LLM的“思考”。

有许多使用基于思考的提示的方式,我们可以使其相当复杂,如下所示:

然而,让我们从最基本的示例开始。如果我们要求模型通过将问题分解为步骤来解决问题呢?

代码语言:javascript
复制
prompt = """
<s>[INST] <<SYS>>

You are a helpful assistant.

<</SYS>>

Do the odd numbers in this group add up to an even number? 3, 5, 15, 32.
Solve by breaking the problem into steps. 
Identify the odd numbers, add them, and indicate whether the result is odd or even.[/INST]
"""
print(generator(prompt)[0]["generated_text"])

输出

代码语言:javascript
复制
"""
Hello! I'd be happy to help you with that. Let's break down the problem into steps:

Step 1: Identify the odd numbers in the group.

The odd numbers in the group are:

3, 5, 15

Step 2: Add the odd numbers.

3 + 5 + 15 = 23

Step 3: Is the result odd or even?

23 is an odd number.

Therefore, the sum of the odd numbers in the group is an odd number.
"""

正如我们之前看到的,这也被称为思维链,LLM有一系列单独的思维或步骤。

这些单独的步骤还有助于模型在计算过程中保持责任感。因为它对每个步骤都进行了“推理”,所以它的“思考”过程有结构。

2.检索增强生成(RAG) 🗃️

尽管提示工程可以带来改进,但它不能使LLM知道它之前没有学到的事情。

当一个LLM在2022年进行训练时,它对2023年发生的事情一无所知。

这就是检索增强生成(RAG)的用武之地。这是一种为LLM提供外部知识以便利用的方法。

在RAG中,知识库,如维基百科,被转化为数值表示以捕捉其含义,称为嵌入。这些嵌入存储在矢量数据库中,以便可以轻松检索信息。

然后,当你向LLM提供某个提示时,将在矢量数据库中搜索与提示相关的信息。

最相关的信息然后作为附加上下文传递给LLM,以便它可以生成其响应。

在实践中,RAG有助于LLM“查找”外部知识库中的信息,以改善其响应。

使用LangChain创建RAG管道

要创建RAG管道或系统,我们可以使用众所周知且易于使用的框架LangChain。

我们将首先创建有关Llama 2的小型知识库,并将其写入文本文件:

代码语言:javascript
复制
# Our tiny knowledge base
knowledge_base = [
    "On July 18, 2023, in partnership with Microsoft, Meta announced LLaMA-2, the next generation of LLaMA." ,
    "Llama 2, a collection of pretrained and fine-tuned large language models (LLMs) ",
    "The fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases.",
    "Meta trained and released LLaMA-2 in three model sizes: 7, 13, and 70 billion parameters.",
    "The model architecture remains largely unchanged from that of LLaMA-1 models, but 40% more data was used to train the foundational models.",
    "The accompanying preprint also mentions a model with 34B parameters that might be released in the future upon satisfying safety targets."
]
with open(r'knowledge_base.txt', 'w') as fp:
    fp.write('\n'.join(knowledge_base))

完成后,我们需要创建一个嵌入模型,可以将文本转换为数值表示,即嵌入。

我们将选择一个众所周知的句子嵌入模型,即sentence-transformers/all-MiniLM-L6-v2。

🔥提示🔥你可以在大规模文本嵌入基准(MTEB)排行榜上找到许多出色的模型。

代码语言:javascript
复制
from langchain.embeddings.huggingface import HuggingFaceEmbeddings

# Embedding Model for converting text to numerical representations
embedding_model = HuggingFaceEmbeddings(
    model_name='sentence-transformers/all-MiniLM-L6-v2'
)

现在,我们有了一个嵌入模型和一个小型知识库,可以开始开发我们的矢量数据库。有许多选择,但让我们选择一个可以在本地使用的,即FAISS。

🔥提示🔥其他出色的矢量数据库包括Pinecone和Weaviate。

代码语言:javascript
复制
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.document_loaders import TextLoader

# Load documents and split them
documents = TextLoader("knowledge_base.txt").load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)

# Create local vector database
db = FAISS.from_documents(docs, embedding_model)

db变量包含了数据库所需的一切,但仍然需要将其与LLM组合在一起。LangChain使这一过程非常简单和直接:

代码语言:javascript
复制
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline

# Load LLM into LangChain
llm = HuggingFacePipeline(pipeline=generator)

# RAG Pipeline
rag = RetrievalQA.from_chain_type(
    llm=llm, chain_type='stuff',
    retriever=db.as_retriever()
)

我们创建的llm和rag两个变量的伟大之处在于,我们可以在不使用RAG管道的情况下使用llm进行一些提示,使用RAG管道时则使用rag。

让我们首先尝试不使用RAG。如果我们问LLM关于自身Llama 2的问题,会发生什么?

代码语言:javascript
复制
>>> llm('What is Llama 2?')

"""
Llama 2 is a cryptocurrency that was created as a parody of the popular cryptocurrency Dogecoin. It was designed to be a more serious and less meme-focused alternative to Dogecoin, with a focus on privacy and security.
Llama 2 uses a proof-of-work consensus algorithm and has a block time of 1 minute. It also features a decentralized governance system that allows holders of the llama2 token to vote on proposals for the development and growth of the project.
One of the key features of Llama 2 is its use of "ring signatures" to enhance privacy and security. Ring signatures are a type of digital signature that can be used to hide the identity of the sender in a transaction. This makes it more difficult for hackers or other malicious actors to trace the source of a transaction.
Overall, Llama 2 is a privacy-focused cryptocurrency that is designed to provide users with greater control over their financial data and more secure transactions.
"""

加密货币?这不是我们要找的答案…它似乎对自己一无所知。

让我们尝试使用RAG管道:

代码语言:javascript
复制
>>> rag('What is Llama 2?')

"""
Llama 2 is a collection of pretrained and fine-tuned large language models 
(LLMs) announced by Meta in partnership with Microsoft on July 18, 2023.
"""

这好多了!

由于我们为Llama 2提供了关于自身的外部知识,它可以利用这些信息生成更准确的答案。

🔥提示🔥提示可能会很快变得复杂。如果你想知道LMM实际收到的提示,请在运行LMM之前运行以下代码:

代码语言:javascript
复制
import langchain
langchain.debug = True

3.参数高效微调 🛠️

无论是提示工程还是RAG,通常不会改变LLM本身。它的参数保持不变,模型不会“学习”任何新知识,它只是进行利用。

我们可以使用领域特定的数据对LLM进行精细调整,以使其学到新的东西。

与其微调模型的数十亿个参数,不如使用参数高效微调(PEFT)。正如其名称所示,它是一个子领域,专注于使用尽可能少的参数有效地微调LLM。

其中最常使用的方法之一被称为低秩适应(LoRA)。LoRA找到原始参数的一个小子集,无需触及基础模型。

这些参数可以看作是完整模型的较小表示,只对最重要或最有影响的参数进行训练。其美妙之处在于所得到的权重可以添加到基础模型中,因此可以单独保存。

使用AutoTrain对Llama 2进行微调

使用众多参数微调Llama 2的过程可能会很困难。幸运的是,AutoTrain可以帮助你解决大部分问题,使你只需一行代码即可进行微调!

首先,数据是最重要的,它对结果性能的影响最大!

我们将使基本的Llama 2模型成为一个聊天模型,并将使用OpenAssistant Guanaco数据集:

代码语言:javascript
复制
import pandas as pd
from datasets import load_dataset

# Load dataset in pandas
dataset = load_dataset("timdettmers/openassistant-guanaco")
df = pd.DataFrame(dataset["train"][:1000]).dropna()
df.to_csv("train.csv")

数据集包含许多问题/回答方案,你可以在上面对Llama 2进行训练。它用### Human标签区分用户,用### Assistant标签区分LLM的回应。

为了说明,我们只从该数据集中取了1000个样本,但质量更高的数据点肯定会提高性能。

注意:数据集需要一个文本列,AutoTrain将自动使用它。

训练本身非常简单,只需安装AutoTrain,然后运行以下代码:

代码语言:javascript
复制
autotrain llm --train \
--project_name Llama-Chat \
--model abhishek/llama-2-7b-hf-small-shards \
--data_path . \
--use_peft \
--use_int4 \
--learning_rate 2e-4 \
--num_train_epochs 1 \
--trainer sft \
--merge_adapter

有一些重要的参数:

data_path:数据的路径。我们在本地保存了一个包含文本列的train.csv,AutoTrain在训练期间将使用它。

model:我们要微调的基础模型。它是基础模型的分片版本,便于训练。

use_peft和use_int4:这些参数启用了对模型的高效微调,减少了所需的VRAM。它部分地利用了LoRA。

merge_adapter:为了更容易使用模型,我们将LoRA与基础模型合并,创建一个新模型。

运行训练代码时,你应该会得到类似以下内容的输出:

就是这样!以这种方式微调Llama 2模型非常简单,因为我们将LoRA权重与原始模型合并,所以可以像之前一样加载更新后的模型。

🔥提示🔥尽管一行代码进行微调令人惊叹,但强烈建议你自己查看参数。通过深入的指南学习精细调整的确切含义,有助于你了解何时出现问题。

更新:我上传了一份更详细介绍如何使用这些方法的视频版本到YouTube。

https://youtu.be/Rqu5Hjsbq6A

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2023-11-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 磐创AI 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 加载Llama 2
  • 1.提示工程 ⚙️
    • 基于示例的提示工程
      • 基于思考的提示工程
      • 2.检索增强生成(RAG) 🗃️
        • 使用LangChain创建RAG管道
        • 3.参数高效微调 🛠️
          • 使用AutoTrain对Llama 2进行微调
          相关产品与服务
          数据库
          云数据库为企业提供了完善的关系型数据库、非关系型数据库、分析型数据库和数据库生态工具。您可以通过产品选择和组合搭建,轻松实现高可靠、高可用性、高性能等数据库需求。云数据库服务也可大幅减少您的运维工作量,更专注于业务发展,让企业一站式享受数据上云及分布式架构的技术红利!
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档