classificationMetrics#
Purpose#
Computes statistics to assess the quality of classification predictions and prints out a report.
Format#
- out = classificationMetrics(y_true, y_predict)#
- Parameters:
y_true (Nx1 vector, or dataframe.) – That represents the true class labels.
y_predict (Nx1 vector, or dataframe.) – That represents the predicted class labels.
- Returns:
out (struct) –
An instance of a
classQuality
structure. For an instance named out, the members are:out.confusionMatrix
\(kxk\) matrix, containing the computed confusion matrix.
out.precision
\(kx1\) dataframe, where each row contains precision for corresponding class \(\frac{tp}{tp + fp}\).
out.recall
\(kx1\) dataframe, where each row contains recall for the corresponding class, \(\frac{tp}{tp + fn}\)
out.fScore
\(kx1\) dataframe, where each row contains the F1-score for the corresponding class, \(\frac{(b^2 + 1) * tp}{(b^2 + 1) * tp + b^2 * fn + fp)}\) (b = 1) .
out.support
\(kx1\) dataframe, where each row contains the number of observations for the corresponding class.
out.macroPrecision
Scalar, the unweighted average of the precision for each class.
out.macroRecall
Scalar, the unweighted average of the recall for each class.
out.macroFScore
Scalar, the unweighted average of the F1-score for each class.
out.macroSupport
Scalar, the total number of observations.
out.accuracy
Scalar, range 0-1, the accuracy of the predicted labels for all classes.
out.classLabels
\(kx1\) string array, containing the labels for each class.
out.classes
\(kx1\) dataframe, containing the numeric keys and label names if given for each class.
Example#
Example 1: Basic use with binary labels
Example#
Example 1: Basic use with binary labels#
new;
library gml;
y_true = { 0, 0, 1, 0, 1, 1, 1, 0 };
y_pred = { 0, 0, 1, 0, 1, 0, 1, 0 };
call classificationMetrics(y_true, y_pred);
After the above code, the following report will be printed:
===================================================
Classification metrics
===================================================
Class Precision Recall F1-score Support
0 0.80 1.00 0.89 4
1 1.00 0.75 0.86 4
Macro avg 0.90 0.88 0.87 8
Weighted avg 0.90 0.88 0.87 8
Accuracy 0.88 8
Example 2: Dataframe inputs#
new;
library gml;
// Strings
string true_label = { "cat", "cat", "zebra", "zebra", "dog", "dog", "dog", "cat", "cat" };
string pred_label = { "cat", "cat", "zebra", "cat", "zebra", "cat", "dog", "cat", "dog" };
// Create dataframes
df_true = asDF(true_label, "Observed");
df_pred = asDF(pred_label, "Prediction");
call classificationMetrics(df_true, df_pred);
After the above code, the following report will be printed:
===================================================
Classification metrics
===================================================
Class Precision Recall F1-score Support
cat 0.60 0.75 0.67 4
dog 0.50 0.33 0.40 3
zebra 0.50 0.50 0.50 2
Macro avg 0.53 0.53 0.52 9
Weighted avg 0.54 0.56 0.54 9
Accuracy 0.56 9
Example 3: KNN classification model assessment#
new;
library gml;
rndseed 790837;
/*
** Load data and prepare data
*/
// Get file name with full path
fname = getGAUSSHome("pkgs/gml/examples/iris.csv");
// Get predictors
X = loadd(fname, ". -Species");
// Load labels
species = loadd(fname, "Species");
// Split data into (70%) train and (30%) test sets
{ y_train, y_test, x_train, x_test } = trainTestSplit(species, x, 0.7);
/*
** Train the model
*/
// Specify number of neighbors
k = 3;
struct knnModel mdl;
mdl = knnFit(y_train, X_train, k);
/*
** Predictions on the test set
*/
y_hat = knnClassify(mdl, X_test);
// Declare 'q' to be a classQuality structure
// to hold the statistics
struct classQuality q;
// Print diagnostic report
q = classificationMetrics(y_test, y_hat);
After the code above and the knn training printouts, we see the following report:
===================================================
Classification metrics
===================================================
Class Precision Recall F1-score Support
setosa 1.00 1.00 1.00 13
versicolor 0.94 1.00 0.97 15
virginica 1.00 0.94 0.97 17
Macro avg 0.98 0.98 0.98 45
Weighted avg 0.98 0.98 0.98 45
Accuracy 0.98 45
We can access any of the structure members from the classQuality
structure using the dot operator:
print "Macro precision for each class =";
print q.macroPrecision;
Macro precision for each class =
0.97916667
print (q.classes ~ q.precision);
Class Precision
setosa 1.0000000
versicolor 0.93750000
virginica 1.0000000