decForestPredict

Purpose

Predicts responses using the output from decForestCFit() or decForestRFit() and matrix of independent variables.

Format

predictions = decForestPredict(dfm, x_test)
Parameters:
  • dfm (struct) –

    An instance of the dfModel structure filled by decForestRFit() or decForestCFit() and containing the following relevant members:

    dfm.variableImportance Matrix, 1 x p, variable importance measure if computation of variable importance is specified, zero otherwise.
    dfm.oobError Scalar, out-of-bag error if OOB error computation is specified, zero otherwise.
    dfm.numClasses Scalar, number of classes if classification model, zero otherwise.
    dfm.opaqueModel Matrix, contains model details for internal use only.
  • x_test (NxP matrix) – The independent variables.
Returns:

predictions (Nx1 numeric or string vector) – The predictions.

Examples

new;
library gml;

rndseed 23423;

// Create file name with full path
fname = getGAUSSHome() $+ "pkgs/gml/examples/breastcancer.csv";

// Load all variables from dataset, except for 'ID'
X = loadd(fname, ". -ID");

// Separate dependent and independent variables
y = X[.,cols(X)];
X = delcols(X, cols(X));

// Split data into 70% training and 30% test set
{ X_train, X_test, y_train, y_test } = trainTestSplit(X, y, 0.7);

// Declare 'df_mdl' to be an 'dfModel' structure
// to hold the trained model
struct dfModel df_mdl;

// Train the decision forest classifier with default settings
df_mdl = decForestCFit(y_train, X_train);

// Make predictions on the test set, from our trained model
y_hat = decForestPredict(df_mdl, X_test);

// Print out model quality evaluation statistics
call binaryClassMetrics(y_test, y_hat);

The code above will print the following output:

            Confusion matrix
            ----------------

    Class +       54       2
    Class -        1     153

   Accuracy           0.9857
  Precision           0.9643
     Recall           0.9818
    F-score            0.973
Specificity           0.9871
        AUC           0.9845