BGE M3-Embedding来自BAAI和中国科学技术大学,是BAAI开源的模型。相关论文在https://arxiv.org/abs/2402.03216,论文提出了一种新的embedding模型,称为M3-Embedding,它在多语言性(Multi-Linguality)、多功能性(Multi-Functionality)和多粒度性(Multi-Granularity)方面表现出色。M3-Embedding支持超过100种工作语言,支持8192长度的输入文本,同时支持密集检索(Dense Retrieval)、多向量检索(Multi-Vector Retrieval)和稀疏检索(Sparse Retrieval),为现实世界中的信息检索(IR)应用提供了统一的模型基础,通过这几种检索方式的组合,取得了良好的混合召回效果。
我们可以查看官方与openai模型的对比,整体来看,采用三种方式联合检索的BGE-M3(ALL)在三项评测中全面领先,而 BGE-M3(Dense)稠密检索在多语言、跨语言检索中具有明显优势。
1. 多语言(Multi-Linguality),训练集包含100+种以上语言 2. 多功能(Multi-Functionality),支持稠密检索(Dense Retrieval),还支持稀疏检索(Sparse Retrieval)与多向量检索(Multi-vector Retrieval) 3. 多粒度(Multi-Granularity) BGE-M3目前可以处理最大长度为8192 的输入文本,支持“句子”、“段落”、“篇章”、“文档”等不同粒度的输入文本
M3-Embedding模型的训练数据组成是模型效果较好的一个关键创新点,因为它旨在支持多语言性、多功能性和多粒度性,训练数据分为三部分:
这些训练数据的创新之处在于:
通过这种创新的训练数据组成,M3-Embedding
能够有效地学习并支持超过100种语言的文本嵌入,同时处理从短句到长达8192个词符的长文档,实现了在多语言、多功能和多粒度方面的突破。
从业界来看,M3-Embedding
、E5-mistral-7b
,都是利用GPT这样的LLM来合成了大量多语言数据,这个应该是后续的主流方案。
M3-Embedding统一了嵌入模型的三种常见检索功能,即密集检索(Dense retrieval)、词汇(稀疏)检索(Lexical retrieval)和多向量检索(Multi-vector retrieval)。以下是这些方法的公式化描述:
def dense_embedding(self, hidden_state, mask):
if self.sentence_pooling_method == 'cls':
return hidden_state[:, 0]
elif self.sentence_pooling_method == 'mean':
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = True):
# sparse_linear 线性层= torch.nn.Linear(in_features=self.model.config.hidden_size, out_features=1)
# 通过relu计算token weight
token_weights = torch.relu(self.sparse_linear(hidden_state))
if not return_embedding: return token_weights
# 形状为(input_ids.size(0), input_ids.size(1), self.vocab_size)的零张量
sparse_embedding = torch.zeros(input_ids.size(0), input_ids.size(1), self.vocab_size,
dtype=token_weights.dtype,
device=token_weights.device)
# 将token_weights中的值分散scatter到sparse_embedding的相应位置,索引位置根据input_ids提供
sparse_embedding = torch.scatter(sparse_embedding, dim=-1, index=input_ids.unsqueeze(-1), src=token_weights)
# CLS,PAD 等无用token
unused_tokens = [self.tokenizer.cls_token_id, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id,
self.tokenizer.unk_token_id]
sparse_embedding = torch.max(sparse_embedding, dim=1).values
# 无用token weight设置为0
sparse_embedding[:, unused_tokens] *= 0.
return sparse_embedding
由于嵌入模型的多功能性,检索过程可以在混合过程中进行。首先,可以通过每种方法单独检索候选结果(由于其高成本,可以免去多向量方法的这一步)。最终的检索结果是根据集成的相关性得分重新排序的:s_{\text{rank}} \leftarrow s_{\text{dense}} + s_{\text{lex}} + s_{\text{mul}}。
BGE-M3模型训练分为三个阶段:
1. 自学习蒸馏
人类可以利用多种不同的方式计算结果,矫正误差。模型也可以,通过联合多种检索方式的输出,可以取得比单检索模式更好的效果。因此,BGE-M3使用了一种自激励蒸馏方法来提高检索性能。具体来说,合并三种检索模式的输出,得到新的文本相似度分数,将其作为激励信号,让各单模式学习该信号,以提高单检索模式的效果。
2. 训练效率优化
通过根据长度对文本数据进行分组,确保一个batch内文本长度相对相似,从而减少填充。为了减少文本建模时的显存消耗,将一批数据分成多个小批。对于每个小批,利用模型编码文本,收集输出的向量同时丢弃所有前向传播中的中间状态,最后汇总向量计算损失,可以显著增加训练的batch size。
3. 长文本优化
BGE-M3提出了一种简单而有效的方法:MCLS(Multiple CLS)来增强模型的能力,而无需对长文本进行微调。
MCLS方法旨在利用多个CLS令牌来联合捕获长文本的语义。为每个固定数量的令牌插入一个cls令牌,每个cls令牌可以从相邻的令牌获取语义信息,最后通过对所有cls令牌的最后隐藏状态求平均值来获得最终的文本嵌入。
多语言检索任务,稀疏检索(Sparse)大幅超过了传统的稀疏匹配算法BM25。多向量检索(multi-vector)则获得了三种检索方式中的最佳效果。
跨语言检索能力(MKQA)
BGE-M3在跨语言任务上依然具备最佳的检索效果。稀疏检索并不擅长应对跨语言检索这种词汇重合度很小的场景。因此,稀疏检索的自身效果以及与其他方法混搭所带来的收益相对较小。
长文档检索能力 (MLRB: Multi-Lingual Long Retrieval Benchmark)
BGE-M3可以支持长达8192的输入文档,从实验结果可以观察到,稀疏检索(Sparse)的效果要显著高于稠密检索(Dense),这说明关键词信息对于长文档检索极为重要。
需要先安装,
pip install -U FlagEmbedding
git clone https://github.com/FlagOpen/FlagEmbedding.git
cd FlagEmbedding
pip install -e .
模型微调的数据集格式是json line格式文件,json格式如下:
{"query": str, "pos": List[str], "neg":List[str]}
query 是查询,pos 是正文本列表,neg 是负文本列表。
模型训练:
torchrun --nproc_per_node {number of gpus} \
-m FlagEmbedding.BGE_M3.run \
--output_dir {path to save model} \
--model_name_or_path BAAI/bge-m3 \
--train_data ./toy_train_data \
--learning_rate 1e-5 \
--fp16 \
--num_train_epochs 5 \
--per_device_train_batch_size {large batch size; set 1 for toy data} \
--dataloader_drop_last True \
--normlized True \
--temperature 0.02 \
--query_max_len 64 \
--passage_max_len 256 \
--train_group_size 2 \
--negatives_cross_device \
--logging_steps 10 \
--same_task_within_batch True \
--unified_finetuning True \
--use_self_distill True