下面的代码仅支持Console notebook模式下运行
首先,准备minist数据集
include lib.`github.com/allwefantasy/lib-core` where
force="true" and
libMirror="gitee.com" and -- proxy configuration.
alias="libCore";
-- dump minist data to object storage
include local.`libCore.dataset.mnist`;
!dumpData /tmp/mnist;
load parquet.`/tmp/mnist` as mnist_data;
接着,可以进行训练了:
#%python
#%input=mnist_data
#%schema=file
#%output=mnist_model
#%env=source /Users/allwefantasy/opt/anaconda3/bin/activate ray1.3.0
#%cache=true
import ray
import os
from tensorflow.keras import models,layers
from tensorflow.keras import utils as np_utils
from pyjava.api.mlsql import RayContext
from pyjava.storage import streaming_tar
from pyjava import rayfix
import numpy as np
ray_context = RayContext.connect(globals(),"127.0.0.1:10001")
data_servers = ray_context.data_servers()
def data():
temp_data = [item for item in RayContext.collect_from(data_servers)]
train_images = np.array([np.array(item["image"]) for item in temp_data])
train_labels = np_utils.to_categorical(np.array([item["label"] for item in temp_data]) )
train_images = train_images.reshape((len(temp_data),28*28))
return train_images,train_labels
@ray.remote
@rayfix.last
def train():
train_images,train_labels = data()
network = models.Sequential()
network.add(layers.Dense(512,activation="relu",input_shape=(28*28,)))
network.add(layers.Dense(10,activation="softmax"))
network.compile(optimizer="rmsprop",loss="categorical_crossentropy",metrics=["accuracy"])
network.fit(train_images,train_labels,epochs=6,batch_size=128)
model_path = os.path.join("tmp","minist_model")
network.save(model_path)
model_binary = [item for item in streaming_tar.build_rows_from_file(model_path)]
return model_binary
model_binary = ray.get(train.remote())
ray_context.build_result(model_binary)
保存模型:
save overwrite mnist_model as delta.`ai_model.mnist_model`;
下面代码可以在Console中以脚本或者Notebook形态运行
!python env "PYTHON_ENV=source /Users/allwefantasy/opt/anaconda3/bin/activate ray1.3.0";
!python conf "schema=st(field(content,string))";
!python conf "mode=model";
!python conf "runIn=driver";
!python conf "rayAddress=127.0.0.1:10001";
-- 加载前面训练好的tf模型
load delta.`ai_model.mnist_model` as mnist_model;
-- 把模型注册成udf函数
register Ray.`mnist_model` as model_predict where
maxConcurrency="8"
and debugMode="true"
and registerCode='''
import ray
import numpy as np
from pyjava.api.mlsql import RayContext
from pyjava.udf import UDFMaster,UDFWorker,UDFBuilder,UDFBuildInFunc
ray_context = RayContext.connect(globals(), context.conf["rayAddress"])
def predict_func(model,v):
train_images = np.array([v])
train_images = train_images.reshape((1,28*28))
predictions = model.predict(train_images)
return {"value":[[float(np.argmax(item)) for item in predictions]]}
UDFBuilder.build(ray_context,UDFBuildInFunc.init_tf,predict_func)
''' and
predictCode='''
import ray
from pyjava.api.mlsql import RayContext
from pyjava.udf import UDFMaster,UDFWorker,UDFBuilder,UDFBuildInFunc
ray_context = RayContext.connect(globals(), context.conf["rayAddress"])
UDFBuilder.apply(ray_context)
'''
;
-- 这个代码可以将分区数目减少,避免并发太高导致的排队等待
-- load parquet.`/tmp/mnist` as mnist_data;
-- save mnist_data as parquet.`/tmp/minst-8` where fileNum="8";
load parquet.`/tmp/minst` as mnist_data;
select cast(image as array<double>) as image from mnist_data limit 100 as new_mnist_data;
select model_predict(array(image)) as predicted from new_mnist_data as output;