专栏首页算法私房菜Tensorflow笔记:高级封装——tf.Estimator

Tensorflow笔记:高级封装——tf.Estimator

前言

Google官方给出了两个tensorflow的高级封装——keras和Estimator,本文主要介绍tf.Estimator的内容。tf.Estimator的特点是:既能在model_fn中灵活的搭建网络结构,也不至于像原生tensorflow那样复杂繁琐。相比于原生tensorflow更便捷、相比与keras更灵活,属于二者的中间态。

实现一个tf.Estimator主要分三个部分:input_fn、model_fn、main三个函数。其中input_fn负责处理输入数据、model_fn负责构建网络结构、main来决定要进行什么样的任务(train、eval、earlystop等等)。本文我们就通过MNIST数据集的例子,介绍一下tf.Estimator是怎么用的。

1. input_fn

读过我的另一篇文章:Tensorflow笔记:TFRecord的制作与读取 的同学应该记得那里面的read_and_decode函数,其实就和这里的input_fn逻辑是类似的,都是通过tf.data每次调用会产生一个batch的数据。

def input_fn(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):
    """
    每次调用,从TFRecord文件中读取一个大小为batch_size的batch
    Args:
        filenames: TFRecord文件
        batch_size: batch_size大小
        num_epochs: 将TFRecord中的数据重复几遍,如果是None,则永远循环读取不会停止
        perform_shuffle: 是否乱序

    Returns:
        tensor格式的,一个batch的数据
    """
    def _parse_fn(record):
        features = {
            "label": tf.FixedLenFeature([], tf.int64),
            "image": tf.FixedLenFeature([], tf.string),
        }
        parsed = tf.parse_single_example(record, features)
        # image
        image = tf.decode_raw(parsed["image"], tf.uint8)
        image = tf.reshape(image, [28, 28])
        # label
        label = tf.cast(parsed["label"], tf.int64)
        return {"image": image}, label

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)    # multi-thread pre-process then prefetch

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size) # Batch size to use

    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

如果遇到不同的问题,其实只需要改动tf.data.TFRecordDataset这一行和_parse_fn函数即可。比如如果输入数据不是TFRecord格式,而是一个LIBSVM格式:

def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
    def _parse_fn(line):
        columns = tf.string_split([line], ' ')
        labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
        splits = tf.string_split(columns.values[5:], ':')  # filed_size=280 feature_size=6500000
        id_vals = tf.reshape(splits.values, splits.dense_shape)
        feat_ids, feat_vals = tf.split(id_vals, num_or_size_splits=2, axis=1)
        feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
        feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
        # feat_vals = tf.sign(feat_vals) * tf.math.log(tf.abs(feat_vals) + 1)  # do log manual
        return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TextLineDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)  # Batch size to use

    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

只是修改了_parse_fn的内容,并用tf.data.TextLineDataset替换tf.data.TFRecordDataset即可。总之这种形式的input_fn其实类似一种迭代器,每次调用都会返回一个batch的数据。但是这里面的_parse_fn函数的内容,就要根据实际情况来编写了。

2. model_fn

model_fn是Estimator中最核心,也是最复杂的一个部分,在这里面需要定义网络结构、损失、train_op、评估结果等各种与网路结构有关的内容。下面依然通过《Tensorflow笔记:TFRecord的制作与读取》中的例子:通过简单的DNN网络来预测label来说明(这一段代码虽然长,但是也是结构化的,不要嫌麻烦一个part一个part的看,其实不复杂的)。

