首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >当我输入单个句子时,ok运行时-c++推断时间的使用是浮动的,但是输入文本文件是可以的。

当我输入单个句子时,ok运行时-c++推断时间的使用是浮动的,但是输入文本文件是可以的。
EN

Stack Overflow用户
提问于 2022-07-28 06:46:23
回答 1查看 147关注 0票数 0

我试图在onnxruntime++上部署一个bert模型,但是推断时间的使用让我很困惑。当输入为控制台上的单个句子时,的时间使用时间较长,且波动较大,而不是输入包含大量句子的文本文件。

会话的初始代码如下:

代码语言:javascript
运行
复制
class BertModel
{
    public:
        BertModel(){};
        BertModel(const char* path)
        {              
            // initial tokenizer
            string vocab_path = join(path, "vocab.txt");
            pTokenizer = new FullTokenizer(vocab_path);

            // onnxruntime setup
            Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "small_bert_onnx"); //Set a  Env for this session,the Env holds the logging state used by all other objects.
            session_options.SetIntraOpNumThreads(1); //Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose.
            session_options.SetInterOpNumThreads(1); //Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose.

            string model_path = join(path, "bert_model_quant.onnx");
            session = new Ort::Session(env, model_path.c_str(), session_options); //create a session,session is
            // session = new Ort::Session(env, model_path.c_str(), Ort::SessionOptions{ nullptr }); //don't do anyOptions

            size_t num_input_nodes = session->GetInputCount(); //  num_input_nodes size of model need,eg:(ids,mask,labels),your will get 3;
            char* input_name = session -> GetInputName(0, allocator);
            input_node_names = {input_name};
            output_node_names = {"logits"};
            // print input node types
            Ort::TypeInfo type_info = session -> GetInputTypeInfo(0);
            auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
            ONNXTensorElementDataType type = tensor_info.GetElementType();
            // print input shapes/dims
            input_node_dims = tensor_info.GetShape();
            cout << "session初始化成功" << endl;
        }

        string join(const char *a, const char *b);
        vector<long> textTokenizer(string text);
        int predicts(string text);

    private:
        FullTokenizer* pTokenizer;
        Ort::SessionOptions session_options;
        std::vector<int64_t> input_node_dims;
        std::vector<const char*> output_node_names;
        Ort::AllocatorWithDefaultOptions allocator; // allocator
        std::vector<const char*> input_node_names;
        Ort::Session* session;
};

我的预测函数如下:

代码语言:javascript
运行
复制
int BertModel::predicts(string text)
{   
    vector<long> input_tensor_values = textTokenizer(text);

    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);

    input_node_dims[0]=1;
    input_node_dims[1]=input_tensor_values.size();

    Ort::Value input_tensor = Ort::Value::CreateTensor<long>(memory_info,input_tensor_values.data(), 
                              input_tensor_values.size(), input_node_dims.data(), input_node_dims.size());
    assert(input_tensor.IsTensor());

    //outputs from session run is vector<Value>
    auto output_tensors = session -> Run(Ort::RunOptions{nullptr}, 
                                      input_node_names.data(), 
                                      &input_tensor,
                                      1,
                                      output_node_names.data(),
                                      1);

    // output_tensors 2, logitspreds
    // onnlogitsfloatfloat
    float* floatarr = output_tensors[0].GetTensorMutableData<float>();
    int res = max_element(floatarr, floatarr + 3) - floatarr;
    return res;
}

我的代码推断单句如下所示,输入是从控制台实时得到的

代码语言:javascript
运行
复制
    string text;
    while(true)
    {
        cout << "enter your input" << endl;
        getline(cin, text);
        high_resolution_clock::time_point beginTime = high_resolution_clock::now();
        int res = model.predicts(text);
        high_resolution_clock::time_point endTime = high_resolution_clock::now();
        milliseconds timeInterval = std::chrono::duration_cast<milliseconds>(endTime - beginTime);
        cout << "predict result:" << res << endl;
        cout << "time spent:" << timeInterval.count() << "ms" << endl;
    }
