关于ML.Net的文档似乎不多,因为它似乎是相对较新的。我遇到了一个又一个问题,试图学习如何使用它,我终于想出了足够的方法,至少可以让它在没有出现错误的情况下运行;然而,我的模型似乎有问题。它总是以50%的概率返回0。我已经在下面列出了我的代码。有谁知道我可以探索的最新版本的ML.Net的好资源吗?下面的代码应该是建立一个二进制分类模型,可以预测一支球队是否会进入季后赛。数据只是上个赛季的最后结果,大部分数据都被删除了,所以剩下的数据只有平均年龄、胜利、输球和季后赛状态(1 =季后赛&0=没有季后赛)。
Program.cs
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
namespace MachineLearning2
{
class Program
{
static readonly string _trainDataPath = Path.Combine(Environment.CurrentDirectory, "trainingNHL.txt");
static readonly string _testDataPath = Path.Combine(Environment.CurrentDirectory, "testingNHL.txt");
static readonly string _modelPath = Path.Combine(Environment.CurrentDirectory, "Model.zip");
static TextLoader _textLoader;
static void Main(string[] args)
{
MLContext mlContext = new MLContext(seed: 0);
_textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
{
Separator = ",",
HasHeader = false,
Column = new[]
{
new TextLoader.Column("Age", DataKind.R4, 0),
new TextLoader.Column("Wins", DataKind.R4, 1),
new TextLoader.Column("Losses", DataKind.R4, 2),
new TextLoader.Column("Label", DataKind.R4, 3)
}
});
var model = Train(mlContext, _trainDataPath);
Evaluate(mlContext, model);
Predict(mlContext, model);
PredictWithModelLoadedFromFile(mlContext);
}
public static ITransformer Train(MLContext mlContext, string dataPath)
{
IDataView dataView = _textLoader.Read(dataPath);
var pipeline = mlContext.Transforms.Concatenate("Features","Age", "Wins", "Losses")
.Append(mlContext.BinaryClassification.Trainers.FastTree(numLeaves: 50, numTrees: 50, minDatapointsInLeafs: 20));
Console.WriteLine("=============== Create and Train the Model ===============");
var model = pipeline.Fit(dataView);
Console.WriteLine("=============== End of training ===============");
Console.WriteLine();
return model;
}
public static void Evaluate(MLContext mlContext, ITransformer model)
{
IDataView dataView = _textLoader.Read(_testDataPath);
Console.WriteLine("=============== Evaluating Model accuracy with Test data===============");
var predictions = model.Transform(dataView);
var metrics = mlContext.BinaryClassification.Evaluate(predictions, "Label");
Console.WriteLine();
Console.WriteLine("Model quality metrics evaluation");
Console.WriteLine("--------------------------------");
Console.WriteLine($"Accuracy: {metrics.Accuracy:P2}");
Console.WriteLine($"Auc: {metrics.Auc:P2}");
Console.WriteLine($"F1Score: {metrics.F1Score:P2}");
Console.WriteLine("=============== End of model evaluation ===============");
SaveModelAsFile(mlContext, model);
}
private static void SaveModelAsFile(MLContext mlContext, ITransformer model)
{
using (var fs = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
mlContext.Model.Save(model, fs);
Console.WriteLine("The model is saved to {0}", _modelPath);
}
public static void Predict(MLContext mlContext, ITransformer model)
{
var predictionFunction = model.MakePredictionFunction<NHLData, NHLPrediction>(mlContext);
NHLData sampleTeam = new NHLData
{
Age = 29,
Wins = 60,
Losses = 22
};
var resultprediction = predictionFunction.Predict(sampleTeam);
Console.WriteLine();
Console.WriteLine("=============== Prediction Test of model with a single sample and test dataset ===============");
Console.WriteLine();
Console.WriteLine($"Age: {sampleTeam.Age} | Wins: {sampleTeam.Wins} | Losses: {sampleTeam.Losses} | Prediction: {(Convert.ToBoolean(resultprediction.Prediction) ? "Yes" : "No")} | Probability: {resultprediction.Probability} ");
Console.WriteLine("=============== End of Predictions ===============");
Console.WriteLine();
}
public static void PredictWithModelLoadedFromFile(MLContext mlContext)
{
IEnumerable<NHLData> teams = new[]
{
new NHLData
{
Age = 29,
Wins = 30,
Losses = 52
},
new NHLData
{
Age = 35,
Wins = 80,
Losses = 2
}
};
ITransformer loadedModel;
using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
{
loadedModel = mlContext.Model.Load(stream);
}
// Create prediction engine
var nhlStreamingDataView = mlContext.CreateStreamingDataView(teams);
var predictions = loadedModel.Transform(nhlStreamingDataView);
// Use the model to predict whether comment data is toxic (1) or nice (0).
var predictedResults = predictions.AsEnumerable<NHLPrediction>(mlContext, reuseRowObject: false);
Console.WriteLine();
Console.WriteLine("=============== Prediction Test of loaded model with a multiple samples ===============");
var teamsAndPredictions = teams.Zip(predictedResults, (team, prediction) => (team, prediction));
foreach (var item in teamsAndPredictions)
{
Console.WriteLine($"Age: {item.team.Age} | Wins: {item.team.Wins} | Losses: {item.team.Losses} | Prediction: {(Convert.ToBoolean(item.prediction.Prediction) ? "Yes" : "No")} | Probability: {item.prediction.Probability} ");
}
Console.WriteLine("=============== End of predictions ===============");
}
}
}
NHLData.cs
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.Runtime.Api;
namespace MachineLearning2
{
public class NHLData
{
[Column(ordinal: "0")]
public float Age;
[Column(ordinal: "1")]
public float Wins;
[Column(ordinal: "2")]
public float Losses;
[Column(ordinal: "3", name: "Label")]
public float Playoffs;
}
public class NHLPrediction
{
[ColumnName("PredictedLabel")]
public bool Prediction { get; set; }
[ColumnName("Probability")]
public float Probability { get; set; }
[ColumnName("Score")]
public float Score { get; set; }
}
}
trainingNHL.txt (专栏:年龄,赢,输,季后赛)
28.4,53,18,1
27.5,54,23,1
28,51,24,1
28.3,49,26,1
29.5,45,26,1
28.8,45,27,1
29.1,45,29,1
27.7,44,29,1
26.4,43,30,1
28.5,42,32,0
27,36,35,0
26.8,36,40,0
28,33,39,0
30.2,30,39,0
26.5,29,41,0
27.1,25,45,0
testingNHL.txt (专栏:年龄,赢,输,季后赛)
26.8,52,20,1
28.6,50,20,1
28.4,49,26,1
28.7,44,25,1
27.7,47,29,1
27.4,42,26,1
26.4,45,30,1
27.8,44,30,0
28.5,44,32,0
28.4,37,35,0
28.4,35,37,0
28.7,34,39,0
28.2,31,40,0
27.8,29,40,0
29.3,28,43,0
发布于 2018-11-10 15:32:54
trainingNHL.txt是您正在使用的完整数据集还是其中的一个示例?我只是试着用FastTree进行训练,我看到了“警告: 50次增强迭代都无法生长一棵树,这通常是因为叶超参数中的最小文档设置得太高,不适合这个数据集。”
给定您在FastTree中设置的参数,您将需要更多的数据来训练一个有意义的模型。当我将minDatapointsInLeafs
改为2时,我可以训练一个非平凡的模型(尽管由于数据量的原因,结果仍然不太可靠)。您也可以尝试使用类似于AveragedPerceptron或SDCA的东西。
https://stackoverflow.com/questions/53244185
复制