前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >windows使用c_api调用tflite 2.3 dll

windows使用c_api调用tflite 2.3 dll

原创
作者头像
superhua
发布2021-01-04 12:23:17
3.2K2
发布2021-01-04 12:23:17
举报
文章被收录于专栏:CNNCNN

在上一篇文章【Win10系统编译Tensorflow Lite 2.3为动态链接库tensorflowlite_c.dll】介绍了如何在Windows平台下编译tflite为动态链接库tensorflowlite_c.dll,接下来介绍如何使用tensorflowlite_c.dll。上一篇文章中我们编译的tflite库为c语言接口,即c_api,在使用过程中,只需下面一条include语句即可:

代码语言:txt
复制
#include "tensorflow/lite/c/c_api.h"

注意,如果不想亲自动手编译,请直接将上一篇文章拉到最后,直接下载作者已编译好的库即可。

最近看到一个巨牛的人工智能教程,分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。平时碎片时间可以当小说看,【点这里可以去膜拜一下大神的“小说”】

0 准备tflite模型

前往【https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/README.md】找到如下mobilenet v3模型下载:

下载mobilenet v3 tflite
下载mobilenet v3 tflite

如果无法打开链接或者是无法下载,请到附件中下载。

1 加载模型

封装函数initModel,传入tflite模型路径,代码如下:

代码语言:txt
复制
void initModel(string path ) {
	  
	TfLiteModel* model = TfLiteModelCreateFromFile(path.c_str()); 
	TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
	interpreter = TfLiteInterpreterCreate(model, options); 
	if (interpreter == nullptr) {
		printf("Failed to create interpreter");
		cout << (path) << endl;
		return  ;
	}
	// Allocate tensor buffers.
	if (TfLiteInterpreterAllocateTensors(interpreter) != kTfLiteOk) {
		printf("Failed to allocate tensors!");
		return  ;
	} 

	input_tf = getInputTensorByName(interpreter, "input"); 
	output_tf = getOutputTensorByName(interpreter, "MobilenetV3/Predictions/Softmax"); 
  
}

上述代码中,主要使用了如下几个接口:

TfLiteModelCreateFromFile: 创建TfLiteModel对象 TfLiteInterpreterOptionsCreate: 设置一些选项,这里暂时没有设置更多的参数。 TfLiteInterpreterCreate:创建TfLiteInterpreter对象,PS: 这个对象有点Session的感觉。 TfLiteInterpreterAllocateTensors: 为所有的Tensor分配空间,用于向系统请求分配空间。 getOutputTensorByName和getInputTensorByName这两个函数是我这边单独封装。

getOutputTensorByNamegetInputTensorByName代码如下:

代码语言:txt
复制
TfLiteTensor * getOutputTensorByName(TfLiteInterpreter * interpreter, const char * name)
{
	int count = TfLiteInterpreterGetOutputTensorCount(interpreter);
	for (int i = 0; i < count; ++i) {
		TfLiteTensor* ts = (TfLiteTensor*)TfLiteInterpreterGetOutputTensor(interpreter, i);
		if (!strcmp(ts->name, name)) {
			return ts;
		}
	}
	return nullptr;
}
TfLiteTensor * getInputTensorByName(TfLiteInterpreter * interpreter, const char * name)
{
	int count = TfLiteInterpreterGetInputTensorCount(interpreter);
	for (int i = 0; i < count; ++i) {
		TfLiteTensor* ts = TfLiteInterpreterGetInputTensor(interpreter, i);
		if (!strcmp(ts->name, name)) {
			return ts;
		}
	}
	return nullptr;
}

2 前向推理

前向推理主要包括3步:

向输入Tensor拷贝输入数据 执行推理 从输出Tensor将运算结果拷贝出来

示例代码如下:

代码语言:txt
复制
void forward(float* data, int len) {
	TfLiteTensorCopyFromBuffer(input_tf, data, len*sizeof(float));
	TfLiteInterpreterInvoke(interpreter);
	float logits[1001];
	TfLiteTensorCopyToBuffer(output_tf, logits, 1001*sizeof(float));
	float maxV = -1;
	int maxIdx = -1;
	for (int i = 0; i < 1001; ++i) {
		if (logits[i] > maxV) {
			maxV = logits[i];
			maxIdx = i;
		}
		//printf("%d->%f\n", i, logits[i]);
	}
	cout << "类别:" << maxIdx << ",概率:" << maxV << endl;
}

