首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Keras MultiHeadAttention层抛出IndexError:元组索引超出范围

Keras MultiHeadAttention层抛出IndexError:元组索引超出范围
EN

Stack Overflow用户
提问于 2022-01-25 14:34:02
回答 1查看 576关注 0票数 1

我一次又一次地得到这个错误,当我试图对一维向量进行自我关注时,我真的不明白为什么会发生这种情况,任何帮助都是非常感谢的。

代码语言:javascript
运行
复制
layer = layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.ones(shape=[1, 16])
source = tf.ones(shape=[1, 16])
output_tensor, weights = layer(target, source)

错误:

代码语言:javascript
运行
复制
~/anaconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/keras/layers/multi_head_attention.py in _masked_softmax(self, attention_scores, attention_mask)
    399         attention_mask = array_ops.expand_dims(
    400             attention_mask, axis=mask_expansion_axes)
--> 401     return self._softmax(attention_scores, attention_mask)
    402 
    403   def _compute_attention(self,

~/anaconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
   1010         with autocast_variable.enable_auto_cast_variables(
   1011             self._compute_dtype_object):
-> 1012           outputs = call_fn(inputs, *args, **kwargs)
   1013 
   1014         if self._activity_regularizer:

~/anaconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/keras/layers/advanced_activations.py in call(self, inputs, mask)
    332             inputs, axis=self.axis, keepdims=True))
    333       else:
--> 334         return K.softmax(inputs, axis=self.axis[0])
    335     return K.softmax(inputs, axis=self.axis)
    336 

IndexError: tuple index out of range
EN

回答 1

Stack Overflow用户

发布于 2022-01-25 14:40:48

您忘记了批处理维度,这是必要的。另外,如果您想要输出张量和相应的权重,则必须将参数return_attention_scores设置为True。试着做这样的事情:

代码语言:javascript
运行
复制
import tensorflow as tf

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
samples = 5
target = tf.ones(shape=[samples, 1, 16])
source = tf.ones(shape=[samples, 1, 16])
output_tensor, weights = layer(target, source, return_attention_scores=True)

也是根据文档

查询:形状的查询张量(B,T,dim) 价值:形状的价值张量(B,S,dim)

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70850506

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档