Creates a K-D tree model from training data for efficient KNN predictions.


mdl = knnFit(y, X, k)
  • y (Nx1 vector or string array) – The dependent, or target, variable.
  • X (NxP matrix) – The independent variables.
  • k (Scalar) – The number of neighbors.

mdl (struct) –

An instance of a knnModel structure. For an instance named mdl, the members will be:

mdl.opaqueModel Column vector, containing the K-D tree in opaque form.
mdl.classIndices Px1 matrix, where P is the number of classes in the target vector y.
mdl.classNames Px1 string array, where P is the number of classes in the target vector y, containing the class names if the target vector was a string array.
mdl.k Scalar, the number of neighbors to search.


library gml;

// Get file name with full path
fname = getGAUSSHome() $+ "pkgs/gml/examples/iris.csv";

// Load numeric predictors
X = loadd(fname, ". -Species");

// Load string labels
species = loaddSA(fname, "Species");

// Set seed for repeatable train/test sampling
rndseed 423432;

// Split data into (70%) train and (30%) test sets
{ y_train, y_test, X_train, X_test } = trainTestSplit(species, X, 0.7);

** Train the model

k = 3;

struct knnModel mdl;
mdl = knnFit(y_train, X_train, k);

** Predictions on the test set

y_hat = knnClassify(mdl, X_test);

print "prediction accuracy = " meanc(y_hat .$== y_test);

The above code will print the following output:

prediction accuracy = 0.956

See also