首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Tensorflow笔记:高级封装——tf.Estimator

Tensorflow笔记:高级封装——tf.Estimator

作者头像
共产主义搬砖人
发布2021-09-24 15:06:01
1.7K0
发布2021-09-24 15:06:01
举报
文章被收录于专栏:算法私房菜算法私房菜

前言

Google官方给出了两个tensorflow的高级封装——keras和Estimator,本文主要介绍tf.Estimator的内容。tf.Estimator的特点是:既能在model_fn中灵活的搭建网络结构,也不至于像原生tensorflow那样复杂繁琐。相比于原生tensorflow更便捷、相比与keras更灵活,属于二者的中间态。

实现一个tf.Estimator主要分三个部分:input_fn、model_fn、main三个函数。其中input_fn负责处理输入数据、model_fn负责构建网络结构、main来决定要进行什么样的任务(train、eval、earlystop等等)。本文我们就通过MNIST数据集的例子,介绍一下tf.Estimator是怎么用的。

1. input_fn

读过我的另一篇文章:Tensorflow笔记:TFRecord的制作与读取 的同学应该记得那里面的read_and_decode函数,其实就和这里的input_fn逻辑是类似的,都是通过tf.data每次调用会产生一个batch的数据。

def input_fn(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):
    """
    每次调用,从TFRecord文件中读取一个大小为batch_size的batch
    Args:
        filenames: TFRecord文件
        batch_size: batch_size大小
        num_epochs: 将TFRecord中的数据重复几遍,如果是None,则永远循环读取不会停止
        perform_shuffle: 是否乱序

    Returns:
        tensor格式的,一个batch的数据
    """
    def _parse_fn(record):
        features = {
            "label": tf.FixedLenFeature([], tf.int64),
            "image": tf.FixedLenFeature([], tf.string),
        }
        parsed = tf.parse_single_example(record, features)
        # image
        image = tf.decode_raw(parsed["image"], tf.uint8)
        image = tf.reshape(image, [28, 28])
        # label
        label = tf.cast(parsed["label"], tf.int64)
        return {"image": image}, label

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)    # multi-thread pre-process then prefetch

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size) # Batch size to use

    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

如果遇到不同的问题,其实只需要改动tf.data.TFRecordDataset这一行和_parse_fn函数即可。比如如果输入数据不是TFRecord格式,而是一个LIBSVM格式:

def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
    def _parse_fn(line):
        columns = tf.string_split([line], ' ')
        labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
        splits = tf.string_split(columns.values[5:], ':')  # filed_size=280 feature_size=6500000
        id_vals = tf.reshape(splits.values, splits.dense_shape)
        feat_ids, feat_vals = tf.split(id_vals, num_or_size_splits=2, axis=1)
        feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
        feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
        # feat_vals = tf.sign(feat_vals) * tf.math.log(tf.abs(feat_vals) + 1)  # do log manual
        return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TextLineDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)  # Batch size to use

    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

只是修改了_parse_fn的内容,并用tf.data.TextLineDataset替换tf.data.TFRecordDataset即可。总之这种形式的input_fn其实类似一种迭代器,每次调用都会返回一个batch的数据。但是这里面的_parse_fn函数的内容,就要根据实际情况来编写了。

2. model_fn

model_fn是Estimator中最核心,也是最复杂的一个部分,在这里面需要定义网络结构、损失、train_op、评估结果等各种与网路结构有关的内容。下面依然通过《Tensorflow笔记:TFRecord的制作与读取》中的例子:通过简单的DNN网络来预测label来说明(这一段代码虽然长,但是也是结构化的,不要嫌麻烦一个part一个part的看,其实不复杂的)。

