首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >InvalidArgumentError: Conv2DCustomBackpropFilterOp只支持NHWC剪枝神经网络。

InvalidArgumentError: Conv2DCustomBackpropFilterOp只支持NHWC剪枝神经网络。
EN

Stack Overflow用户
提问于 2021-03-19 12:48:22
回答 2查看 755关注 0票数 0

我在使用库tensorflow_model_optimization时遇到了问题。我正在开发一个代码来修剪一个已经训练过的神经网络。我从一个h5文件中导入了权重,所以我使用tensorflor_model_optimization来修剪我的神经网络。当我调用fit方法时,出现了以下错误:

代码语言:javascript
运行
复制
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-26-ce9759e4dd53> in <module>
----> 1 model_for_pruning.fit_generator(base_treinamento, steps_per_epoch = 6000 /64, epochs = 5, validation_data = base_teste, validation_steps = 30, callbacks=callbacks)

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1859         use_multiprocessing=use_multiprocessing,
   1860         shuffle=shuffle,
-> 1861         initial_epoch=initial_epoch)
   1862 
   1863   def evaluate_generator(self,

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1098                 _r=1):
   1099               callbacks.on_train_batch_begin(step)
-> 1100               tmp_logs = self.train_function(iterator)
   1101               if data_handler.should_sync:
   1102                 context.async_wait()

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    826     tracing_count = self.experimental_get_tracing_count()
    827     with trace.Trace(self._name) as tm:
--> 828       result = self._call(*args, **kwds)
    829       compiler = "xla" if self._experimental_compile else "nonXla"
    830       new_tracing_count = self.experimental_get_tracing_count()

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    853       # In this case we have created variables on the first call, so we run the
    854       # defunned version which is guaranteed to never create variables.
--> 855       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    856     elif self._stateful_fn is not None:
    857       # Release the lock early so that multiple threads can perform the call

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
   2941        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   2942     return graph_function._call_flat(
-> 2943         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   2944 
   2945   @property

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1917       # No tape is watching; skip to running the function.
   1918       return self._build_call_outputs(self._inference_function.call(
-> 1919           ctx, args, cancellation_manager=cancellation_manager))
   1920     forward_backward = self._select_forward_and_backward_functions(
   1921         args,

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
    558               inputs=args,
    559               attrs=attrs,
--> 560               ctx=ctx)
    561         else:
    562           outputs = execute.execute_with_cancellation(

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError:  Conv2DCustomBackpropFilterOp only supports NHWC.
     [[node gradient_tape/sequential_3/prune_low_magnitude_conv2d_18/Conv2D/Conv2DBackpropFilter (defined at <ipython-input-24-fc85f8818d30>:1) ]] [Op:__inference_train_function_10084]

Errors may have originated from an input operation.
Input Source operations connected to node gradient_tape/sequential_3/prune_low_magnitude_conv2d_18/Conv2D/Conv2DBackpropFilter:
 sequential_3/prune_low_magnitude_activation_18/Relu (defined at C:\Users\Pichau\anaconda3\envs\supernova\lib\site-packages\tensorflow_model_optimization\python\core\sparsity\keras\pruning_wrapper.py:270)

Function call stack:
train_function
enter code here

我的代码:

代码语言:javascript
运行
复制
from keras.models import model_from_json
from keras.preprocessing.image import ImageDataGenerator
import numpy as np

arquivo = open('model.json', 'r')
estrutura_rede = arquivo.read()
arquivo.close()
model = model_from_json(estrutura_rede)
model.load_weights('model.h5')

gerador_treinamento = ImageDataGenerator(rescale=None)
base_treinamento = gerador_treinamento.flow_from_directory('data/train', target_size = (51,51), batch_size = 64, class_mode = 'binary')

gerador_teste = ImageDataGenerator(rescale=None)
base_teste = gerador_teste.flow_from_directory('data/test', target_size = (51,51), batch_size = 64, class_mode = 'binary')

import tempfile
import tensorflow as tf
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

batch_size = 128
epochs = 2
validation_split = 0.1

num_images = int(len(base_treinamento) * (1 - validation_split))
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

model_for_pruning.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

model_for_pruning.summary()

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit_generator(base_treinamento, steps_per_epoch = 6000 /64, epochs = 5, validation_data = base_teste, validation_steps = 30, callbacks=callbacks)

python: 3.6.12

tensorflow: 2.2.0

tensorflow-model-optimization: 0.5.0

有人能帮我吗?

EN

回答 2

Stack Overflow用户

发布于 2022-03-03 20:16:08

将运行时从CPU更改为GPU对我有用。

如果您正在CPU上运行它,请考虑将其更改为GPU。

票数 1
EN

Stack Overflow用户

发布于 2021-03-25 16:14:50

这似乎不是一个修剪问题,而是模型的数据格式的问题。看起来一个Conv2D层应该使用NHWC格式(data_format="channels_last")。

你能分享一下模型的代码吗?

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66708492

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档