首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >tf.data.Dataset.from_generator :值错误“没有足够的值来解压(预期的2,得到1)”

tf.data.Dataset.from_generator :值错误“没有足够的值来解压(预期的2,得到1)”
EN

Stack Overflow用户
提问于 2022-01-15 07:21:06
回答 1查看 428关注 0票数 0

我目前正在尝试使用预先训练过的变压器模型来进行分类问题。我使用tf.data.Dataset.from_generator方法编写了一个自定义生成器。该模型接受两个输入: input_id和attn_mask。当调用model.fit时,我得到的值错误“没有足够的值来解压(预期的2,got 1)”,接收到的参数列表显示它同时得到了input_id和attn_mask。有人能帮我解决这个问题吗?

代码语言:javascript
复制
import tensorflow.keras as keras 
from tensorflow.keras.models import Model 
from transformers import TFBertModel,BertConfig

def _input_fn():
    x  = (train_data.iloc[:,0:512]).to_numpy()
    y = (train_data.iloc[:,512:516]).to_numpy()
    attn = np.asarray(np.tile(attn_mask,x.shape[0]).reshape(-1,512))
    def generator():
       for s1, s2, l in zip(x, attn, y):
              yield {"input_id": s1, "attn_mask": s2}, l 
   
    dataset = tf.data.Dataset.from_generator(generator, output_types=({"input_id": tf.int32, "attn_mask": tf.int32}, tf.int32))
    #dataset = dataset.batch(2)
    #dataset = dataset.shuffle
    return dataset

train_data是包含训练数据(16000×516)的数据。最后四列是一个热编码标签。由于我没有使用自动标记器函数,所以我将注意掩码作为attn_mask传递。

我的模型

代码语言:javascript
复制
bert = 'bert-base-uncased'

config = BertConfig(dropout=0.2, attention_dropout=0.2)
config.output_hidden_states = False
transformer_model = TFBertModel.from_pretrained(bert, config = config)

input_ids_in = tf.keras.layers.Input(shape=(512), name='input_id', dtype='int32')
input_masks_in = tf.keras.layers.Input(shape=(512), name='attn_mask', dtype='int32') 

embedding_layer = transformer_model(input_ids_in, attention_mask=input_masks_in)[0]
#cls_token = embedding_layer[:,0,:]
#X = tf.keras.layers.BatchNormalization()(cls_token)
X = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50, return_sequences=True, dropout=0.1, recurrent_dropout=0.1))(embedding_layer)
X = tf.keras.layers.GlobalMaxPool1D()(X)
#X = tf.keras.layers.BatchNormalization()(X)
X = tf.keras.layers.Dense(50, activation='relu')(X)
X = tf.keras.layers.Dropout(0.2)(X)
X = tf.keras.layers.Dense(4, activation='softmax')(X)
model = tf.keras.Model(inputs=[input_ids_in, input_masks_in], outputs = X)


for layer in model.layers[:3]:
  layer.trainable = False
代码语言:javascript
复制
optimizer = tf.keras.optimizers.Adam(0.001, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['categorical_accuracy'])

epochs = 1
batch_size =2
history = model.fit(_input_fn(), epochs= epochs, batch_size= batch_size, verbose=2)
代码语言:javascript
复制
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_16908/300834086.py in <module>
      2 batch_size =2
      3 #history = model.fit(trainDataGenerator(batch_size), epochs= epochs, validation_data=valDataGenerator(batch_size), verbose=2) #
----> 4 history = model.fit(_input_fn(), epochs= epochs, batch_size= batch_size, verbose=2) #validation_data=val_ds,

~/.local/lib/python3.8/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

~/.local/lib/python3.8/site-packages/transformers/models/bert/modeling_tf_bert.py in call(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, training, **kwargs)
   1124             kwargs_call=kwargs,
   1125         )
-> 1126         outputs = self.bert(
   1127             input_ids=inputs["input_ids"],
   1128             attention_mask=inputs["attention_mask"],

~/.local/lib/python3.8/site-packages/transformers/models/bert/modeling_tf_bert.py in call(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, training, **kwargs)
    771             raise ValueError("You have to specify either input_ids or inputs_embeds")
    772 
--> 773         batch_size, seq_length = input_shape
    774 
    775         if inputs["past_key_values"] is None:

ValueError: Exception encountered when calling layer "bert" (type TFBertMainLayer).

not enough values to unpack (expected 2, got 1)

Call arguments received:
  • input_ids=tf.Tensor(shape=(512,), dtype=int32)
  • attention_mask=tf.Tensor(shape=(512,), dtype=int32)
  • token_type_ids=None
  • position_ids=None
  • head_mask=None
  • inputs_embeds=None
  • encoder_hidden_states=None
  • encoder_attention_mask=None
  • past_key_values=None
  • use_cache=True
  • output_attentions=False
  • output_hidden_states=False
  • return_dict=True
  • training=True
  • kwargs=<class 'inspect._empty'>

编辑:添加调用_input_fn()的输出

代码语言:javascript
复制
<FlatMapDataset shapes: ({input_id: <unknown>, attn_mask: <unknown>}, <unknown>), types: ({input_id: tf.int32, attn_mask: tf.int32}, tf.int32)>
EN

Stack Overflow用户

发布于 2022-10-11 17:56:43

我通过批处理我的tf.data.Dataset来解决这个错误。这为我的数据集中的TensorSpec提供了一个有两个值来解压缩->的形状。

代码语言:javascript
复制
TensorSpec(shape=(16, 200)...

这就是错误所指的内容。

解决方案

代码语言:javascript
复制
print(train_ds) #Before Batching 
new_train_ds = train_ds.batch(16, drop_remainder=True)
print(new_train_ds) #After Batching

# Before Batching
<MapDataset element_spec=({'input_ids': TensorSpec(shape=(200,), 
dtype=tf.float64, name=None), 'attention_mask': TensorSpec(shape= 
(200,), dtype=tf.float64, name=None)}, TensorSpec(shape=(11,), 
dtype=tf.float64, name=None))>

# After Batching
<BatchDataset element_spec=({'input_ids': TensorSpec(shape=(16, 200), 
dtype=tf.float64, name=None), 'attention_mask': TensorSpec(shape=(16, 
200), dtype=tf.float64, name=None)}, TensorSpec(shape=(16, 11), 
dtype=tf.float64, name=None))>
票数 0
EN
查看全部 1 条回答
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70719567

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档