首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >图像失真返回错误: Reshape_5:0返回的张量无效

图像失真返回错误: Reshape_5:0返回的张量无效
EN

Stack Overflow用户
提问于 2017-01-08 16:23:19
回答 1查看 279关注 0票数 0

我试图将图像失真添加到我的ConvNet模型中,并得到一个非常奇怪的错误。我的数据是TFRecords格式的,我正在使用来自CFIAR10代码的color_distorter()函数。下面是一些我拼凑在一起的虚拟代码,以确保所有的事情都在做我希望它做的事情。当我看到图像被扭曲后,就没有问题了。这个问题似乎是在张量变平后或在运行之后出现的。由于某种原因,它将执行一次,但是第二次它会抛出一个错误。下面是我的代码和它返回的错误。我目前的怀疑是它可能是tf.map_fn(),但我不知道。

代码语言:javascript
运行
复制
def color_distorer(image, thread_id=0, scope=None):
    with tf.op_scope([image], scope, 'distort_color'):
        color_ordering = thread_id % 2

       if color_ordering == 0:
           image = tf.image.random_brightness(image, max_delta=32. / 255.)
           image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
           image = tf.image.random_hue(image, max_delta=0.2)
           image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
    elif color_ordering == 1:
           image = tf.image.random_brightness(image, max_delta=32. / 255.)
           image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
           image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
           image = tf.image.random_hue(image, max_delta=0.2)

           # The random_* ops do not necessarily clamp.
        image = tf.clip_by_value(image, 0.0, 1.0)
        return image

X_test_batch, y_test_batch = inputs(FLAGS.train_dir,
                                FLAGS.test_file,
                                FLAGS.batch_size,
                                FLAGS.n_epochs,
                                FLAGS.n_classes,
                                one_hot_labels=True,
                                imshape=160*160*3)

with tf.Session() as sess:

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    # X = sess.run([X_test_batch])
    for i in range(5): #epochs
        image = tf.reshape(X_test_batch, [-1, 160, 160, 3])
        result = tf.map_fn(lambda img: color_distorer(img), image)
        X, y = sess.run([result, y_test_batch]) #to see what the distortion did
        for i in range(50): #all the images look distorted..
            if i%25 ==0:
                plt.title(y[i])
                plt.imshow(X[i])
                plt.show()
        print('******************') #This runs once... then breaks. Why?
        result = tf.reshape(result, [-1, 76800])
        dX, dy = sess.run([result, y_test_batch])
        print(dX)

错误:

代码语言:javascript
运行
复制
******************
[[ 1.  1.  1. ...,  1.  1.  1.]
[ 1.  1.  1. ...,  1.  1.  1.]
[ 1.  1.  1. ...,  1.  1.  1.]
..., 
[ 1.  1.  1. ...,  1.  1.  1.]
[ 1.  1.  1. ...,  1.  1.  1.]
[ 1.  1.  1. ...,  1.  1.  1.]]
WARNING:tensorflow:tf.op_scope(values, name, default_name) is deprecated, use tf.name_scope(name, default_name, values)
Traceback (most recent call last):
File "/home/mcamp/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1021, in _do_call
return fn(*args)
File "/home/mcamp/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1003, in _run_fn
status, run_metadata)
File "/home/mcamp/anaconda3/lib/python3.5/contextlib.py", line 66, in __exit__
next(self.gen)
File "/home/mcamp/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 469, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: The tensor returned for map_1/TensorArrayPack_1/TensorArrayGatherV2:0 was not valid.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/media/mcamp/Local SSHD/Python Projects/Garage Door   Project/FreshStart/ReadTFREcords.py", line 80, in <module>
  dX, dy = sess.run([result, y_test_batch])
File "/home/mcamp/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 766, in run
run_metadata_ptr)
 File "/home/mcamp/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 964, in _run
feed_dict_string, options, run_metadata)
File "/home/mcamp/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1014, in _do_run
target_list, options, run_metadata)
File "/home/mcamp/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1034, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: The tensor returned for map_1/TensorArrayPack_1/TensorArrayGatherV2:0 was not valid.
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-01-08 19:34:18

好的,我认为正在发生的是我调整了X_test_batch的大小,使之回到图像的形状。然后我在它上运行我的失真,并将其命名为result。这就是我跑过去看的。在查看之后,我调整了模型的大小,使之成为平面图像。在此期间,我更改了一些名为imageresult的东西的名称。如果代码从上到下都是红色的,我认为它会很好,但是Tensorflow有它工作的图形,所以当它到达第二次迭代时,它期望结果是一个三维张量,但是它实际上是一个平面张量,试图通过tf.map_fntf.map_fn是驻留在图形上的东西,直到sess.run()才会被执行,因此这将是我看到的错误的原因。我希望这是有意义的,并帮助其他人的道路上。

代码语言:javascript
运行
复制
with tf.name_scope('TestingData'):
    X_test_batch, y_test_batch = inputs(FLAGS.train_dir,
                                       FLAGS.test_file,
                                       FLAGS.batch_size,
                                       FLAGS.n_epochs,
                                       FLAGS.n_classes,
                                       one_hot_labels=True,
                                       imshape=160*160*3)

image = tf.reshape(X_test_batch, [-1, 160, 160, 3])

image = tf.map_fn(lambda img: color_distorer(img), image)
with tf.Session() as sess:

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
     sess.run(init_op)
     coord = tf.train.Coordinator()
     threads = tf.train.start_queue_runners(coord=coord)
     # X = sess.run([X_test_batch])
     for i in range(5):
         # dX, dy = sess.run([result, y_test_batch])
         print('******************')
         image = tf.reshape(image, [-1, 76800])
         dX, dy = sess.run([image, y_test_batch])
         distorted = np.reshape(dX, (-1, 160, 160, 3))
         print(dX.shape)
         print(distorted.shape)
         for i in range(50):
             if i%25 ==0:
                  print(distorted[i], dy[i])
                  plt.title(dy[i])
                  plt.imshow(distorted[i])
                  plt.show()
                  print(dX.shape)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/41534866

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档