def model_fn(features, labels, mode, params):
    # ==========  解析参数部分  ========== #
    learning_rate = params["learning_rate"]

    # ==========  网络结构部分  ========== #
    # input
    X = tf.cast(features["image"], tf.float32, name="input_image")
    X = tf.reshape(X, [-1, 28*28]) / 255
    # DNN
    deep_inputs = X
    deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=128)
    deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=64)
    y_deep = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=10)
    # output
    y = tf.reshape(y_deep, shape=[-1, 10])
    pred = tf.nn.softmax(y, name="soft_max")

    
    # ==========  如果是 predict 任务  ========== #
    predictions={"prob": pred}
    export_outputs = {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.estimator.export.PredictOutput(predictions)}
    # Provide an estimator spec for `ModeKeys.PREDICT`
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs=export_outputs)
    

    # ==========  如果是 eval 任务  ========== #
    one_hot_label = tf.one_hot(tf.cast(labels, tf.int32, name="input_label"), depth=10, name="label")
    # 构建损失
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=one_hot_label))
    eval_metric_ops = {
        "accuracy": tf.metrics.accuracy(tf.math.argmax(one_hot_label, axis=1), tf.math.argmax(pred, axis=1))
    }
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                loss=loss,
                eval_metric_ops=eval_metric_ops)
    

    # ==========  如果是 train 任务  ========== #
    # 构建train_op
    train_op = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(loss, global_step=tf.train.get_global_step())
    # Provide an estimator spec for `ModeKeys.TRAIN` modes
    if mode == tf.estimator.ModeKeys.TRAIN:
        return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                loss=loss,
                train_op=train_op)

介绍一下model_fn的结构:

  • Part1:解析参数部分,本例中以learning_rate为例,展示如何通过param来将参数传递进来,其他参数为了简便,直接用了数值型。
  • Part2:网络结构部分。这部分只是负责构建网络结构,从input到pred,不涉及label部分,所以不要把对labels的处理写在这里,因为如果在predict任务中,可能没有label的数据,就会报错。(在这里其实是支持通过tf.keras来构造网络结构,关于tf.keras的用法我在《Tensorflow笔记:高级封装——Keras》中有详细介绍)
  • Part3:predict任务部分。如果任务目的是predict,那么可以直接通过网络结构计算pred,不需要其他操作。设置好export_outputs,并以tf.estimator.EstimatorSpec形式返回即可。
  • Part4:eval任务部分。如果是eval任务,除了网络结构以外还需要计算此时的损失、正确率等指标,所以对于loss的定义要放在这一部分。同时设置好评价指标eval_metric_ops,并以tf.estimator.EstimatorSpec形式返回。
  • Part5:train任务部分。最后如果是train任务,除了网络结构、loss,还需要优化器、学习率等内容,所以定义train_op的部分在这里进行。最后以tf.estimator.EstimatorSpec形式返回。

model_fn部分虽然看起来长,但是对于不同的任务,只需要改动网络结构部分、loss以及train_op就可以了,说白了还是复制粘贴那点事。

3. main

最后就到了main函数这里,已经有了input_fn负责数据,model_fn负责模型,main这部分管的就是,我要怎么用这个模型。

def main():
    # ==========  准备参数 ========== #
    task_type = "train"
    model_params = {
        "learning_rate": 0.001,
    }

    # ==========  构建Estimator  ========== #
    config = tf.estimator.RunConfig().replace(
        session_config=tf.ConfigProto(device_count={'GPU': 0, 'CPU': 1}),
        log_step_count_steps=100,
        save_summary_steps=100,
        save_checkpoints_secs=None,
        save_checkpoints_steps=500,
        keep_checkpoint_max=1
    )
    estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="./model_ckpt/", params=model_params, config=config)

    # ==========  执行任务  ========== #
    if task_type == "train":
        # early_stop_hook 是控制模型早停的控件,下面两个分别是 tf 1.x 和 tf 2.x 的写法
        # early_stop_hook = tf.contrib.estimator.stop_if_no_increase_hook(estimator, metric_name="accuracy",
        early_stop_hook=tf.estimator.experimental.stop_if_no_increase_hook(estimator, metric_name="accuracy", max_steps_without_increase=1000, min_steps=500)
        train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(tr_files, num_epochs=10, batch_size=32), hooks=[early_stop_hook])
        eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_fn(va_files, num_epochs=1, batch_size=32), steps=None, start_delay_secs=1000, throttle_secs=1)
        tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    elif task_type == "eval":
        estimator.evaluate(input_fn=lambda: input_fn(va_files, num_epochs=1, batch_size=32))
    elif task_type == "infer":
        preds = estimator.predict(input_fn=lambda: input_fn(te_files, num_epochs=1, batch_size=32), predict_keys="prob")
        with open("./pred.txt", "w") as fo:
            for prob in preds:
                fo.write("%f\n" % (np.argmax(prob['prob'])))
    if task_type == "export":
        feature_spec = {
            "image": tf.placeholder(dtype=tf.float32, shape=[None, 28, 28], name="image"),
        }
        serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_spec)
        Estimator.export_savedmodel("./saved_model/", serving_input_receiver_fn)

