首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >InvalidArgumentError的Tensorflow对象检测中止: indices[0] =2不在[0,1]中

InvalidArgumentError的Tensorflow对象检测中止: indices[0] =2不在[0,1]中
EN

Stack Overflow用户
提问于 2017-10-25 18:55:40
回答 1查看 5.2K关注 0票数 1

我正在尝试在我自己的数据集上训练tensorflow对象检测。

我做了什么?

  • 使用ssd_mobilenet_v1_pets.config作为基础来创建我自己的管道配置。调整了num_classes和所有其他特定于路径的部分,以匹配我的环境。
  • 使用来自ssd_mobilenet_v1_coco的tensorflow模型动物园作为检查点
  • 创建带有所有标签的Label文件(第一个索引从1开始)
  • 从我的数据集创建了一个TFRecord文件(脚本基于tensorflow示例脚本)

出了什么问题?

在使用:python tensorflow_models/research/object_detection/train.py --pipeline_config_path=/home/playground/ssd_mobilenet_v1.config --train_dir=/tmp/bla/开始培训时,我会得到以下回溯:

代码语言:javascript
运行
复制
Traceback (most recent call last):
  File "tensorflow_models/research/object_detection/train.py", line 198, in <module>
    tf.app.run()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "tensorflow_models/research/object_detection/train.py", line 194, in main
    worker_job_name, is_chief, FLAGS.train_dir)
  File "/home/playground/tensorflow_models/research/object_detection/trainer.py", line 296, in train
    saver=saver)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/slim/python/slim/learning.py", line 767, in train
    sv.stop(threads, close_summary_writer=True)
  File "/usr/lib/python2.7/contextlib.py", line 35, in __exit__
    self.gen.throw(type, value, traceback)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/supervisor.py", line 964, in managed_session
    self.stop(close_summary_writer=close_summary_writer)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/supervisor.py", line 792, in stop
    stop_grace_period_secs=self._stop_grace_secs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/coordinator.py", line 389, in join
    six.reraise(*self._exc_info_to_raise)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/queue_runner_impl.py", line 238, in _run
    enqueue_callable()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1235, in _single_operation_run
    target_list_as_strings, status, None)
  File "/usr/lib/python2.7/contextlib.py", line 24, in __exit__
    self.gen.next()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
    pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0] = 2 is not in [0, 1)
         [[Node: cond/RandomCropImage/PruneCompleteleyOutsideWindow/Gather/Gather_1 = Gather[Tindices=DT_INT64, Tparams=DT_INT64, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](cond/RandomCropImage/PruneCompleteleyOutsideWindow/Gather/Gather_1/Switch:1, cond/RandomCropImage/PruneCompleteleyOutsideWindow/Reshape)]]

不幸的是,我不知道tensorflow想告诉我什么回溯,也不知道我应该从哪里开始寻找我的错误。我已经检查了每一步是否有可能出错,但到目前为止找不到任何错误。

编辑:我还尝试使用配置文件,正如@eshirima所提议的那样。我再次更改了num_classes参数和所有使用PATH_TO_BE_CONFIGURED标记的其他参数。但是,它现在失败了,出现了以下错误消息:

代码语言:javascript
运行
复制
INFO:tensorflow:Starting Queues.
INFO:tensorflow:global_step/sec: 0
INFO:tensorflow:Error reported to Coordinator: <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>, indices[0] = 2 is not in [0, 1)
         [[Node: Loss/Gather_29 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Loss/Pad_5, Loss/Reshape_47)]]

