我想测试一下我加载到一个安卓项目中的TensorFlow Lite模型的推断。
我在Python环境中生成了一些输入,我想将其保存到文件中,加载到我的安卓应用程序中,并用于TFLite推断。我的输入有点大,一个例子是:
<class 'numpy.ndarray'>, dtype: float32, shape: (1, 596, 80)
我需要一些方法来序列化这个ndarray并将其加载到Android中。
有关TFLite推理的更多信息可以找到这里。本质上,这应该是一个多维度的原始浮点数数组,或者一个ByteBuffer。
最简单的方法是:
谢谢!
发布于 2020-12-09 07:53:56
最后,我发现了一个名为JavaNpy的方便的Java库,它允许您在Java中打开.npy文件,从而使用Android。
在Python方面,我以正常的方式保存了一个扁平的.npy
:
data_flat = data.flatten()
print(data_flat.shape)
np.save(file="data.npy", arr=data_flat)
在安卓系统中,我把这个放在assets
文件夹中。
然后我将其加载到JavaNpy中:
InputStream stream = context.getAssets().open("data.npy")
Npy npy = new Npy(stream);
float[] npyData = npy.floatElements();
并最终将其转换为一个TensorBuffer:
int[] inputShape = new int[]{1, 596, 80}; //the data shape before I flattened it
TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(inputShape, DataType.FLOAT32);
tensorBuffer.loadArray(npyData);
然后,我使用这个tensorBuffer来推断我的TFLite模型。
https://stackoverflow.com/questions/65203358
复制