我想使用PyTorch的nn.MultiheadAttention,但它不起作用。
我只想在手工计算的注意示例中使用py手电的功能。
我在尝试运行这个例子时总是会出错。
import torch.nn as nn
embed_dim = 4
num_heads = 1
x = [
[1, 0, 1, 0], # Input 1
[0, 2, 0, 2], # Input 2
[1, 1, 1, 1] # Input 3
]
x = torch.tensor(x, dtype=torch.float32)
w_key = [
[0, 0, 1],
[1, 1, 0],
[0, 1, 0],
[1, 1, 0]
]
w_query = [
[1, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 1, 1]
]
w_value = [
[0, 2, 0],
[0, 3, 0],
[1, 0, 3],
[1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)
keys = x @ w_key
querys = x @ w_query
values = x @ w_value
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output, attn_output_weights = multihead_attn(querys, keys, values)
发布于 2022-03-31 07:51:20
尝尝这个。
首先,x是(3x4)矩阵。因此,您需要一个权重矩阵(4x4)代替。
nn.MultiheadAttention似乎只支持批处理模式,尽管文档表示它支持取消批处理输入。因此,让我们通过.unsqueeze(0)
使您的一个数据点处于批处理模式。
embed_dim = 4
num_heads = 1
x = [
[1, 0, 1, 0], # Seq 1
[0, 2, 0, 2], # Seq 2
[1, 1, 1, 1] # Seq 3
]
x = torch.tensor(x, dtype=torch.float32)
w_key = [
[0, 0, 1, 1],
[1, 1, 0, 1],
[0, 1, 0, 1],
[1, 1, 0, 1]
]
w_query = [
[1, 0, 1, 1],
[1, 0, 0, 1],
[0, 0, 1, 1],
[0, 1, 1, 1]
]
w_value = [
[0, 2, 0, 1],
[0, 3, 0, 1],
[1, 0, 3, 1],
[1, 1, 0, 1]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)
keys = (x @ w_key).unsqueeze(0) # to batch mode
querys = (x @ w_query).unsqueeze(0)
values = (x @ w_value).unsqueeze(0)
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
attn_output, attn_output_weights = multihead_attn(querys, keys, values)
https://stackoverflow.com/questions/71464582
复制相似问题