def model_fn(features, labels, mode, params):
    # ==========  解析参数部分  ========== #
    learning_rate = params["learning_rate"]

    # ==========  网络结构部分  ========== #
    # input
    X = tf.cast(features["image"], tf.float32, name="input_image")
    X = tf.reshape(X, [-1, 28*28]) / 255
    # DNN
    deep_inputs = X
    deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=128)
    deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=64)
    y_deep = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=10)
    # output
    y = tf.reshape(y_deep, shape=[-1, 10])
    pred = tf.nn.softmax(y, name="soft_max")

    
    # ==========  如果是 predict 任务  ========== #
    predictions={"prob": pred}
    export_outputs = {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.estimator.export.PredictOutput(predictions)}
    # Provide an estimator spec for `ModeKeys.PREDICT`
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs=export_outputs)
    

    # ==========  如果是 eval 任务  ========== #
    one_hot_label = tf.one_hot(tf.cast(labels, tf.int32, name="input_label"), depth=10, name="label")
    # 构建损失
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=one_hot_label))
    eval_metric_ops = {
        "accuracy": tf.metrics.accuracy(tf.math.argmax(one_hot_label, axis=1), tf.math.argmax(pred, axis=1))
    }
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                loss=loss,
                eval_metric_ops=eval_metric_ops)
    

    # ==========  如果是 train 任务  ========== #
    # 构建train_op
    train_op = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(loss, global_step=tf.train.get_global_step())
    # Provide an estimator spec for `ModeKeys.TRAIN` modes
    if mode == tf.estimator.ModeKeys.TRAIN:
        return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                loss=loss,
                train_op=train_op)

介绍一下model_fn的结构:

  • Part1:解析参数部分,本例中以learning_rate为例,展示如何通过param来将参数传递进来,其他参数为了简便,直接用了数值型。
  • Part2:网络结构部分。这部分只是负责构建网络结构,从input到pred,不涉及label部分,所以不要把对labels的处理写在这里,因为如果在predict任务中,可能没有label的数据,就会报错。(在这里其实是支持通过tf.keras来构造网络结构,关于tf.keras的用法我在《Tensorflow笔记:高级封装——Keras》中有详细介绍)
  • Part3:predict任务部分。如果任务目的是predict,那么可以直接通过网络结构计算pred,不需要其他操作。设置好export_outputs,并以tf.estimator.EstimatorSpec形式返回即可。
  • Part4:eval任务部分。如果是eval任务,除了网络结构以外还需要计算此时的损失、正确率等指标,所以对于loss的定义要放在这一部分。同时设置好评价指标eval_metric_ops,并以tf.estimator.EstimatorSpec形式返回。
  • Part5:train任务部分。最后如果是train任务,除了网络结构、loss,还需要优化器、学习率等内容,所以定义train_op的部分在这里进行。最后以tf.estimator.EstimatorSpec形式返回。

model_fn部分虽然看起来长,但是对于不同的任务,只需要改动网络结构部分、loss以及train_op就可以了,说白了还是复制粘贴那点事。

3. main

最后就到了main函数这里,已经有了input_fn负责数据,model_fn负责模型,main这部分管的就是,我要怎么用这个模型。

def main():
    # ==========  准备参数 ========== #
    task_type = "train"
    model_params = {
        "learning_rate": 0.001,
    }

    # ==========  构建Estimator  ========== #
    config = tf.estimator.RunConfig().replace(
        session_config=tf.ConfigProto(device_count={'GPU': 0, 'CPU': 1}),
        log_step_count_steps=100,
        save_summary_steps=100,
        save_checkpoints_secs=None,
        save_checkpoints_steps=500,
        keep_checkpoint_max=1
    )
    estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="./model_ckpt/", params=model_params, config=config)

    # ==========  执行任务  ========== #
    if task_type == "train":
        # early_stop_hook 是控制模型早停的控件,下面两个分别是 tf 1.x 和 tf 2.x 的写法
        # early_stop_hook = tf.contrib.estimator.stop_if_no_increase_hook(estimator, metric_name="accuracy",
        early_stop_hook=tf.estimator.experimental.stop_if_no_increase_hook(estimator, metric_name="accuracy", max_steps_without_increase=1000, min_steps=500)
        train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(tr_files, num_epochs=10, batch_size=32), hooks=[early_stop_hook])
        eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_fn(va_files, num_epochs=1, batch_size=32), steps=None, start_delay_secs=1000, throttle_secs=1)
        tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    elif task_type == "eval":
        estimator.evaluate(input_fn=lambda: input_fn(va_files, num_epochs=1, batch_size=32))
    elif task_type == "infer":
        preds = estimator.predict(input_fn=lambda: input_fn(te_files, num_epochs=1, batch_size=32), predict_keys="prob")
        with open("./pred.txt", "w") as fo:
            for prob in preds:
                fo.write("%f\n" % (np.argmax(prob['prob'])))
    if task_type == "export":
        feature_spec = {
            "image": tf.placeholder(dtype=tf.float32, shape=[None, 28, 28], name="image"),
        }
        serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_spec)
        Estimator.export_savedmodel("./saved_model/", serving_input_receiver_fn)

