机器学习:使用逻辑回归进行电影评论的情感分析完整实现

此文章由“火柴人爱科技”原创,首发于企鹅号。 文章相关图片请看文字下方

import matplotlib.pyplot as pltimport mglearnmglearn.tools.visualize_coefficients(grid.best_estimator_.coef_, feature_names, n_top_features=25)plt.show()

本文中,我们将重点分析IMDb电影评论数据,并尝试预测评论是正面还是负面。熟悉一些机器学习概念将有助于理解所使用的Python代码和机器学习算法。我们将使用流行的scikit-learn机器学习框架。

数据集准备:

我们将使用此处的数据集 - http://ai.stanford.edu/~amaas/data/sentiment/

下载数据集后,删除了不必要的文件/文件夹,以便文件夹结构如下所示 -

数据集的文件夹结构

将数据加载到Python程序中:

我们将加载并查看训练和测试数据,以了解数据的性质。训练和测试数据的格式相似。

from sklearn.datasets import load_filesreviews_train = load_files("aclImdb/train/")text_train, y_train = reviews_train.data, reviews_train.targetprint("Number of documents in train data: {}".format(len(text_train)))print("Samples per class (train): {}".format(np.bincount(y_train)))reviews_test = load_files("aclImdb/test/")text_test, y_test = reviews_test.data, reviews_test.targetprint("Number of documents in test data: {}".format(len(text_test)))print("Samples per class (test): {}".format(np.bincount(y_test)))

scikit - learn提供load_files来读取这种数据。加载数据后,我们打印了文件数量(训练/测试)和每类样品(pos / neg),如下所示 -

Number of documents in train data: 25000

Samples per class (train): [12500 12500]

Number of documents in test data: 25000

Samples per class (test): [12500 12500]

我们可以看到总共25000个训练和测试数据样本,每类pos和neg为12500。

将文本数据表示为Bag of Words:

我们希望将单词的出现次数计数为Bag of Words,其中包括图表中的以下步骤

Bag of words processing

为了将输入数据集表示为Bag of words,我们将使用CountVectorizer并调用它的transform方法。CountVectorizer是一个转换器,它将输入文档转换为稀疏的特征矩阵。Python代码如下:

from sklearn.feature_extraction.text import CountVectorizervect = CountVectorizer(min_df=5, ngram_range=(2, 2))X_train = vect.fit(text_train).transform(text_train)X_test = vect.transform(text_test)print("Vocabulary size: {}".format(len(vect.vocabulary_)))print("X_train:{}".format(repr(X_train)))print("X_test: {}".format(repr(X_test)))feature_names = vect.get_feature_names()print("Number of features: {}".format(len(feature_names)))

CountVectorizer使用两个参数 -

min_df(= 5):定义一个单词的最小频率,以便将其计为一个特征

ngram_range(=(2,2)):ngram_range参数是一个元组。它定义了所考虑的令牌序列的最小和最大长度。在这种情况下,这个长度是2.所以,这将找到2个令牌的序列,如''but the','wise man'等。

结果矩阵中的每个条目都被视为一个特征。上面代码片段的输出如下 -

Vocabulary size: 129549

X_train:

'

with 3607330 stored elements in Compressed Sparse Row format>

X_test:

'

with 3392376 stored elements in Compressed Sparse Row format>

Number of features: 129549

总共找到了129549个特征。

机器学习模型:

我们将使用LogisticRegression进行模型开发,就像我们的高维稀疏数据一样,LogisticRegression通常效果最好。

在开发模型时,我们需要做另外两件事 -

网格搜索:用于逻辑回归的参数调优。我们想要确定coefficeint ' C '的值能提供更好的准确性。

交叉验证:为了避免过度拟合数据。

from sklearn.model_selection import GridSearchCVfrom sklearn.linear_model import LogisticRegressionparam_grid = {'C': [0.001, 0.01, 0.1, 1, 10]}grid = GridSearchCV(LogisticRegression(), param_grid, cv=5)grid.fit(X_train, y_train)print("Best cross-validation score: {:.2f}".format(grid.best_score_))print("Best parameters: ", grid.best_params_)print("Best estimator: ", grid.best_estimator_)

在这里,我们使用GridSearchCV进行5-fold交叉验证。在拟合列车数据之后,我们看到best_score_、best_params_(“C”)和best_estimator_(我们将使用的模型)。

上面代码片段的输出如下所示

Best cross-validation score: 0.88Best parameters: {'C': 1}Best estimator: LogisticRegression(C=1, class_weight=None, dual=False, fit_intercept=True, intercept_scaling=1, max_iter=100, multi_class='warn', n_jobs=None, penalty='l2', random_state=None, solver='warn', tol=0.0001, verbose=0, warm_start=False)

我们有一个'C'= 1且准确率为88%的模型。

我们想要绘制最好和最差的前25个特征。Python代码如下:

import matplotlib.pyplot as pltimport mglearnmglearn.tools.visualize_coefficients(grid.best_estimator_.coef_, feature_names, n_top_features=25)plt.show()

mglearn:下载地址—https://github.com/amueller/mglearn

做出预测:

现在我们将使用训练模型对我们的测试数据进行预测。

lr = grid.best_estimator_lr.fit(X_train, y_train)lr.predict(X_test)print("Score: {:.2f}".format(lr.score(X_test, y_test)))

预测输出显示测试数据的分数为88%。

Score: 0.88

为了检查我们的模型如何对单个数据执行,我们将使用正面电影评论进行一次预测,而使用负面进行预测。

这个输出 -

Pos prediction: [1]

在这里,1意味着它预测了一个积极的评论。

这个输出 -

Neg prediction: [0]

在这里,0意味着它预测了负面评论

完整Python代码:

我们将使用此处的数据集 - http://ai.stanford.edu/~amaas/data/sentiment/

此文章由“火柴人爱科技”原创,首发于企鹅号。

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20181014A171OS00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 yunjia_community@tencent.com 删除。

扫码关注云+社区

领取腾讯云代金券