我有一个包含260个显微图像的数据集。我想为逻辑回归algorithm.But生成学习曲线我得到了这个错误"'module‘object is not iterable“.Perhaps我不懂一些基本的东西,因为我是一个刚刚学习Python语言的初学者
from sklearn.cross_validation import train_test_split
from imutils import paths
from scipy import misc
import numpy as np
import argparse
import imutils
import cv2
import os
from matplotlib import pyplot as plt
from sklearn.model_selection import learning_curve
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.model_selection import cross_val_score
def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,
n_jobs=None, train_sizes=np.linspace(50, 80, 110)):
"""
Generate a simple plot of the test and training learning curve.
Parameters
----------
estimator : object type that implements the "fit" and "predict" methods
An object of that type which is cloned for each validation.
title : string
Title for the chart.
X : array-like, shape (n_samples, n_features)
Training vector, where n_samples is the number of samples and
n_features is the number of features.
y : array-like, shape (n_samples) or (n_samples, n_features), optional
Target relative to X for classification or regression;
None for unsupervised learning.
cv : int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`CV splitter`,
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if ``y`` is binary or multiclass,
:class:`StratifiedKFold` used. If the estimator is not a classifier
or if ``y`` is neither binary nor multiclass, :class:`KFold` is used.
Refer :ref:`User Guide <cross_validation>` for the various
cross-validators that can be used here.
n_jobs : int or None, optional (default=None)
Number of jobs to run in parallel.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.
train_sizes : array-like, shape (n_ticks,), dtype float or int
Relative or absolute numbers of training examples that will be used to
generate the learning curve. If the dtype is float, it is regarded as a
fraction of the maximum size of the training set (that is determined
by the selected validation method), i.e. it has to be within (0, 1].
Otherwise it is interpreted as absolute sizes of the training sets.
Note that for classification the number of samples usually have to
be big enough to contain at least one sample from each class.
(default: np.linspace(0.1, 1.0, 5))
"""
plt.figure()
plt.title(title)
if ylim is not None:
plt.ylim(*ylim)
plt.xlabel("Training examples")
plt.ylabel("Score")
train_sizes, train_scores, test_scores = learning_curve(
estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)
plt.grid()
plt.fill_between(train_sizes, train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std, alpha=0.1,
color="r")
plt.fill_between(train_sizes, test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std, alpha=0.1, color="g")
plt.plot(train_sizes, train_scores_mean, 'o-', color="r",
label="Training score")
plt.plot(train_sizes, test_scores_mean, 'o-', color="g",
label="Cross-validation score")
plt.legend(loc="best")
return plt
#training with logistic regression
clfLR = LogisticRegression(random_state=0, solver='lbfgs')
clfLR.fit(trainFeat,trainLabels)
acc = clfLR.score(testFeat, testLabels)
print("accuracy of Logistic regression ",acc)
只有当我想要绘制代码的curve.Rest时,我才会遇到这个问题。
#plotting the curve
estimator =LogisticRegression()
train_sizes, train_scores, valid_scores = plot_learning_curve(
estimator,'logistic learning curve ', trainFeat, trainLabels, cv=5, n_jobs=4,train_sizes=[50, 80, 110])
print(train_sizes)
plt.show()
错误的屏幕截图
发布于 2018-12-04 04:11:23
尝试在Jupyter online IDE上运行代码。如果您将"%matplotlib“行添加到导入部分,则会自动绘制。
如果你想继续在这个IDE上工作,请分享你的错误信息,也许我可以帮助你。您可能缺少一个导入,或者它可能是一个Python2/3问题。
https://stackoverflow.com/questions/53600669
复制相似问题