前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【云+社区年度征文】tensorflow 2.0 Estimator Keras读取saved model并预测

【云+社区年度征文】tensorflow 2.0 Estimator Keras读取saved model并预测

原创
作者头像
大鹅
发布2020-12-07 17:35:03
7750
发布2020-12-07 17:35:03
举报

背景

使用tensorflow2.0以上版本框架用Keras或者Estimator方式保存模型有两种方式加载模型并预测。

Keras框架保存模型后可以直接加载并调用predict方法预测;

estimator将比较麻烦,需要签名并传入tensor才可以预测;

Keras模型预测

代码语言:txt
复制
import tensorflow as tf
from tensorflow import keras
model = tf.keras.models.load_model(export_dir)

# dataframe 特征读取与处理
X = dict(dataframe)
c = model.predict(X)
output = np.argmax(c, axis=1)

Estimator模型预测

代码语言:txt
复制
import tensorflow as tf
# 加载模型 & 签名
imported = tf.saved_model.load(export_dir)
f = imported.signatures["predict"]
代码语言:txt
复制
# 转换为tensor并预测
out_df = pd.DataFrame()
def predict(dataframe):
    examples = []
    for row in dataframe.itertuples():
        feature_map = {}
        # 特征处理 将特征放入dict中
        example = tf.train.Example(
            features=tf.train.Features(
                feature = feature_map
            )
        )
        examples.append(example.SerializeToString())
            
    ex = tf.constant(examples)
    result = f(examples=ex)
    out_df['high_rank_score'] = np.max(result["probabilities"].numpy(), axis=1)
    out_df['tag'] = np.argmax(result["probabilities"].numpy(), axis=1)
    return out_df

Ref

  1. http://d0evi1.com/tensorflow/custom_estimators/
  2. https://www.tensorflow.org/guide/saved_model?hl=zh-cn#%E5%8A%A0%E8%BD%BD%E5%92%8C%E4%BD%BF%E7%94%A8%E8%87%AA%E5%AE%9A%E4%B9%89%E6%A8%A1%E5%9E%8B
  3. https://zhuanlan.zhihu.com/p/66872472
  4. https://yinguobing.com/load-savedmodel-of-estimator-by-keras/

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 背景
  • Keras模型预测
  • Estimator模型预测
  • Ref
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档