前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >AI模型注册成MLSQL UDF函数示例

AI模型注册成MLSQL UDF函数示例

作者头像
用户2936994
发布2022-07-21 14:11:33
2720
发布2022-07-21 14:11:33
举报
文章被收录于专栏:祝威廉祝威廉

训练一个Tensorflow模型

下面的代码仅支持Console notebook模式下运行

首先,准备minist数据集

代码语言:javascript
复制
include lib.`github.com/allwefantasy/lib-core` where  
force="true" and 
libMirror="gitee.com" and -- proxy configuration. 
alias="libCore"; 
 
-- dump minist data to object storage 
include local.`libCore.dataset.mnist`; 
!dumpData /tmp/mnist; 
 
 
load parquet.`/tmp/mnist` as mnist_data; 

接着,可以进行训练了:

代码语言:javascript
复制
#%python 
#%input=mnist_data 
#%schema=file 
#%output=mnist_model 
#%env=source /Users/allwefantasy/opt/anaconda3/bin/activate ray1.3.0 
#%cache=true 

import ray 
import os 
from tensorflow.keras import models,layers 
from tensorflow.keras import utils as np_utils 
from pyjava.api.mlsql import RayContext 
from pyjava.storage import streaming_tar 
from pyjava import rayfix 
import numpy as np 
 
 
ray_context = RayContext.connect(globals(),"127.0.0.1:10001") 
data_servers = ray_context.data_servers() 
 
def data(): 
    temp_data = [item for item in RayContext.collect_from(data_servers)] 
    train_images = np.array([np.array(item["image"]) for item in temp_data]) 
    train_labels = np_utils.to_categorical(np.array([item["label"] for item in temp_data])    ) 
    train_images = train_images.reshape((len(temp_data),28*28)) 
    return train_images,train_labels 
 
@ray.remote 
@rayfix.last 
def train(): 
    train_images,train_labels = data() 
    network = models.Sequential() 
    network.add(layers.Dense(512,activation="relu",input_shape=(28*28,))) 
    network.add(layers.Dense(10,activation="softmax")) 
    network.compile(optimizer="rmsprop",loss="categorical_crossentropy",metrics=["accuracy"]) 
    network.fit(train_images,train_labels,epochs=6,batch_size=128) 
    model_path = os.path.join("tmp","minist_model") 
    network.save(model_path) 
    model_binary = [item for item in streaming_tar.build_rows_from_file(model_path)] 
    return model_binary 
 
model_binary = ray.get(train.remote())   
ray_context.build_result(model_binary) 

保存模型:

代码语言:javascript
复制
save overwrite mnist_model as delta.`ai_model.mnist_model`; 

把模型注册成UDF函数

下面代码可以在Console中以脚本或者Notebook形态运行

代码语言:javascript
复制
!python env "PYTHON_ENV=source /Users/allwefantasy/opt/anaconda3/bin/activate ray1.3.0"; 
!python conf "schema=st(field(content,string))"; 
!python conf "mode=model"; 
!python conf "runIn=driver"; 
!python conf "rayAddress=127.0.0.1:10001"; 
 
 
-- 加载前面训练好的tf模型 
load delta.`ai_model.mnist_model` as mnist_model; 
 
-- 把模型注册成udf函数 
register Ray.`mnist_model` as model_predict where  
maxConcurrency="8" 
and debugMode="true" 
and registerCode=''' 
 
import ray 
import numpy as np 
from pyjava.api.mlsql import RayContext 
from pyjava.udf import UDFMaster,UDFWorker,UDFBuilder,UDFBuildInFunc 
 
ray_context = RayContext.connect(globals(), context.conf["rayAddress"]) 
 
def predict_func(model,v): 
    train_images = np.array([v]) 
    train_images = train_images.reshape((1,28*28)) 
    predictions = model.predict(train_images) 
    return {"value":[[float(np.argmax(item)) for item in predictions]]} 
 
UDFBuilder.build(ray_context,UDFBuildInFunc.init_tf,predict_func) 
 
''' and  
predictCode=''' 
 
import ray 
from pyjava.api.mlsql import RayContext 
from pyjava.udf import UDFMaster,UDFWorker,UDFBuilder,UDFBuildInFunc 
 
ray_context = RayContext.connect(globals(), context.conf["rayAddress"]) 
UDFBuilder.apply(ray_context) 
 
''' 
; 
 
-- 这个代码可以将分区数目减少,避免并发太高导致的排队等待 
-- load parquet.`/tmp/mnist` as mnist_data; 
-- save mnist_data as parquet.`/tmp/minst-8`  where fileNum="8"; 
 
load parquet.`/tmp/minst` as mnist_data; 
 
select cast(image as array<double>) as image from mnist_data limit 100 as new_mnist_data; 
 
select model_predict(array(image)) as predicted  from  new_mnist_data as output; 

原理讲解

PPT: 4.使用Ray作为Spark SQL UDF的执行引擎-祝威廉

视频:4.使用Ray作为Spark SQL UDF的执行引擎-祝威廉

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2021-08-10,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 训练一个Tensorflow模型
  • 把模型注册成UDF函数
    • 原理讲解
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档