上面代码写的比较粗糙,用起来不灵活,但是足够作为一个示例来使用了。

3 完整代码

接下来看完整代码,如下:

代码语言:txt
复制
#include "pch.h"
#include <map>
#include <iostream>
#include <sstream>
#include <fstream>
#include <string>
#include "tensorflow/lite/c/c_api.h"
#pragma comment( lib, "tensorflowlite_c.dll.if.lib" )
using namespace std;
TfLiteTensor* input_tf;
TfLiteTensor* output_tf;
TfLiteInterpreter* interpreter; 

TfLiteTensor * getOutputTensorByName(TfLiteInterpreter * interpreter, const char * name)
{
	int count = TfLiteInterpreterGetOutputTensorCount(interpreter);
	for (int i = 0; i < count; ++i) {
		TfLiteTensor* ts = (TfLiteTensor*)TfLiteInterpreterGetOutputTensor(interpreter, i);
		if (!strcmp(ts->name, name)) {
			return ts;
		}
	}
	return nullptr;
}
TfLiteTensor * getInputTensorByName(TfLiteInterpreter * interpreter, const char * name)
{
	int count = TfLiteInterpreterGetInputTensorCount(interpreter);
	for (int i = 0; i < count; ++i) {
		TfLiteTensor* ts = TfLiteInterpreterGetInputTensor(interpreter, i);
		if (!strcmp(ts->name, name)) {
			return ts;
		}
	}
	return nullptr;
}

void initModel(string path ) {
	  
	TfLiteModel* model = TfLiteModelCreateFromFile(path.c_str()); 
	TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
	interpreter = TfLiteInterpreterCreate(model, options); 
	if (interpreter == nullptr) {
		printf("Failed to create interpreter");
		cout << (path) << endl;
		return  ;
	}
	// Allocate tensor buffers.
	if (TfLiteInterpreterAllocateTensors(interpreter) != kTfLiteOk) {
		printf("Failed to allocate tensors!");
		return  ;
	} 

	input_tf = getInputTensorByName(interpreter, "input"); 
	output_tf = getOutputTensorByName(interpreter, "MobilenetV3/Predictions/Softmax"); 
  
}
void forward(float* data, int len) {
	TfLiteTensorCopyFromBuffer(input_tf, data, len*sizeof(float));
	TfLiteInterpreterInvoke(interpreter);
	float logits[1001];
	TfLiteTensorCopyToBuffer(output_tf, logits, 1001*sizeof(float));
	float maxV = -1;
	int maxIdx = -1;
	for (int i = 0; i < 1001; ++i) {
		if (logits[i] > maxV) {
			maxV = logits[i];
			maxIdx = i;
		}
		//printf("%d->%f\n", i, logits[i]);
	}
	cout << "类别:" << maxIdx << ",概率:" << maxV << endl;
}
long getSize(string path) {
	ifstream file(path, ios::in | ios::binary); 
	long l, m;
	l = file.tellg();
	file.seekg(0, ios::end);
	m = file.tellg();
	file.close();
	return m - l;
}
float* readBmp(string path, int& len) {
	len = getSize(path);
	unsigned char* buff = (unsigned char*)calloc(len, sizeof(unsigned char*));
	ifstream fin(path, std::ifstream::binary);
	fin.read(reinterpret_cast<char*>(buff), len *sizeof(unsigned char*));
	fin.close();
	float* data = (float*)calloc(len, sizeof(float));
	for (int i = 0; i < len; ++i) {
		data[i] = (buff[i]/255.0-0.5)*2;
	}
	free(buff);
	return data;
}
 
int main()
{    
	initModel("v3-small_224_0.75_float.tflite" );
	int size=0;
	float* bmp = readBmp("input.bin", size);
	forward(bmp, size );
}
 

4 运行结果

将下面图片作为输入:

输入图片
输入图片

运行上面代码,控制台输出如下:

代码语言:txt
复制
类别:896,概率:0.92355

这里没有对具体类别转为中文查看,因为作为demo不想再添加其他不太相关的代码(主要是懒)。label文件可以在附件中下载,打开label文件可以看到第896类:

输出结果
输出结果

5 附件

  1. mobilenet v3 tflite模型下载地址:http://askonline.tech/download/4.html
  2. 完整打包下载:http://askonline.tech/download/5.html

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 0 准备tflite模型
  • 1 加载模型
  • 2 前向推理
  • 3 完整代码
  • 4 运行结果
  • 5 附件
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档