其实main中主要做三件事:1. 通过tf.estimator.RunConfig()配置构建Estimator对象;2. 初始化estimator(model_dir如果非空则自动热启动);3. 执行train/eval/infer/export任务。

  • train任务中初始化好TrainSpec和EvalSpec之后可以直接调用tf.estimator.train。也可以使用train_and_evaluate来一边训练一边输出验证集效果。hook可以看作是在训练验证基础上可以实现其他复杂功能的“插件”,比如本例中的early_stop,其他功能还包括热启动、Fine-tune等等,关于hook的用法比较复杂,以后单独写一篇文章。
  • eval任务输出的就是在model_fn函数中eval_metric_ops定义的指标。
  • infer任务就是调用estimator.predict获取在model_fn中定义的export_outputs作为预测值。
  • export就是将定义Estimator时候模型路径 model_dir="./model_ckpt/" 下的模型导出为可部署模型,也就是常说的saved_model。关于saved_model和模型部署方面,我也会单独写一篇文章来介绍。另外feature_spec指的是一个请求过来所带的数据应该长什么样,对应了model_fn里面的features(即features["image"]),所以这里feature_spec用的是字典的形式,建议model_fn中的features也用字典形式,哪怕是只有一个元素。

最后,直接跑main函数,或者通过tf.app.run()来运行脚本都可以:

# 直接运行 main 函数
main()

# 通过 tf.app.run() 来运行
if __name__ == "__main__":
    tf.app.run()

4. 分布式训练

对于单机单卡和单机多卡的情况,可以通过tf.device('/gpu:0')来手动控制,这里介绍一下在多机分布式情况下Estimator如何进行分布式训练。Estimator的分布式训练和原生Tensorflow的分布式训练类似,都需要提供一份“集群名单”,并且告诉每一台机器他是名单中的谁,并在每台机器上运行脚本。下面看一个例子

import os
import json
import numpy as np
import tensorflow as tf


FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("job_name", "worker", "chief/ps/worker")
tf.app.flags.DEFINE_integer("task_id", 0, "Task ID of the worker running the train")

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'chief': ["localhost:2221"],
        'ps':  ["localhost:2222"],
        'worker': ["localhost:2223", "localhost:2224"]
    },
    'task': {'type': FLAGS.job_name, 'index': FLAGS.task_id}
})

本例采用本地机的两个端口模拟集群中的两个机器,"cluster"表示集群的“名单”信息。"task"表示该机器的信息,"type"表示该机器的角色,"index"表示该机器是列表中的第几个。tf.Estimator中需要指定一个chief机器,ps机也只是在特定的策略下才需要指定(这一点下文介绍)。

除此之外,只需要在tf.ConfigProto中配置train_distribute就可以了:

strategy = tf.distribute.experimental.ParameterServerStrategy()
config = tf.estimator.RunConfig().replace(
    session_config=tf.ConfigProto(device_count={'GPU': 0, 'CPU': 1}),
    log_step_count_steps=100,
    save_summary_steps=100,
    save_checkpoints_secs=None,
    save_checkpoints_steps=500,
    keep_checkpoint_max=1,
    train_distribute=strategy
)
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="./model_ckpt/", params=model_params, config=config)

