7分钟
绘图API-示例
class PlotTest:
def __init__(self):
df = pd.read_csv('./data/iris.csv')
_feature_names = ['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width']
x = df[_feature_names]
y = df['Class'].map(lambda x: _label_map[x])
train_X, test_X, train_Y, test_Y = train_test_split(x, y,
test_size=0.3, stratify=y, shuffle=True, random_state=1)
self._train_matrix = xgt.DMatrix(data=train_X, label=train_Y,
feature_names=_feature_names,
feature_types=['float', 'float', 'float', 'float'])
self._validate_matrix = xgt.DMatrix(data=test_X, label=test_Y,
feature_names=_feature_names,
feature_types=['float', 'float', 'float', 'float'])
def plot_test(self):
params = {
'booster': 'gbtree',
'eta': 0.01,
'max_depth': 5,
'tree_method': 'exact',
'objective': 'binary:logistic',
'eval_metric': ['logloss', 'error', 'auc']
}
eval_rst = {}
booster = xgt.train(params, self._train_matrix,
num_boost_round=20, evals=([(self._train_matrix, 'valid1'),
(self._validate_matrix, 'valid2')]),
early_stopping_rounds=5, evals_result=eval_rst, verbose_eval=True)
xgt.plot_importance(booster)
plt.show()
学员评价