课程评价 (0)

请对课程作出评价:
0/300

学员评价

暂无精选评价
7分钟

绘图API-示例

class PlotTest:
  def __init__(self):
    df = pd.read_csv('./data/iris.csv')
    _feature_names = ['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width']
    x = df[_feature_names]
    y = df['Class'].map(lambda x: _label_map[x])
​
    train_X, test_X, train_Y, test_Y = train_test_split(x, y, 
          test_size=0.3, stratify=y, shuffle=True, random_state=1)
    self._train_matrix = xgt.DMatrix(data=train_X, label=train_Y, 
             feature_names=_feature_names,
             feature_types=['float', 'float', 'float', 'float'])
    self._validate_matrix = xgt.DMatrix(data=test_X, label=test_Y, 
             feature_names=_feature_names,
             feature_types=['float', 'float', 'float', 'float'])
​
  def plot_test(self):
    params = {
      'booster': 'gbtree',
      'eta': 0.01,
      'max_depth': 5,
      'tree_method': 'exact',
      'objective': 'binary:logistic',
      'eval_metric': ['logloss', 'error', 'auc']
    }
    eval_rst = {}
    booster = xgt.train(params, self._train_matrix,
             num_boost_round=20, evals=([(self._train_matrix, 'valid1'),
                                         (self._validate_matrix, 'valid2')]),
             early_stopping_rounds=5, evals_result=eval_rst, verbose_eval=True)
    xgt.plot_importance(booster)
    plt.show()