接下来只需要在每台机器上运行脚本,就可以完成Esitmator的分布式训练了。实际上可以声明不同的strategy,来实现不同的并行策略:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Tensorflow笔记:通过tf.Serving+Docker部署

    很多时候仅仅是线下跑一个模型,对特定一批数据进行预测并不够,需要随时来一个或几个样本都能输出结果。这时候就需要起一个服务,然后随时一个包含数据的请求过来,就返回...

    共产主义搬砖人
  • tf46:再议tf.estimator之便利

    版权声明:本文为博主原创文章,未经博主允许不得转载。有问题可以加微信:lp9628(注明CSDN)。 ...

    MachineLP
  • 使用BERT和TensorFlow构建搜索引擎

    基于神经概率语言模型的特征提取器,例如与多种下游NLP任务相关的BERT提取特征。因此它们有时被称为自然语言理解(NLU)模块。

    代码医生工作室
  • 一看就懂的Tensorflow实战(多层感知机)

    这里定义含有两个隐含层的模型,隐含层输出均为256个节点,输入784(MNIST数据集图片大小28*28),输出10。

    AI异构
  • Reddit引爆框架决战!TensorFlow遭疯狂吐槽,PyTorch被捧上神坛

    对于不同人群可能有不同的答案,科研人员可能更偏爱pyTorch,因其简单易用,能够快速验证idea来抢占先机发论文。

    新智元
  • TensorFlow 资源大全–中文版

    jtoy 发起整理的 TensorFlow 资源,包含一些很棒的 TensorFlow 工程、库、项目等。

    我在鹅厂做安全
  • TensorFlow 资源大全中文版

    编译:伯乐在线 - Yalye,英文:jtoy http://blog.jobbole.com/110558/ jtoy 发起整理的 TensorFlow 资源...

    企鹅号小编
  • 从零开始学TensorFlow【01-搭建环境、HelloWorld篇】

    最近在学习TensorFlow的相关知识,了解了TensorFlow一些基础的知识,现在周末有空了,就写写一些笔记,记录一下自己的成长~

    Java3y
  • 机器学习&人工智能博文链接汇总

    ? 争取每天更新 ? 126 ? ---- 蜗牛的历程: [入门问题] [机器学习] [聊天机器人] [好玩儿的人工智能应用实例] [Tensor...

    杨熹
  • 安装 TensorFlow安装 TensorFlow

    我们已在如下配置的 64 位笔记本电脑/台式机操作系统中构建并测试过 TensorFlow:

    一个会写诗的程序员
  • 动态 | TensorFlow 2.0 新特性来啦,部分模型、库和 API 已经可以使用

    由于令人难以置信的多样化社区,TensorFlow 已经发展成为世界上最受欢迎和广泛采用的 ML 平台之一。这个社区包括:

    AI科技评论
  • 课程 |《深度学习原理与TensorFlow实践》学习笔记(二)

    作者 | 王清 TensorFlow基础使用 环境准备 TensorFlow安装 常用Python库介绍 实例解析 Kaggle平台及Titanic题目介绍 代...

    AI科技大本营
  • 官方解读:TensorFlow 2.0中即将到来的所有新特性

    作为最流行的深度学习框架,TensorFlow 已经成长为全球使用最广泛的机器学习平台。目前,TensorFlow 的开发者社区包括研究者、开发者和企业等。

    机器之心
  • google 机器学习速成课程字幕和视频

    iOSDevLog
  • 资源 | GitHub万星:适用于初学者的TensorFlow代码资源集

    机器之心
  • 官方推荐!用TensorFlow 2.0做深度学习入门教程 | 资源

    最近,TensorFlow 2.0版的开发者预览版发布没多久,这不,又有一篇优质教程来了。

    量子位
  • 官方解读:TensorFlow 2.0中即将到来的所有新特性

    本文经机器之心(微信公众号:almosthuman2014)授权转载,禁止二次转载

    小小詹同学
  • TensorFlow 2.0 新功能 | 官方详解

    TensorFlow 已经发展为世界上最受欢迎和被广泛采用的机器学习平台之一,我们衷心感谢一直以来支持我们的各界的开发者和他们的贡献:

    量子位
  • TensorFlow 2.0 的新功能

    2018 年 11 月,TensorFlow 迎来了它的 3 岁生日,我们回顾了几年来它增加的功能,进而对另一个重要里程碑 TensorFlow 2.0 感到兴...

    磐创AI

扫码关注云+社区

领取腾讯云代金券