首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何使用PyTorch的nn.MultiheadAttention

如何使用PyTorch的nn.MultiheadAttention
EN

Stack Overflow用户
提问于 2022-03-14 08:05:06
回答 1查看 806关注 0票数 0

我想使用PyTorch的nn.MultiheadAttention,但它不起作用。

我只想在手工计算的注意示例中使用py手电的功能。

我在尝试运行这个例子时总是会出错。

代码语言:javascript
运行
复制
import torch.nn as nn

embed_dim = 4
num_heads = 1

x = [
  [1, 0, 1, 0], # Input 1
  [0, 2, 0, 2], # Input 2
  [1, 1, 1, 1]  # Input 3
 ]
x = torch.tensor(x, dtype=torch.float32)

w_key = [
  [0, 0, 1],
  [1, 1, 0],
  [0, 1, 0],
  [1, 1, 0]
]
w_query = [
  [1, 0, 1],
  [1, 0, 0],
  [0, 0, 1],
  [0, 1, 1]
]
w_value = [
  [0, 2, 0],
  [0, 3, 0],
  [1, 0, 3],
  [1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)


keys = x @ w_key
querys = x @ w_query
values = x @ w_value

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output, attn_output_weights = multihead_attn(querys, keys, values)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-03-31 07:51:20

尝尝这个。

首先,x是(3x4)矩阵。因此,您需要一个权重矩阵(4x4)代替。

nn.MultiheadAttention似乎只支持批处理模式,尽管文档表示它支持取消批处理输入。因此,让我们通过.unsqueeze(0)使您的一个数据点处于批处理模式。

代码语言:javascript
运行
复制
embed_dim = 4
num_heads = 1

x = [
  [1, 0, 1, 0], # Seq 1
  [0, 2, 0, 2], # Seq 2
  [1, 1, 1, 1]  # Seq 3
 ]
x = torch.tensor(x, dtype=torch.float32)

w_key = [
  [0, 0, 1, 1],
  [1, 1, 0, 1],
  [0, 1, 0, 1],
  [1, 1, 0, 1]
]
w_query = [
  [1, 0, 1, 1],
  [1, 0, 0, 1],
  [0, 0, 1, 1],
  [0, 1, 1, 1]
]
w_value = [
  [0, 2, 0, 1],
  [0, 3, 0, 1],
  [1, 0, 3, 1],
  [1, 1, 0, 1]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)


keys = (x @ w_key).unsqueeze(0)     # to batch mode
querys = (x @ w_query).unsqueeze(0)
values = (x @ w_value).unsqueeze(0)

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
attn_output, attn_output_weights = multihead_attn(querys, keys, values)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71464582

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档