我已经成功地训练了两个塔模型的谷歌顶点AI,按照这里的指南。
我现在想下载这个模型并在我自己的机器上尝试一些本地的推断,我已经与各种错误斗争了一段时间,现在我被困在以下几个方面:
代码:
import tensorflow as tf
import tensorflow_text
load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
tf.saved_model.load('model_path', options=load_options)
错误:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _get_op_def(self, type)
3957 try:
-> 3958 return self._op_def_cache[type]
3959 except KeyError:
KeyError: 'IO>DecodeJSON'
During handling of the above exception, another exception occurred:
NotFoundError Traceback (most recent call last)
~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in load_internal(export_dir, tags, options, loader_cls, filters)
905 loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
--> 906 ckpt_options, filters)
907 except errors.NotFoundError as err:
~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in __init__(self, object_graph_proto, saved_model_proto, export_dir, ckpt_options, filters)
133 function_deserialization.load_function_def_library(
--> 134 meta_graph.graph_def.library))
135 self._checkpoint_options = ckpt_options
~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/saved_model/function_deserialization.py in load_function_def_library(library, load_shared_name_suffix)
357 with graph.as_default():
--> 358 func_graph = function_def_lib.function_def_to_graph(copy)
359 _restore_gradient_functions(func_graph, renamed_functions)
~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/framework/function_def_to_graph.py in function_def_to_graph(fdef, input_shapes)
63 graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
---> 64 fdef, input_shapes)
65
~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/framework/function_def_to_graph.py in function_def_to_graph_def(fdef, input_shapes)
227 else:
--> 228 op_def = default_graph._get_op_def(node_def.op) # pylint: disable=protected-access
229
~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _get_op_def(self, type)
3962 pywrap_tf_session.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type),
-> 3963 buf)
3964 # pylint: enable=protected-access
NotFoundError: Op type not registered 'IO>DecodeJSON' in binary running on 192.168.1.105. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed.
During handling of the above exception, another exception occurred:
FileNotFoundError Traceback (most recent call last)
<ipython-input-2-39fe5910a28b> in <module>
5
6 load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
----> 7 tf.saved_model.load('query_model/20220219125209', options=load_options)
~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in load(export_dir, tags, options)
867 ValueError: If `tags` don't match a MetaGraph in the SavedModel.
868 """
--> 869 return load_internal(export_dir, tags, options)["root"]
870
871
~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in load_internal(export_dir, tags, options, loader_cls, filters)
907 except errors.NotFoundError as err:
908 raise FileNotFoundError(
--> 909 str(err) + "\n If trying to load on a different device from the "
910 "computational device, consider using setting the "
911 "`experimental_io_device` option on tf.saved_model.LoadOptions "
FileNotFoundError: Op type not registered 'IO>DecodeJSON' in binary running on 192.168.1.105. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed.
If trying to load on a different device from the computational device, consider using setting the `experimental_io_device` option on tf.saved_model.LoadOptions to the io_device such as '/job:localhost'.
问题似乎在于模型是经过训练的用户albert-base
,运行它需要一些额外的操作和包,这就是为什么我import tensorflow_text
我也尝试过要运行import tensorflow_io
,但是我收到一个错误,只是试图加载包,说明已经注册了一个S3文件系统。
任何帮助都将不胜感激!
发布于 2022-04-15 01:28:59
使用tensorflow 2.3和tensorflow io 0.15.0训练两个塔模型,您需要使用正确的版本,否则不能加载。也可以在实际加载模型之前使用import tensorflow_io
。
https://stackoverflow.com/questions/71185602
复制相似问题