Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
231 views
in Technique[技术] by (71.8m points)

请问为什么tf.py_function()中自定义的函数未被调用?

TF2.0中,在最大池化层的源码当中使用了tf.py_function(),自定义了一个函数,想将Tensor转为Numpy矩阵从而进行一些操作,然而程序运行时,却将自定义的函数忽略掉了,自定义函数中的print()函数未作输出,返回值为<unknow>,如果py_function()中的参数写错了,系统还是会报错的,说明py_function()这个函数运行了,但是自定义的函数未运行,这是为什么呢,是缺少修饰器么。

def Myshow_all(self,inputs):
    def showTensor(inputs):
        a=inputs.numpy()
        print(a)
        return a
    y=tf.py_function(showTensor,[inputs],tf.float32)
    print(y.shape)
    return y
def _pooling_function(self,inputs,pool_size,strides,padding,data_format):
    output=K.pool2d(inputs,pool_size,strides,padding,data_format,pool_mode='max')
    a=self.Myshow_all(inputs)
    print(a)
    print(a.shape)
    return output

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)
import tensorflow as tf

def Myshow_all(inputs):
    def showTensor(inputs):
        a = inputs.numpy()
        print(a)
        return a
    y = tf.py_function(showTensor, [inputs], tf.float32)
    print(y.shape)
    return y


if __name__ == "__main__":
    inputs = tf.random.uniform(shape=(2, 2))
    Myshow_all(inputs)

我使用上面的这些代码测试了下 Myshow_all 函数是没问题的.在Myshow_all上面加入@tf.function修饰器,出现了你说的那个问题.
所以我觉得你这个报错的原因是因为Myshow_all 本身有@tf.function,或者调用它的函数或者间接调用它的函数加了 @tf.function 修饰器


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...