PyTorch - 有效地应用注意力

内容来源于 Stack Overflow,并遵循CC BY-SA 3.0许可协议进行翻译与使用

  • 回答 (1)
  • 关注 (0)
  • 查看 (178)

我已经建立了一个注意力的RNN语言模型,我通过参加所有以前的隐藏状态(只有一个方向)为输入的每个元素创建上下文向量。

在我看来,最直接的解决方案是在RNN输出上使用for循环,这样每个上下文向量一个接一个地计算。

import torch
import torch.nn as nn
import torch.nn.functional as F

class RNN_LM(nn.Module):
    def __init__(self, hidden_size, vocab_size, embedding_dim=None, droprate=0.5):
        super().__init__()
        if not embedding_dim:
            embedding_dim = hidden_size
        self.embedding_matrix = nn.Embedding(vocab_size, embedding_dim)

        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, batch_first=False)
        self.attn = nn.Linear(hidden_size, hidden_size)
        self.vocab_dist = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(droprate)

    def forward(self, x):
        x = self.dropout(self.embedding_matrix(x.view(-1, 1)))
        x, states = self.lstm(x)
        #print(x.size())
        x = x.squeeze()
        content_vectors = [x[0].view(1, -1)]
        # for-loop over hidden states and attention
        for i in range(1, x.size(0)):
            prev_states = x[:i]
            current_state = x[i].view(1, -1)

            attn_prod = torch.mm(self.attn(current_state), prev_states.t())
            attn_weights = F.softmax(attn_prod, dim=1)
            context = torch.mm(attn_weights, prev_states)
            content_vectors.append(context)

        return self.vocab_dist(self.dropout(torch.cat(content_vectors)))

注意:forward此处的方法仅用于培训。

然而,该解决方案不是非常有效,因为代码与后续计算每个上下文向量不能很好地并行化。但由于上下文向量不依赖于彼此,我想知道是否存在非连续的计算方法。

那么有没有一种方法可以在没有for循环的情况下计算上下文向量从而可以并行化更多的计算?

提问于
用户回答回答于

好的,为了清楚起见:我假设我们只关心矢量化for循环。形状是x什么?假设x是2维的,我有以下代码,v1执行你的循环并且v2是矢量化版本:

import torch
import torch.nn.functional as F

torch.manual_seed(0)

x = torch.randn(3, 6)

def v1():
    for i in range(1, x.size(0)):
        prev = x[:i]
        curr = x[i].view(1, -1)

        prod = torch.mm(curr, prev.t())
        attn = prod # same shape
        context = torch.mm(attn, prev)
        print(context)

def v2():
    # we're going to unroll the loop by vectorizing over the new,
    # 0-th dimension of `x`. We repeat it as many times as there
    # are iterations in the for loop
    repeated = x.unsqueeze(0).repeat(x.size(0), 1, 1)

    # we're looking to build a `prevs` tensor such that
    # prevs[i, x, y] == prev[x, y] at i-th iteration of the loop in v1,
    # up to 0-padding necessary to make them all the same size.
    # We need to build a higher-dimensional equivalent of torch.triu
    xs = torch.arange(x.size(0)).reshape(1, -1, 1)
    zs = torch.arange(x.size(0)).reshape(-1, 1, 1)
    prevs = torch.where(zs < xs, torch.tensor(0.), repeated)

    # this is an equivalent of the above iteration starting at 1
    prevs = prevs[:-1]
    currs = x[1:]

    # a batched matrix multiplication
    prod = torch.matmul(currs, prevs.transpose(1, 2))
    attn = prod # same shape
    context = torch.matmul(attn, prevs)
    # equivalent of a higher dimensional torch.diagonal
    contexts = torch.einsum('iij->ij', (context))
    print(contexts)

print(x)

print('\n------ v1 -------\n')
v1()
print('\n------ v2 -------\n')
v2()

它会对你的循环进行矢量化,但有一些注意事项。首先,我假设x是二维的。其次,我跳过softmax声称它不会改变输入的大小,因此不会影响矢量化。这是真的,但不幸的是,0填充向量的softmax v不等于未填充的0填充softmax v。这可以通过重整化来修复。如果我的假设是正确的,请告诉我这是否帮助了您。

热门问答

腾讯会议共享屏幕,其他人收到的是黑屏?

AI学习社一个人工智能的死忠粉,让我们一起了解人工智能

你分享给谁 让谁看下腾讯会议应用的权限 是否都开启了

腾讯云音视频 支持 移动端h5吗( 不是小程序的)?

shixin

腾讯 · 高级产品经理 (已认证)

推荐

实时音视频TRTC的Web版是基于WebRTC的方案,需要浏览器的对WebRTC的支持,支持WebRTC的浏览器就可以。但是,移动端浏览器对WebRTC支持的情况并不好,建议使用小程序版。

如何用命令修改腾讯云解析目标ip?

氧化先生道可道 非常道 名可名 非常名
推荐
可以,参考: https://cns.api.qcloud.com/v2/index.php? &<公共请求参数> &Action=RecordCreate &domain=qcloud.com &subDomain=www &recordType=A &recordLine=默...... 展开详请

组队匹配完整流程是怎样的?感觉缺少API支持?

您好,matchgroup匹配成功后,小组成员会进入同一个房间和同一个队伍,这个API需要传玩家ID,通过邀请好友进房间就能拿到玩家的id,解散房间后再调用matchgroup,在没有解散房间不能调用matchGroup 接口。

SCF使用了k8s或docker容器技术吗?

Mason-Serverless

腾讯 · 产品经理 (已认证)

推荐

SCF的新架构使用的轻量化虚拟机技术,同时MVM里内嵌的有docker,但是没有使用K8S

腾讯云IoT物联平台中如何自定义Topic?

DylanRichard

腾讯 · 产品经理 (已认证)

万物互联的时代,欢迎来到IoT的世界
推荐已采纳
第二个是物联网通信平台(IoT Hub)的,https://cloud.tencent.com/document/product/634/32546。目前物联网开发平台(IoT explorer)只支持基于数据模板协议的接入(文档 https://cloud.tencent.co...... 展开详请

所属标签

扫码关注云+社区

领取腾讯云代金券