前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >tf.py_func()

tf.py_func()

作者头像
狼啸风云
修改2022-09-04 21:39:19
1.3K0
修改2022-09-04 21:39:19
举报

tensorflow中所有的tensor只是占位符,在没有用tf.Session().run接口填充值之前是没有实际值的,不能对其进行判值操作,如if ... else...等,在实际问题中,我们可能需要将一个tensor转换成numpy array 然后进行一些 np的运算,然后返回tensor这样可以加强tensorflow的灵活性。在目标检测算法Faster R-CNN中,需要计算各种ground truth,接口比较复杂。因此,使用tf.py_func是一个比较好的途径。对于tf.py_func的使用,可以参见计算RPN的ground truth计算proposals的ground truth时的使用方法。可以看到,都是将tensor转化成numpy array,再使用np.操作完成复杂运算。封装一个python函数并将其用作TensorFlow op。

代码语言:javascript
复制
tf.py_func(
    func,
    inp,
    Tout,
    stateful=True,
    name=None
)

参数:

  • func: 一个Python函数,它接受ndarray对象作为参数并返回一个ndarray对象列表(或单个ndarray)。这个函数必须接受inp中有多少张量就有多少个参数,这些参数类型将匹配相应的tf.inp中的tf.tensor。返回的ndarrays必须匹配已定义的Tout的数字和类型。重要提示: func的输入和输出numpy ndarrays不能保证是副本。在某些情况下,它们的底层内存将与相应的TensorFlow张量共享。在没有显式(np.)复制的python数据结构中,就地修改或存储func输入或返回值可能会产生不确定的结果。
  • inp: 一个张量对象的列表。
  • Tout: tensorflow数据类型的列表或元组,如果只有一个tensorflow数据类型,则使用单个tensorflow数据类型,指示func返回什么。
  • stateful: (布尔)。如果为真,则应该认为该函数是有状态的。如果一个函数是无状态的,当给定相同的输入时,它将返回相同的输出,并且没有可观察到的副作用。诸如公共子表达式消除之类的优化只在无状态操作上执行。
  • name: 操作的名称(可选)。

返回值:

  • func计算的张量或单个张量的列表。

例:

代码语言:javascript
复制
def my_func(array1,array2):
    return array1 + array2, array1 - array2

if __name__ =='__main__':
    array1 = np.array([[1, 2], [3, 4]])
    array2 = np.array([[1, 2], [3, 4]])

    a1 = tf.placeholder(tf.float32,[2,2],name = 'array1')
    a2 = tf.placeholder(tf.float32,[2,2],name = 'array2')
    y1,y2 = tf.py_func(my_func,[a1,a2],[tf.float32, tf.float32])

    with tf.Session() as sess:
        y1_,y2_ = sess.run([y1,y2],feed_dict={a1:array1,a2:array2})
        print(y1_)
        print('*'*10)
        print(y2_)


Output:
-----------
[[2. 4.]
[6. 8.]]
**********
[[0. 0.]
[0. 0.]]
-----------

直接用array的方式操作:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

def my_func(array1,array2):
    return array1 + array2, array1 - array2

with tf.Session() as sess:
  array1 = np.array([[1, 2], [3, 4]])
  array2 = np.array([[1, 2], [3, 4]])
  y1, y2 = my_func(array1, array2)
  print(y1)
  print('*' * 10)
  print(y2)


Output:
-----------
[[2 4]
 [6 8]]
**********
[[0 0]
 [0 0]]
-----------

原链接:https://tensorflow.google.cn/api_docs/python/tf/py_func?hl=en

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019年07月15日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档