其实main中主要做三件事:1. 通过tf.estimator.RunConfig()配置构建Estimator对象;2. 初始化estimator(model_dir如果非空则自动热启动);3. 执行train/eval/infer/export任务。

  • train任务中初始化好TrainSpec和EvalSpec之后可以直接调用tf.estimator.train。也可以使用train_and_evaluate来一边训练一边输出验证集效果。hook可以看作是在训练验证基础上可以实现其他复杂功能的“插件”,比如本例中的early___stop,其他功能还包括热启动、Fine-tune等等,关于hook的用法比较复杂,以后单独写一篇文章。
  • eval任务输出的就是在model_fn函数中eval_metric_ops定义的指标。
  • infer任务就是调用estimator.predict获取在model_fn中定义的export_outputs作为预测值。
  • export就是将定义Estimator时候模型路径 model_dir="./model_ckpt/" 下的模型导出为可部署模型,也就是常说的saved_model。关于saved_model和模型部署方面,我也会单独写一篇文章来介绍。另外feature_spec指的是一个请求过来所带的数据应该长什么样,对应了model_fn里面的features(即features"image"),所以这里feature_spec用的是字典的形式,建议model_fn中的features也用字典形式,哪怕是只有一个元素。

最后,直接跑main函数,或者通过tf.app.run()来运行脚本都可以:

# 直接运行 main 函数
main()

# 通过 tf.app.run() 来运行
if __name__ == "__main__":
    tf.app.run()

4. 分布式训练

对于单机单卡和单机多卡的情况,可以通过tf.device('/gpu:0')来手动控制,这里介绍一下在多机分布式情况下Estimator如何进行分布式训练。Estimator的分布式训练和原生Tensorflow的分布式训练类似,都需要提供一份“集群名单”,并且告诉每一台机器他是名单中的谁,并在每台机器上运行脚本。下面看一个例子

import os
import json
import numpy as np
import tensorflow as tf


FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("job_name", "worker", "chief/ps/worker")
tf.app.flags.DEFINE_integer("task_id", 0, "Task ID of the worker running the train")

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'chief': ["localhost:2221"],
        'ps':  ["localhost:2222"],
        'worker': ["localhost:2223", "localhost:2224"]
    },
    'task': {'type': FLAGS.job_name, 'index': FLAGS.task_id}
})

本例采用本地机的两个端口模拟集群中的两个机器,"cluster"表示集群的“名单”信息。"task"表示该机器的信息,"type"表示该机器的角色,"index"表示该机器是列表中的第几个。tf.Estimator中需要指定一个chief机器,ps机也只是在特定的策略下才需要指定(这一点下文介绍)。

除此之外,只需要在tf.ConfigProto中配置train_distribute就可以了:

strategy = tf.distribute.experimental.ParameterServerStrategy()
config = tf.estimator.RunConfig().replace(
    session_config=tf.ConfigProto(device_count={'GPU': 0, 'CPU': 1}),
    log_step_count_steps=100,
    save_summary_steps=100,
    save_checkpoints_secs=None,
    save_checkpoints_steps=500,
    keep_checkpoint_max=1,
    train_distribute=strategy
)
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="./model_ckpt/", params=model_params, config=config)

接下来只需要在每台机器上运行脚本,就可以完成Esitmator的分布式训练了。实际上可以声明不同的strategy,来实现不同的并行策略:

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 前言
  • 1. input_fn
  • 2. model_fn
  • 3. main
  • 4. 分布式训练
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档