最近在用Tensorflow做神经网络训练,把问题记录下来,以备回顾复习。
Trainning所用的数据是图片,Label数据是NumPy数据。
def parser(image_file, label_file):
image_data = tf.read_file(image_file)
image_data = tf.image.decode_jpeg(image_data)
image_data = tf.image.convert_image_dtype(image_data, tf.float32)
...
label_data = np.load(label_file)
...
return image_data, label_data
image_holder = tf.placeholder(tf.string)
label_holder = tf.placeholder(tf.string)
dataset = tf.data.Dataset.from_tensor_slices((image_holder, label_holder))
dataset = dataset.map(parser)
代码在np.load()行报错:
AttributeError: 'Tensor' object has no attribute 'read'
原因是因为label_file是Tensor,而不是string,但是np.load需要string类型的参数,如何解决呢?
可以使用tf.py_fun函数间接调用np.load函数。
label_data = tf.py_func(np.load, [label_file], [tf.float32])