前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Tensorflow实战(1)-Load NumPy Array In Dataset Parser Function

Tensorflow实战(1)-Load NumPy Array In Dataset Parser Function

作者头像
YoungTimes
发布2022-04-28 12:59:24
1980
发布2022-04-28 12:59:24
举报
文章被收录于专栏:半杯茶的小酒杯

最近在用Tensorflow做神经网络训练,把问题记录下来,以备回顾复习。

Trainning所用的数据是图片,Label数据是NumPy数据。

代码语言:javascript
复制
def parser(image_file, label_file):
    image_data = tf.read_file(image_file)
    image_data = tf.image.decode_jpeg(image_data)
    image_data = tf.image.convert_image_dtype(image_data, tf.float32)
    ...
    label_data = np.load(label_file)
    ...
    return image_data, label_data
 
image_holder = tf.placeholder(tf.string)
label_holder = tf.placeholder(tf.string)
dataset = tf.data.Dataset.from_tensor_slices((image_holder, label_holder))
dataset = dataset.map(parser)

代码在np.load()行报错:

代码语言:javascript
复制
AttributeError: 'Tensor' object has no attribute 'read'

解决方法

原因是因为label_file是Tensor,而不是string,但是np.load需要string类型的参数,如何解决呢?

可以使用tf.py_fun函数间接调用np.load函数。

代码语言:javascript
复制
label_data = tf.py_func(np.load, [label_file], [tf.float32])
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-05-11,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 半杯茶的小酒杯 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 解决方法
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档