前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Albert 源码解析:分组复用

Albert 源码解析:分组复用

作者头像
ApacheCN_飞龙
发布2023-10-13 09:28:40
1190
发布2023-10-13 09:28:40
举报
文章被收录于专栏:信数据得永生信数据得永生
代码语言:javascript
复制
class AlbertGroup(nn.Module):
    def __init__(self, config):
        super(AlbertGroup, self).__init__()
        self.inner_group_num = config.inner_group_num
        self.inner_group = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])

    def forward(self, hidden_states, attention_mask, head_mask):
        layer_attentions = ()
        layer_hidden_states = ()
        for inner_group_idx in range(self.inner_group_num): # [1]
            layer_module = self.inner_group[inner_group_idx]
            layer_outputs = layer_module(hidden_states, attention_mask, head_mask)
            hidden_states = layer_outputs[0]
            layer_attentions = layer_attentions + (layer_outputs[1],)
            layer_hidden_states = layer_hidden_states + (hidden_states,)
        return (layer_hidden_states, layer_attentions)

class AlbertTransformer(nn.Module):
    def __init__(self, config):
        super(AlbertTransformer, self).__init__()
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.num_hidden_layers = config.num_hidden_layers
        self.num_hidden_groups = config.num_hidden_groups
        self.group = nn.ModuleList([AlbertGroup(config) for _ in range(config.num_hidden_groups)])

    def forward(self, hidden_states, attention_mask, head_mask):
        all_hidden_states = ()
        all_attentions = ()
        for layer_idx in range(self.num_hidden_layers):
            if self.output_hidden_states and layer_idx == 0:
                all_hidden_states = all_hidden_states + (hidden_states,)
			# [2]
            group_idx = int(layer_idx / self.num_hidden_layers * self.num_hidden_groups)
            layer_module = self.group[group_idx]
            layer_outputs = layer_module(hidden_states, attention_mask, head_mask[layer_idx])
            hidden_states = layer_outputs[0][-1]
            if self.output_attentions:
                all_attentions = all_attentions + layer_outputs[1]
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + layer_outputs[0]
        outputs = (hidden_states,)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            outputs = outputs + (all_attentions,)
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)

config.inner_group_num是组内TFBlock数量,这里记为 GS,num_hidden_layers是总的 TFBlock 层数,记为 LC,num_hidden_groups是分组数量,记为 GC。

可以看出来AlbertGroup含有 GS 个AlbertLayer(也就是TFBlock),逻辑就是依次调用它里面的AlbertLayer([1])。 GS 应该等于LC // GC,但是源码里面三个值都能自由设定,没有校验。

然后AlbertTransformer含有 GC 个AlbertGroup,这些AlbertGroup都是重复使用的。在正向传播期间,程序遍历每一层,根据层序号得到分组序号,然后调用整个分组,也就是把隐藏向量传入层里面所有的组([2])。

直观演示,假设:

代码语言:javascript
复制
LC = 12
GC = 3
GS = LC // GC = 4

那么第一轮迭代,Li = 0, Gi = 0,隐藏向量以此传入G0L0, G0L1, G0L2, G0L3

之后的三轮迭代,Li = 1, 2, 3,还是Gi = 0,隐藏向量以相同方式传入G0L0, G0L1, G0L2, G0L3三次。

之后,Li = 4, 5, 6, 7Gi = 1,隐藏向量传入G1L0, G1L1, G1L2, G1L3四次。

之后,Li = 8, 9, 10, 11Gi = 2,隐藏向量传入G2L0, G2L1, G2L2, G2L3四次。

整体的模块调用路径是这样:

代码语言:javascript
复制
G0L0, G0L1, G0L2, G0L3
G0L0, G0L1, G0L2, G0L3
G0L0, G0L1, G0L2, G0L3
G0L0, G0L1, G0L2, G0L3
G1L0, G1L1, G1L2, G1L3
G1L0, G1L1, G1L2, G1L3
G1L0, G1L1, G1L2, G1L3
G1L0, G1L1, G1L2, G1L3
G2L0, G2L1, G2L2, G2L3
G2L0, G2L1, G2L2, G2L3
G2L0, G2L1, G2L2, G2L3
G2L0, G2L1, G2L2, G2L3

也就是层数是 12 没错,组数是 3 没错,但是每个组被复用了 4 次。

关于复用来看,有三个关键参数,第一个是每个组的容量,也就是 GS,第二个是一共有多少组,也就是GC,第三个是每个组复用多少次,实际上等于LC // GC,但这里面没有任何一个参数直接设置这个,你只能设置 LC。估计是为了和前代保持一致,但非常非常不好用。

假设我们把这个参数开放出来,叫做 GR(分组副本),那么并让 LC = GS * GR * GC,一切就合理了。我们完全可以将重复的层看作新的一层,只不过参数是和其他层共享的。

而且这套设置同时兼容跨层复用和相邻层的复用:

代码语言:javascript
复制
L0, L0, L1, L1, ..., LN, LN

对于相邻层复用,我们只需要把GS设成 1,GR设成 2,GC设成 N。

还有一种是跨层复用:

代码语言:javascript
复制
L0, L1, L2, ..., LN, L0, L1, L2, ... LN

我们只需要把GC设成 1,GS设成 N,然后GR设成 2 。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2023-08-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档