前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >ChatGLM3 源码解析(二)

ChatGLM3 源码解析(二)

作者头像
ApacheCN_飞龙
发布2024-03-08 09:11:15
1820
发布2024-03-08 09:11:15
举报
文章被收录于专栏:信数据得永生信数据得永生

SelfAttention

代码语言:javascript
复制
class SelfAttention(torch.nn.Module):
    """
    自注意力的逻辑,包含四部分:
    +   从输入计算 QKV,
    +   对 QKV 分头,
    +   从 QKV 计算 O(在`CoreAttention`里面),
    +   从 O 计算输出
    """

    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        super(SelfAttention, self).__init__()
        # 层的序号
        self.layer_number = max(1, layer_number)
        # ProjSize:就是没有开启 MQA 情况下的 QKV 的尺寸
        # 等于 NHead * HeadSize,和原始的 HidSize 可能有不同·
        self.projection_size = config.kv_channels * config.num_attention_heads

        # HeadSize = ProjSize // NHead
        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
        # NHead
        self.num_attention_heads_per_partition = config.num_attention_heads
        # 控制是否启用MQA
        self.multi_query_attention = config.multi_query_attention
        # 如果不启用 MQA,QKVSize 就是三倍的 ProjSize
        self.qkv_hidden_size = 3 * self.projection_size
        if self.multi_query_attention:
            # 如果启用了 MQA
            # NGroup
            self.num_multi_query_groups_per_partition = config.multi_query_group_num
            # QKVSize 等于 ProjSize(Q) + 2 * HeadSize * NGroup (KV)
            self.qkv_hidden_size = (
                    self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
            )
        # 将输入映射成 QKV 的线性层
        self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
                                         bias=config.add_bias_linear or config.add_qkv_bias,
                                         device=device, **_config_to_kwargs(config)
                                         )
        # 用于从 QKV 计算 O 的核心模块
        self.core_attention = CoreAttention(config, self.layer_number)

        # 用于从 O 计算输出的线性层
        self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
                               device=device, **_config_to_kwargs(config)
                               )

    def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
        if self.multi_query_attention:
            num_attention_heads = self.num_multi_query_groups_per_partition
        else:
            num_attention_heads = self.num_attention_heads_per_partition
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            num_attention_heads,
            self.hidden_size_per_attention_head,
            dtype=dtype,
            device=device,
        )

    def forward(
            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
    ):
        # 输入隐藏状态尺寸为 [SeqLen, BatchSize, HidSize]

        # 使用输入计算 QKV
        mixed_x_layer = self.query_key_value(hidden_states)

        if self.multi_query_attention:
            # 如果开启了 MQA,将 QKV 按照最后一维分割
            # 得到 Q [SeqLen, BatchSize, ProjSize]
            # 和 K/V [SeqLen, BatchSize, NGroup * HeadSize]
            (query_layer, key_layer, value_layer) = mixed_x_layer.split(
                [
                    self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                ],
                dim=-1,
            )
            # 对 Q 分头,变形为 [SeqLen, BatchSize, NHead, HeadSize]
            query_layer = query_layer.view(
                query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )
            # 对 K 分头,变形为 [SeqLen, BatchSize, NGroup, HeadSize]
            key_layer = key_layer.view(
                key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
            # 对 V 分头,变形为 [SeqLen, BatchSize, NGroup, HeadSize]
            value_layer = value_layer.view(
                value_layer.size()[:-1]
                + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
        else:
            # 变形为 [SeqLen, BatchSize, NHead, 3 * HeadSize]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                               (self.num_attention_heads_per_partition,
                                3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # 将 QKV 最后一维平分三份,得到 Q/K/V
            # 尺寸为 [SeqLen, BatchSize, NHead, HeadSize]
            (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

        # 应用 ROPE
        if rotary_pos_emb is not None:
            query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
            key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

        # 如果传入了 KVCache
        # 拆分为 KCache 和 VCache
        # 每个形状为 [CacheLen, BatchSize, NGroup, HeadSize]
        # 分别添加到 K 和 V 前面
        if kv_cache is not None:
            cache_k, cache_v = kv_cache
            key_layer = torch.cat((cache_k, key_layer), dim=0)
            value_layer = torch.cat((cache_v, value_layer), dim=0)
        # 如果设置了 UseCache,则返回 KV
        if use_cache:
            kv_cache = (key_layer, value_layer)
        else:
            kv_cache = None

        # MQA 模式下,给 K 和 V 广播到 Q 的形状
        # [..., NGroup, ...] => [..., NGroup, 1, ...] =>
        # [..., NGroup, NHead // NGroup, ...] =>
        # [..., NHead, ...]
        if self.multi_query_attention:
            # K 变形为 [CacheSeqLen, BatchSize, NGroup, 1, HeadSize]
            key_layer = key_layer.unsqueeze(-2)
            # K 广播为 [CacheSeqLen, BatchSize, NGroup, NHead // NGroup, HeadSize]
            # NHead // NGroup 是每一组的头部数量
            # 相当于把最后一维复制了 NHead // NGroup 等份
            key_layer = key_layer.expand(
                -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
            )
            # K 变形为 [CacheSeqLen, BatchSize, NHead, HeadSize]
            key_layer = key_layer.contiguous().view(
                key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )
            # V 变形为 [CacheSeqLen, BatchSize, NGroup, 1, HeadSize]
            value_layer = value_layer.unsqueeze(-2)
            # V 广播为 [CacheSeqLen, BatchSize, NGroup, NHead // NGroup, HeadSize]
            value_layer = value_layer.expand(
                -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
            )
            # V 变形为 [CacheSeqLen, BatchSize, NHead, HeadSize]
            value_layer = value_layer.contiguous().view(
                value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )

        # 将 Q K V 和掩码数组传入核心模块,得到 O
        # 尺寸为 [SeqLen, BatchSize, ProjSize]
        context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)

        # 使用 O 计算输出,尺寸为 [SeqLen, BatchSize, HidSize]
        output = self.dense(context_layer)

        return output, kv_cache

CoreAttention

代码语言:javascript
复制
class CoreAttention(torch.nn.Module):
    '''
    包含了从分头的 QKV 计算 O 的逻辑
    '''
    def __init__(self, config: ChatGLMConfig, layer_number):
        super(CoreAttention, self).__init__()

        # 控制 QK 是否缩放
        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
        # 控制注意力矩阵是否转为 FP32
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
        # 缩放模式下必须为 FP32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        # 确保层序号大于等于 1
        self.layer_number = max(1, layer_number)

        # ProjSize = HeadSize * NHead
        projection_size = config.kv_channels * config.num_attention_heads

        # ProjSize
        self.hidden_size_per_partition = projection_size
        # HeadSize = HeadSize // NHead
        self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
        # NHead
        self.num_attention_heads_per_partition = config.num_attention_heads

        # 如果定义了 QK 缩放
        #     系数就是层序号
        #     d = 系数 * HeadSize
        # 否则 d =  HeadSize
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff
        self.coeff = coeff

        # 用于注意力矩阵的 Dropout
        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)

    def forward(self, query_layer, key_layer, value_layer, attention_mask):
        # Q:[SeqLen, BatchSize, NHead, HeadSize]
        # K:[CacheSeqLen, BatchSize, NHead, HeadSize]
        # V:[CacheSeqLen, BatchSize, NHead, HeadSize]
        
        # 如果 PyTorch 版本大于 2,直接调用内置函数
        pytorch_major_version = int(torch.__version__.split('.')[0])
        if pytorch_major_version >= 2:
            query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
            if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                                 is_causal=True)
            else:
                if attention_mask is not None:
                    attention_mask = ~attention_mask
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                                 attention_mask)
            context_layer = context_layer.permute(2, 0, 1, 3)
            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
            context_layer = context_layer.reshape(*new_context_layer_shape)
        else:
            # 否则自己实现计算逻辑

            # 定义注意力矩阵的尺寸
            # [BatchSize, NHead, Seqlen, CacheSeqLen]
            output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))

            # 合并 Q 中间两维,[Seqlen, BatchSize * NHead, HeadSize]
            query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
            # 合并 K 中间两维,[CacheSeqlen, BatchSize * NHead, HeadSize]
            key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)

            # 定义缓冲张量,形状和注意力矩阵相同
            # [BatchSize * NHead, SeqLen, CacheSeqLen]
            matmul_input_buffer = torch.empty(
                output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
                device=query_layer.device
            )

            # 交换 Q 前两维,[BatchSize * NHead, SeqLen, HeadSize]
            # 交换 K 前两维和后两维,[BatchSize * NHead, HeadSize, CacheSeqLen]
            # 计算原始注意力矩阵 A = Q @ K / d
            # beta=0 所以不受缓冲张量的影响
            matmul_result = torch.baddbmm(
                matmul_input_buffer,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
                alpha=(1.0 / self.norm_factor),
            )

            # 拆分 A 第一维,[BatchSize, NHead, Seqlen, CacheSeqLen]
            attention_scores = matmul_result.view(*output_size)

            # 如果定义了...,将其转为 FP32
            if self.attention_softmax_in_fp32:
                attention_scores = attention_scores.float()
            # 如果定义了系数,将其相乘
            if self.coeff is not None:
                attention_scores = attention_scores * self.coeff
            # 如果传入了掩码矩阵,并且注意力矩阵后两维相等(也就是没有KVCache)
            if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
                # 将掩码矩阵初始化为全1矩阵
                # 形状为 [BatchSize, 1, Seqlen, CacheSeqLen]
                attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
                                            device=attention_scores.device, dtype=torch.bool)
                # 只保留下三角元素,上三角置 0
                attention_mask.tril_()
                # 翻转矩阵,使上三角为 True,下三角为 False
                attention_mask = ~attention_mask
            # 如果传入了掩码矩阵,将其非零位置的元素设为 -inf
            if attention_mask is not None:
                attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
            # 注意力矩阵应用 SoftMax
            attention_probs = F.softmax(attention_scores, dim=-1)
            # 转回输入的数据类型
            attention_probs = attention_probs.type_as(value_layer)

            # 对注意力矩阵应用 Dropout
            attention_probs = self.attention_dropout(attention_probs)

            # 定义 O 的尺寸 [BatchSize, NHead, SeqLen, HeadSize]
            output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
            # 合并 V 中间两维,[CacheSeqLen, BatchSize * NHead, HeadSize]
            value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
            # 合并 A 前两维,[BatchSize * NHead, SeqLen, CacheSeqLen]
            attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
            # 交换 V 前两维,[BatchSize * NHead, CacheSeqLen, HeadSize]
            # 计算 O = A @ V
            context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
            # 拆分 O 前两维, [BatchSize, NHead, SeqLen, HeadSize]
            context_layer = context_layer.view(*output_size)
            # 将 O 转置为 [SeqLen, BatchSize, NHead, HeadSize]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
            # 合并 O 后两维,[SeqLen, BatchSize, ProjSize]
            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
            context_layer = context_layer.view(*new_context_layer_shape)

        # 返回 O
        return context_layer
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2024-03-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • SelfAttention
  • CoreAttention
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档