我们有一个简单的ML模型,编译并保存为SavedModel/*.pb
格式。我们使用Java1.5 (Java)加载SavedModel/
来进行推理。
我们使用以下方法加载模型:
String path = 'models/SavedModel'
File modelFile = new File(getClass().getClassLoader().getResource(path).getPath());
model = SavedModelBundle.loader(modelFile.getAbsolutePath())
.withTags("serve")
.load();
Graph g = model.graph();
...
我们能够推断模型并使用IDE
获得输出,但是一旦构建了Jar,它就不起作用了。
重要说明:
SavedModel
保存在资源dir中,比如src/main/resources/models/SavedModel/*.pb
。SavedModelBundle.loader
将String exportDir
作为第一个参数。exportDir
是包含已保存模型的目录路径。models/SavedModel/*
。我们无法在jar中引用exportDir
的正确路径。有人能帮帮我吗?我是爪哇世界的新手!
发布于 2020-05-17 06:08:22
不幸的是,你不能做到这一点,至少不是那么容易。
TensorFlow不从JVM中读取保存的模型目录,而是从它的本机C++库中读取。这个库理解常规文件路径,但不理解Java资源路径。它从IDE中工作的原因是因为大多数IDE直接处理编译和存储在文件系统上的类,它们不必像命令行那样处理归档。
例如,如果我运行这段代码片段:
public class Main {
public static void main(String[] args) throws Exception {
String path = Main.class.getResource("model").getPath();
System.out.println("Loading model at " + path);
SavedModelBundle.load(path, "serve");
}
}
从IDE,我得到:
Loading model at /Users/klessard/Documents/Projects/MachineLearning/Sources/quick-java-test/target/classes/model
2020-05-17 01:53:31.355176: I tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: /Users/klessard/Documents/Projects/MachineLearning/Sources/quick-java-test/target/classes/model
2020-05-17 01:53:31.428273: I tensorflow/cc/saved_model/reader.cc:54] Reading meta graph with tags { serve }
2020-05-17 01:53:31.787746: I tensorflow/cc/saved_model/loader.cc:182] Restoring SavedModel bundle.
2020-05-17 01:53:33.762491: I tensorflow/cc/saved_model/loader.cc:132] Running initialization op on SavedModel bundle.
2020-05-17 01:53:34.178367: I tensorflow/cc/saved_model/loader.cc:285] SavedModel load for tags { serve }; Status: success. Took 2823199 microseconds.
但是,使用以下命令行java -jar target/quick-java-test-1.0-SNAPSHOT-jar-with-dependencies.jar
从JAR运行将给出如下结果:
Loading model at file:/Users/klessard/Documents/Projects/MachineLearning/Sources/quick-java-test/target/quick-java-test-1.0-SNAPSHOT-jar-with-dependencies.jar!/model
2020-05-17 01:55:25.701852: I tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: file:/Users/klessard/Documents/Projects/MachineLearning/Sources/quick-java-test/target/quick-java-test-1.0-SNAPSHOT-jar-with-dependencies.jar!/model
2020-05-17 01:55:25.701937: I tensorflow/cc/saved_model/loader.cc:285] SavedModel load for tags { serve }; Status: fail. Took 101 microseconds.
Exception in thread "main" org.tensorflow.TensorFlowException: Could not find SavedModel .pb or .pbtxt at supplied export directory path: file:/Users/klessard/Documents/Projects/MachineLearning/Sources/quick-java-test/target/quick-java-test-1.0-SNAPSHOT-jar-with-dependencies.jar!/model
at org.tensorflow.SavedModelBundle.load(Native Method)
at org.tensorflow.SavedModelBundle.access$000(SavedModelBundle.java:27)
at org.tensorflow.SavedModelBundle$Loader.load(SavedModelBundle.java:32)
at org.tensorflow.SavedModelBundle.load(SavedModelBundle.java:95)
at Main.main(Main.java:10)
比较这两种路径:在IDE中,它指向文件系统上的target/classes
文件夹(这是使用Maven编译类时的默认输出文件夹),而在命令行中,它指向JAR (quick-java-test-1.0-SNAPSHOT-jar-with-dependencies.jar!/model
)中的一个目录,这是特定于Java的。C++本机库不知道如何解析此路径并失败加载模型。
因此,我建议您将保存的模型存储在文件系统上,并使用常规的文件路径加载它(如果使用Docker,则在构建映像时可以这样做)。
另一个解决方案也是使用遗留格式导入图形,因为它直接作为序列化的proto消息传递,而不是从目录中读取。但是,我不确定TensorFlow仍然支持这种格式多长时间。
最后,还可以将保存的模型目录存档在参考资料下的zip文件中,在使用SavedModelBundle.load
加载模型之前将其解压缩到Java中的一个tmp文件夹中。
https://stackoverflow.com/questions/61833774
复制相似问题