Caused by op u'Loss/Gather_29', defined at:
  File "tensorflow_models/research/object_detection/train.py", line 198, in <module>
    tf.app.run()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "tensorflow_models/research/object_detection/train.py", line 194, in main
    worker_job_name, is_chief, FLAGS.train_dir)
  File "/home/playground/tensorflow_models/research/object_detection/trainer.py", line 192, in train
    clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
  File "/home/playground/tensorflow_models/research/slim/deployment/model_deploy.py", line 193, in create_clones
    outputs = model_fn(*args, **kwargs)
  File "/home/playground/tensorflow_models/research/object_detection/trainer.py", line 133, in _create_losses
    losses_dict = detection_model.loss(prediction_dict)
  File "/home/playground/tensorflow_models/research/object_detection/meta_architectures/ssd_meta_arch.py", line 411, in loss
    self.groundtruth_lists(fields.BoxListFields.classes))
  File "/home/playground/tensorflow_models/research/object_detection/meta_architectures/ssd_meta_arch.py", line 485, in _assign_targets
    groundtruth_classes_with_background_list)
  File "/home/playground/tensorflow_models/research/object_detection/core/target_assigner.py", line 438, in batch_assign_targets
    anchors, gt_boxes, gt_class_targets)
  File "/home/playground/tensorflow_models/research/object_detection/core/target_assigner.py", line 154, in assign
    match)
  File "/home/playground/tensorflow_models/research/object_detection/core/target_assigner.py", line 250, in _create_classification_targets
    matched_cls_targets = tf.gather(groundtruth_labels, matched_gt_indices)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 2409, in gather
    validate_indices=validate_indices, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 1219, in gather
    validate_indices=validate_indices, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2630, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1204, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): indices[0] = 2 is not in [0, 1)
         [[Node: Loss/Gather_29 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Loss/Pad_5, Loss/Reshape_47)]]

Traceback (most recent call last):
  File "tensorflow_models/research/object_detection/train.py", line 198, in <module>
    tf.app.run()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "tensorflow_models/research/object_detection/train.py", line 194, in main
    worker_job_name, is_chief, FLAGS.train_dir)
  File "/home/playground/tensorflow_models/research/object_detection/trainer.py", line 296, in train
    saver=saver)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/slim/python/slim/learning.py", line 767, in train
    sv.stop(threads, close_summary_writer=True)
  File "/usr/lib/python2.7/contextlib.py", line 35, in __exit__
    self.gen.throw(type, value, traceback)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/supervisor.py", line 964, in managed_session
    self.stop(close_summary_writer=close_summary_writer)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/supervisor.py", line 792, in stop
    stop_grace_period_secs=self._stop_grace_secs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/coordinator.py", line 389, in join
    six.reraise(*self._exc_info_to_raise)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/coordinator.py", line 296, in stop_on_exception
    yield
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/coordinator.py", line 494, in run
    self.run_loop()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/supervisor.py", line 994, in run_loop
    self._sv.global_step])
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 895, in run
    run_metadata_ptr)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1124, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1321, in _do_run
    options, run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1340, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0] = 2 is not in [0, 1)
         [[Node: Loss/Gather_29 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Loss/Pad_5, Loss/Reshape_47)]]

Caused by op u'Loss/Gather_29', defined at:
  File "tensorflow_models/research/object_detection/train.py", line 198, in <module>
    tf.app.run()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "tensorflow_models/research/object_detection/train.py", line 194, in main
    worker_job_name, is_chief, FLAGS.train_dir)
  File "/home/playground/tensorflow_models/research/object_detection/trainer.py", line 192, in train
    clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
  File "/home/playground/tensorflow_models/research/slim/deployment/model_deploy.py", line 193, in create_clones
    outputs = model_fn(*args, **kwargs)
  File "/home/playground/tensorflow_models/research/object_detection/trainer.py", line 133, in _create_losses
    losses_dict = detection_model.loss(prediction_dict)
  File "/home/playground/tensorflow_models/research/object_detection/meta_architectures/ssd_meta_arch.py", line 411, in loss
    self.groundtruth_lists(fields.BoxListFields.classes))
  File "/home/playground/tensorflow_models/research/object_detection/meta_architectures/ssd_meta_arch.py", line 485, in _assign_targets
    groundtruth_classes_with_background_list)
  File "/home/playground/tensorflow_models/research/object_detection/core/target_assigner.py", line 438, in batch_assign_targets
    anchors, gt_boxes, gt_class_targets)
  File "/home/playground/tensorflow_models/research/object_detection/core/target_assigner.py", line 154, in assign
    match)
  File "/home/playground/tensorflow_models/research/object_detection/core/target_assigner.py", line 250, in _create_classification_targets
    matched_cls_targets = tf.gather(groundtruth_labels, matched_gt_indices)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 2409, in gather
    validate_indices=validate_indices, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 1219, in gather
    validate_indices=validate_indices, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2630, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1204, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): indices[0] = 2 is not in [0, 1)
         [[Node: Loss/Gather_29 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Loss/Pad_5, Loss/Reshape_47)]]

