knnFit#
Purpose#
Creates a K-D tree model from training data for efficient KNN predictions.
Format#
- mdl = knnFit(y, X, k)#
- Parameters:
y (Nx1 vector or string array) – The dependent, or target, variable.
X (NxP matrix) – The independent, or features, variables.
k (Scalar) – The number of neighbors.
- Returns:
mdl (struct) –
An instance of a
knnModel
structure. For an instance named mdl, the members will be:mdl.opaqueModel
Column vector, containing the K-D tree in opaque form.
mdl.classIndices
Px1 matrix, where P is the number of classes in the target vector y.
mdl.classNames
Px1 string array, where P is the number of classes in the target vector y, containing the class names if the target vector was a string array.
mdl.k
Scalar, the number of neighbors to search.
Remarks#
The KD-tree algorithm is an optimized approximate nearest neighbors algorithm. It provides high computational performance and high accuracy, but may not always return the exact nearest neighbors.
Examples#
new;
library gml;
// Set seed for repeatable train/test sampling
rndseed 423432;
/*
** Load data and prepare
*/
// Get file name with full path
fname = getGAUSSHome("pkgs/gml/examples/iris.csv");
// Get predictors
X = loadd(fname, ". -Species");
// Load labels
species = loadd(fname, "Species");
// Split data into (70%) train and (30%) test sets
{ y_train, y_test, X_train, X_test } = trainTestSplit(species, X, 0.7);
/*
** Train the model
*/
// Specify number of neighbors
k = 3;
struct knnModel mdl;
mdl = knnFit(y_train, X_train, k);
/*
** Predictions on the test set
*/
y_hat = knnClassify(mdl, X_test);
/*
** Model assessment
*/
call classificationMetrics(y_test, y_hat);
The above code will print the following output:
===========================================================================
Model: KNN Target Variable: Species
Number observations: 105 Number features: 4
Num. Neighbors: 3 Number of Classes: 3
===========================================================================
KNN Classification Prediction Frequencies:
=============================================
Label Count Total % Cum. %
setosa 14 31.11 31.11
versicolor 19 42.22 73.33
virginica 12 26.67 100
Total 45 100
=============================================
Observed Test Data Frequencies:
=============================================
Label Count Total % Cum. %
setosa 14 31.11 31.11
versicolor 19 42.22 73.33
virginica 12 26.67 100
Total 45 100
=============================================
===================================================
Classification metrics
===================================================
Class Precision Recall F1-score Support
setosa 1.00 1.00 1.00 14
versicolor 0.95 0.95 0.95 19
virginica 0.92 0.92 0.92 12
Macro avg 0.95 0.95 0.95 45
Weighted avg 0.96 0.96 0.96 45
Accuracy 0.96 45
See also