代码语言:javascript
运行
复制
你给我想没有包子铺的你也灯关的水都关了新的利润都被人骨的肌
predict result:1
time spent:16ms
enter your input
你给我想没有包子铺的你也灯关的水都关了新的利润都被人骨的肌
predict result:1
time spent:16ms
enter your input
你给我想没有包子铺的你也灯关的水都关了新的利润都被人骨的肌
predict result:1
time spent:14ms
enter your input
你给我想没有包子铺的你也灯关的水都关了新的利润都被人骨的肌
predict result:1
time spent:8ms
enter your input
你给我想没有包子铺的你也灯关的水都关了新的利润都被人骨的肌
predict result:1
time spent:15ms
enter your input
你给我想没有包子铺的你也灯关的水都关了新的利润都被人骨的肌
predict result:1
time spent:8ms
enter your input
你给我想没有包子铺的你也灯关的水都关了新的利润都被人骨的肌
predict result:1
time spent:13ms

我推断文本文件的代码如下所示:

代码语言:javascript
运行
复制
    string input_path = "../../test0711.txt";
    string output_path = "../../test0711_result.txt";
    ifstream input_file(input_path);
    ofstream output_file(output_path);
    if (!input_file.is_open()) {
        cerr << "Could not open the file - '"
             << input_path << "'" << endl;
        return EXIT_FAILURE;
    }

    if (!output_file.is_open()) {
        cerr << "Could not open the file - '"
             << output_path << "'" << endl;
        return EXIT_FAILURE;
    }

    int time_spent = 0;
    int seq_nums = 0;

    string line;
    while (getline(input_file, line))
    {   
        high_resolution_clock::time_point beginTime = high_resolution_clock::now(); //start time
        int res = model.predicts(line); //predicts single sentence
        high_resolution_clock::time_point endTime = high_resolution_clock::now(); //end time
        milliseconds timeInterval = std::chrono::duration_cast<milliseconds>(endTime - beginTime); //spent time
        cout << "bytes length of this sentence:" << line.size()/3 << endl;
        cout << "predict result:" << res << endl;
        cout << "time spent:" << timeInterval.count() << "ms" << endl;
        output_file << line << '\t' << res << '\t' << timeInterval.count() << "ms" << endl;
        time_spent += timeInterval.count();
        seq_nums++;
    }
    input_file.close();
    output_file.close();
代码语言:javascript
运行
复制
你知道什么是版权问题吗就是他们就是这个 1   6ms
北石店 2   3ms
我要去新街口  0   4ms
导航到向阳小区 0   4ms
只想守护你   0   3ms
将车道偏离预警开关打开 0   4ms
导航到南海意库 0   4ms
导航去1号公馆 0   4ms
1米制的恭喜发财    1   4ms
你给我想没有包子铺的你也灯关的水都关了新的利润都被人骨的肌   1   8ms
你吃不吃粑粑  1   4ms
导航去深圳湾创新科技中心    0   4ms
个性也没看就行了    1   4ms
三好听你就三个1390这个都是套餐5万双送给您的    1   6ms

显然,当我输入文本文件时,时间的使用与句子的长度成正比。那么,为什么会出现bug,我如何修复它呢?

EN

回答 1

Stack Overflow用户

发布于 2022-09-12 05:18:00

第一步是只测量要运行的调用所需的时间。您在计时中包括了前后处理,比如令牌化,所以这并不是衡量ONNX运行时花费多长时间的准确度量。

ONNX有一些逻辑来跟踪请求所需的内存使用情况,以便下次收到相同形状的输入时,它可以分配单个块。这可能会影响延迟,因为可以使用单个块的请求应该更快。可以用https://onnxruntime.ai/docs/api/c/struct_ort_1_1_session_options.html#a85495cc117b54771cb4d7632753532f0关闭

您是否在文件中使用与从控制台提供的内容完全相同的输入进行测试?两者之间的差异越小越好。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73148362

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档