首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >FastText分类器模型的混淆矩阵计算

FastText分类器模型的混淆矩阵计算
EN

Stack Overflow用户
提问于 2017-10-27 15:06:33
回答 1查看 1.8K关注 0票数 4

我正在为Facebook FastText分类器计算方法中的混淆矩阵模型:

代码语言:javascript
运行
复制
#!/usr/local/bin/python3

import argparse
import numpy as np
from sklearn.metrics import confusion_matrix


def parse_labels(path):
    with open(path, 'r') as f:
        return np.array(list(map(lambda x: int(x[9:]), f.read().split())))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Display confusion matrix.')
    parser.add_argument('test', help='Path to test labels')
    parser.add_argument('predict', help='Path to predictions')
    args = parser.parse_args()
    test_labels = parse_labels(args.test)
    pred_labels = parse_labels(args.predict)

    print(test_labels)
    print(pred_labels)

    eq = test_labels == pred_labels
    print("Accuracy: " + str(eq.sum() / len(test_labels)))
    print(confusion_matrix(test_labels, pred_labels))

我的预测和测试集就像

代码语言:javascript
运行
复制
$ head -n10 /root/pexp 
__label__spam
__label__verified
__label__verified
__label__spam
__label__verified
__label__verified
__label__verified
__label__verified
__label__verified
__label__verified

$ head -n10 /root/dataset_test.csv 
__label__spam
__label__verified
__label__verified
__label__spam
__label__verified
__label__verified
__label__verified
__label__verified
__label__verified
__label__verified

模型的预测是通过这种方式在测试集上计算出来的:

代码语言:javascript
运行
复制
./fasttext predict /root/my_model.bin /root/dataset_test.csv > /root/pexp

然后进行计算FastText混淆矩阵:

代码语言:javascript
运行
复制
$ ./confusion.py /root/dataset_test.csv /root/pexp

但我被这个错误困住了:

代码语言:javascript
运行
复制
Traceback (most recent call last):
  File "./confusion.py", line 18, in <module>
    test_labels = parse_labels(args.test)
  File "./confusion.py", line 10, in parse_labels
    return np.array(list(map(lambda x: int(x[9:]), f.read().split())))
  File "./confusion.py", line 10, in <lambda>
    return np.array(list(map(lambda x: int(x[9:]), f.read().split())))
ValueError: invalid literal for int() with base 10: 'spam'

我按照建议修改了脚本,以处理非数字标签:

代码语言:javascript
运行
复制
def parse_labels(path):
    with open(path, 'r') as f:
        return np.array(list(map(lambda x: x[9:], f.read().split())))

此外,在FastText的情况下,测试集可能在某个时候具有规范化的标签(没有前缀__label__),因此要将其转换回前缀,您可以这样做:

代码语言:javascript
运行
复制
awk 'BEGIN{FS=OFS="\t"}{ $1 = "__label__" tolower($1) }1' /root/dataset_test.csv  > /root/dataset_test_norm.csv 

关于这个问题,请参见这里

此外,输入测试文件必须从标签列以外的其他列中剪切:

代码语言:javascript
运行
复制
cut -f 1 -d$'\t' /root/dataset_test_norm.csv > /root/dataset_test_norm_label.csv

最后,我们得到了混淆矩阵:

代码语言:javascript
运行
复制
$ ./confusion.py /root/dataset_test_norm_label.csv /root/pexp
Accuracy: 0.998852852227
[[9432    21]
 [    3 14543]]

我的最后一个解决方案是这里

更新

该脚本现在运行良好。我已经在我的FastText Node.js实现FastText.js 这里中直接添加了混淆矩阵计算脚本。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-05-25 02:50:17

代码语言:javascript
运行
复制
from sklearn.metrics import confusion_matrix

# predict the data
df["predicted"] = df["text"].apply(lambda x: model.predict(x)[0][0])

# Create the confusion matrix
confusion_matrix(df["labeled"], df["predicted"])


## OutPut:
# array([[5823,    8,  155,    1],
#        [ 199,   51,   22,    0],
#        [ 561,    2,  764,    0],
#        [  48,    0,    4,    4]], dtype=int64)
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46978152

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档