编辑添加了一些代码,这些代码显示了如何生成TFRecord文件。整个脚本有点长,但我试着把它删减,只显示相关的部分。如果遗漏了你感兴趣的东西,请告诉我。

代码语言:javascript
运行
复制
CATEGORIES_TO_TRAIN = ["apple", "dog", "cat"]

def createTFExample(img):
    imageFormat = ""
    if img.format == 'JPEG':
        imageFormat = b'jpeg'
    elif img.format == 'PNG':
        imageFormat = b'png'
    else:
        print 'Unknown Image format %s' %(img.format,)
        return None

    width, height = img.size
    filename = str(img.filename)
    encodedImageData = img.bytesIO

    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []

    for annotation in img.annotations:
        xmins.append((annotation.left / width))
        xmaxs.append((annotation.left + annotation.width) / width)
        ymins.append((annotation.top / height))
        ymaxs.append((annotation.top + annotation.height) / height)

    #we might have some images in our dataset, which don't have a annotation, skip those
    if((len(xmins) == 0) or (len(xmaxs) == 0) or (len(ymins) == 0) or (len(ymaxs) == 0)):
        return None

    label = [img.label.encode('utf8')]
    classes = [(CATEGORIES_TO_TRAIN.index(img.label) + 1)] #class indexes start with 1


    tf_example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(filename),
      'image/source_id': dataset_util.bytes_feature(filename),
      'image/encoded': dataset_util.bytes_feature(encodedImageData),
      'image/format': dataset_util.bytes_feature(imageFormat),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
      'image/object/class/text': dataset_util.bytes_list_feature(label),
      'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_example


def createTfRecordFile(images):
    writer = tf.python_io.TFRecordWriter(TFRECORD_OUTPUT_PATH)
    for img in images:
        t = createTFExample(img)
        if t is not None:
            writer.write(t.SerializeToString())

    writer.close()

任何帮助,指出我的正确方向,是真的感谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-04-10 04:55:41

我也有类似的问题,但是让label列表和classes列表具有相同的长度,并为我修复边界框元素。

具体来说,在createTFExample()中,label = [img.label.encode('utf8')]classes = [(CATEGORIES_TO_TRAIN.index(img.label) + 1)]中的元素应该对应于边框注释列表中的元素:

代码语言:javascript
运行
复制
xmins = []
xmaxs = []
ymins = []
ymaxs = []

for annotation in img.annotations:
    xmins.append((annotation.left / width))
    xmaxs.append((annotation.left + annotation.width) / width)
    ymins.append((annotation.top / height))
    ymaxs.append((annotation.top + annotation.height) / height)

从代码结构来看,我假设每个img对象都有一个对象类型,但在这种情况下,编写

代码语言:javascript
运行
复制
label = [img.label.encode('utf8')] * len(xmins)  
classes = [(CATEGORIES_TO_TRAIN.index(img.label) + 1)] * len(xmins)

或者使用任何给出图像中对象数量的元素,以便label&类和边界框列表具有相同的长度。

如果一个img对象中有几种类型的对象,那么使用与注释列表索引相匹配的内部元素的索引,创建一个对象名称和类别If的列表。

生成的列表应该如下所示:

代码语言:javascript
运行
复制
xmins = [a_xmin, b_xmin, c_xmin]
ymins = [a_ymin, b_ymin, c_ymin]
xmaxs = [a_xmax, b_xmax, c_xmax]
ymaxs = [a_ymax, b_ymax, c_ymax]
labels = [a_label, b_label, c_label]
classes = [a_classid, b_classid, c_classid]

这解决了我的问题,希望这是有帮助的!

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46940073

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档