首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >__init__()在tensorflow中缺少1个必需的位置参数:“sess”

__init__()在tensorflow中缺少1个必需的位置参数:“sess”
EN

Stack Overflow用户
提问于 2018-07-09 04:45:43
回答 0查看 2K关注 0票数 0

我正在尝试使用这个脚本中的类,它对目录'test_ images‘中的多个图像执行图像分类。我以前不经常使用类,所以在这种情况下我有点困惑如何正确地应用它们。错误是:TypeError: __init__() missing 1 required positional argument: 'sess'。任何帮助都将不胜感激!

代码如下:

代码语言:javascript
运行
复制
def image_recognition_algorithm():

def load_graph(model_file):
    graph = tf.Graph()
    graph_def = tf.GraphDef()

    with open(model_file, "rb") as f:
        graph_def.ParseFromString(f.read())
    with graph.as_default():
        tf.import_graph_def(graph_def)

    return graph

def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
                input_mean=0, input_std=255):
    input_name = "file_reader"
    output_name = "normalized"
    file_reader = tf.read_file(file_name, input_name)
    image_reader = tf.image.decode_jpeg(file_reader, channels = 3, name='jpeg_reader')
    float_caster = tf.cast(image_reader, tf.float32)
    dims_expander = tf.expand_dims(float_caster, 0);
    resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
    normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
    sess = tf.Session()
    result = sess.run(normalized)

    return result

def load_labels(label_file):
    label = []
    proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
    for l in proto_as_ascii_lines:
        label.append(l.rstrip())
    return label

class initiate_session():

def __init__(self, sess):
    self.sess = sess
    graph = load_graph(model_file)
    input_name = "import/" + input_layer
    output_name = "import/" + output_layer
    input_operation = graph.get_operation_by_name(input_name);
    output_operation = graph.get_operation_by_name(output_name);

    config = tf.ConfigProto(device_count={"CPU": 4},
                            inter_op_parallelism_threads=1,
                            intra_op_parallelism_threads=4)
    self.sess = tf.Session(graph=graph, config = config)
    start = time.time()
    results = self.sess.run(output_operation.outputs[0],
                      {input_operation.outputs[0]: t})
    end=time.time()
    results = np.squeeze(results)

    top_k = results.argsort()[-5:][::-1]
    labels = load_labels(label_file)


    print('\nEvaluation time (1-image): {:.3f}s\n'.format(end-start))


    for i in top_k:
        print(file_name, labels[i], results[i])

    return [file_name] + list(results)

    image_list = [f for f in listdir('test_images') if isfile(join('test_images', f))]

    res_list = []
    for image in image_list:
        if image.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
            res_list.append(main(join('test_images', image)))

def main(self, file_name):
    model_file = "tf_files/retrained_graph.pb"
    label_file = "tf_files/retrained_labels.txt"
    input_height = 299
    input_width = 299
    input_mean = 128
    input_std = 128
    input_layer = "Mul"
    output_layer = "final_result"

    t = read_tensor_from_image_file(file_name,
                                    input_height=input_height,
                                    input_width=input_width,
                                    input_mean=input_mean,
                                    input_std=input_std)

if __name__ == '__main__':
    initiate_session().main()
EN

回答

页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51235868

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档