Creates a binary classification model to predict the quality of wine using 11 physicochemical features. Uses the DataFrame API to read the raw data and prepare it.
#r "nuget:Microsoft.ML, 1.4.0"
#r "nuget:XPlot.Plotly, 3.0.1"
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using XPlot.Plotly;
public class BinaryClassificationData
{
[LoadColumn(0)]
public float FixedAcidity;
[LoadColumn(1)]
public float VolatileAcidity;
[LoadColumn(2)]
public float CitricAcid;
[LoadColumn(3)]
public float ResidualSugar;
[LoadColumn(4)]
public float Chlorides;
[LoadColumn(5)]
public float FreeSulfurDioxide;
[LoadColumn(6)]
public float TotalSulfurDioxide;
[LoadColumn(7)]
public float Density;
[LoadColumn(8)]
public float Ph;
[LoadColumn(9)]
public float Sulphates;
[LoadColumn(10)]
public float Alcohol;
[LoadColumn(11)]
public float Quality;
}
public class RichBinaryClassificationData: BinaryClassificationData
{
public bool Label => Quality > 5;
}
public class BinaryClassificationPrediction
{
public bool Label;
[ColumnName("PredictedLabel")]
public bool PredictedLabel;
public int LabelAsNumber => PredictedLabel ? 1 : 0;
}
#r "nuget:Microsoft.Data.Analysis,0.2.0"
using Microsoft.Data.Analysis;
using Microsoft.AspNetCore.Html;
// Convenient custom formatter.
Formatter<DataFrame>.Register((df, writer) =>
{
var headers = new List<IHtmlContent>();
headers.Add(th(i("index")));
headers.AddRange(df.Columns.Select(c => (IHtmlContent) th(c.Name)));
var rows = new List<List<IHtmlContent>>();
var take = 5;
for (var i = 0; i < Math.Min(take, df.Rows.Count); i++)
{
var cells = new List<IHtmlContent>();
cells.Add(td(i));
foreach (var obj in df.Rows[i])
{
cells.Add(td(obj));
}
rows.Add(cells);
}
var t = table(
thead(
headers),
tbody(
rows.Select(
r => tr(r))));
writer.Write(t);
}, "text/html");
var trainingData = DataFrame.LoadCsv(
"./WineQuality_White_Train.csv",
separator: ';',
columnNames: new[]
{
"FixedAcidity",
"VolatileAcidity",
"CitricAcid",
"ResidualSugar",
"Chlorides",
"FreeSulfurDioxide",
"TotalSulfurDioxide",
"Density",
"Ph",
"Sulphates",
"Alcohol",
"Quality"
});
display(trainingData);
// Create the Label column and add it to the data.
var labelCol = trainingData["Quality"].ElementwiseGreaterThanOrEqual(6);
labelCol.SetName("Label");
trainingData.Columns.Add(labelCol);
// This works, but we need the Quality column in later cells ...
// trainingData.Columns.Remove(trainingData["Quality"]);
display(trainingData);
var mlContext = new MLContext(seed: null);
// Define the pipeline.
var pipeline =
mlContext.Transforms.ReplaceMissingValues(
outputColumnName: "FixedAcidity",
replacementMode: MissingValueReplacingEstimator.ReplacementMode.Mean)
.Append(mlContext.Transforms.Concatenate("Features",
new[]
{
"FixedAcidity",
"VolatileAcidity",
"CitricAcid",
"ResidualSugar",
"Chlorides",
"FreeSulfurDioxide",
"TotalSulfurDioxide",
"Density",
"Ph",
"Sulphates",
"Alcohol"
}))
.Append(mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression());
var model = pipeline.Fit(trainingData);
// Load the raw test data.
var testData = mlContext.Data.LoadFromTextFile<BinaryClassificationData>(
"./WineQuality_White_Test.csv",
separatorChar: ';',
hasHeader: true);
// Calculate the Label (IDataView to IEnumerable to IDataView).
var stronglyTypedTestData = mlContext.Data.CreateEnumerable<RichBinaryClassificationData>(trainingData, false);
testData = mlContext.Data.LoadFromEnumerable(stronglyTypedTestData);
// Score the test data and calculate the metrics.
var scoredData = model.Transform(testData);
var qualityMetrics = mlContext.BinaryClassification.Evaluate(scoredData);
display(qualityMetrics);
string[] metricNames =
{
"Log Loss",
"Log Loss Reduction",
"Entropy",
"Area Under Curve",
"Accuracy",
"Positive Recall",
"Negative Recall",
"F1 Score"
};
double[] metricValues =
{
qualityMetrics.LogLoss,
qualityMetrics.LogLossReduction,
qualityMetrics.Entropy,
qualityMetrics.AreaUnderRocCurve,
qualityMetrics.Accuracy,
qualityMetrics.PositiveRecall,
qualityMetrics.NegativeRecall,
qualityMetrics.F1Score
};
var graph = new Graph.Bar()
{
x = metricValues,
y = metricNames,
orientation = "h",
marker = new Graph.Marker { color = "darkred" }
};
var chart = Chart.Plot(graph);
var layout = new Layout.Layout(){ title="Quality Metrics" };
chart.WithLayout(layout);
display(chart);
display(qualityMetrics.ConfusionMatrix);
Formatter<ConfusionMatrix>.Register((df, writer) =>
{
var rows = new List<IHtmlContent>();
var cells = new List<IHtmlContent>();
var n = df.Counts[0][0] + df.Counts[0][1] + df.Counts[1][0] + df.Counts[1][1];
cells.Add(td[rowspan: 2, colspan: 2, style: "text-align: center; background-color: transparent"]("n = " + n));
cells.Add(td[colspan: 2, style: "border: 1px solid black; text-align: center; padding: 24px; background-color: lightsteelblue"](b("Predicted")));
rows.Add(tr[style: "background-color: transparent"](cells));
cells = new List<IHtmlContent>();
cells.Add(td[style:"border: 1px solid black; padding: 24px; background-color: #E3EAF3"](b("True")));
cells.Add(td[style:"border: 1px solid black; padding: 24px; background-color: #E3EAF3"](b("False")));
rows.Add(tr[style: "background-color: transparent"](cells));
cells = new List<IHtmlContent>();
cells.Add(td[rowspan: 2, style:"border: 1px solid black; text-align: center; padding: 24px; background-color: lightsteelblue"](b("Actual")));
cells.Add(td[style:"border: 1px solid black; text-align: center; padding: 24px; background-color: #E3EAF3"](b("True")));
cells.Add(td[style:"border: 1px solid black; padding: 24px"](df.Counts[0][0]));
cells.Add(td[style:"border: 1px solid black; padding: 24px"](df.Counts[0][1]));
rows.Add(tr[style: "background-color: transparent"](cells));
cells = new List<IHtmlContent>();
cells.Add(td[style:"border: 1px solid black; text-align: center; padding: 24px; background-color: #E3EAF3"](b("False")));
cells.Add(td[style:"border: 1px solid black; padding: 24px"](df.Counts[1][0]));
cells.Add(td[style:"border: 1px solid black; padding: 24px"](df.Counts[1][1]));
rows.Add(tr(cells));
var t = table(
tbody(
rows));
writer.Write(t);
}, "text/html");
display(qualityMetrics.ConfusionMatrix);
// Create prediction engine
var predictionEngine = mlContext.Model.CreatePredictionEngine<RichBinaryClassificationData, BinaryClassificationPrediction>(model);
// Get a random data sample
var shuffledData = mlContext.Data.ShuffleRows(trainingData);
var rawSample = mlContext.Data.TakeRows(shuffledData, 1);
var sample = mlContext.Data.CreateEnumerable<RichBinaryClassificationData>(rawSample, false).First();
display(sample);
// Predict quality of sample
var prediction = predictionEngine.Predict(sample);
display(prediction);