首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Tensorflowlite on android缓冲区大小错误

Tensorflowlite on android缓冲区大小错误
EN

Stack Overflow用户
提问于 2021-07-23 14:34:29
回答 1查看 90关注 0票数 1

我正在尝试建立一个图像分类器android应用程序。我已经使用keras构建了我的模型。模型如下:

代码语言:javascript
运行
复制
model.add(MobileNetV2(include_top=False, weights='imagenet',input_shape=(224, 224, 3)))
model.add(GlobalAveragePooling2D())
model.add(Dropout(0.5))
model.add(Dense(3, activation='softmax'))

model.layers[0].trainable = False     
model.compile(optimizer='adam',  loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

输出:

代码语言:javascript
运行
复制
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenetv2_1.00_224 (Functi (None, 7, 7, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d_2 ( (None, 1280)              0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 1280)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 3)                 3843      
=================================================================
Total params: 2,261,827
Trainable params: 3,843
Non-trainable params: 2,257,984

训练完成后,我使用以下命令转换模型

代码语言:javascript
运行
复制
model = tf.keras.models.load_model('model.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open(f"myModel.tflite", "wb").write(tflite_model)

对于android,代码如下:

代码语言:javascript
运行
复制
        make_prediction.setOnClickListener(View.OnClickListener {
            var resized = Bitmap.createScaledBitmap(bitmap, 224, 224, true)
            val model = MyModel.newInstance(this)
            var tbuffer = TensorImage.fromBitmap(resized)
            var byteBuffer = tbuffer.buffer

// Creates inputs for reference.
            val inputFeature0 = TensorBuffer.createFixedSize(intArrayOf(1, 224, 224, 3), DataType.FLOAT32)
            inputFeature0.loadBuffer(byteBuffer)

// Runs model inference and gets result.
            val outputs = model.process(inputFeature0)
            val outputFeature0 = outputs.outputFeature0AsTensorBuffer

            var max = getMax(outputFeature0.floatArray)

            text_view.setText(labels[max])

// Releases model resources if no longer used.
            model.close()
        })

但是每当我尝试运行我的应用程序时,它就会关闭,并且我在logcat中得到这个错误。

代码语言:javascript
运行
复制
java.lang.IllegalArgumentException: The size of byte buffer and the shape do not match.

如果我将我的图像的输入形状从224改为300,并在300输入形状上训练我的模型,并将其插入android,我得到了anthor错误。

代码语言:javascript
运行
复制
java.lang.IllegalArgumentException: Cannot convert between a TensorFlowLite buffer with 1080000 bytes and a Java Buffer with 150528 bytes

任何形式的帮助都将是非常感谢的。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-07-26 12:22:49

像这样使用它:

代码语言:javascript
运行
复制
make_prediction.setOnClickListener(View.OnClickListener {
            var resized = Bitmap.createScaledBitmap(bitmap, 224, 224, true)
            val model = MyModel.newInstance(this)
            var tImage = TensorImage(DataType.FLOAT32)
            var tensorImage = tImage.load(resized)
            var byteBuffer = tensorImage.buffer

// Creates inputs for reference.
            //val inputFeature0 = TensorBuffer.createFixedSize(intArrayOf(1, 224, 224, 3), DataType.FLOAT32)
            //inputFeature0.loadBuffer(byteBuffer)

// Runs model inference and gets result.
            val outputs = model.process(byteBuffer)
            val outputFeature0 = outputs.outputFeature0AsTensorBuffer

            var max = getMax(outputFeature0.floatArray)

            text_view.setText(labels[max])

// Releases model resources if no longer used.
            model.close()
        })

如果问题仍然存在,请使用调试器进行检查,或者

val outputFeature0 = outputs.outputFeature0AsTensorBuffer导致了另一个问题。

如果您需要更多帮助,请联系我

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68494971

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档