ridgeCFit#
Purpose#
Fit a binary classification model using with an L2 penalty.
Format#
- mdl = ridgeCFit(y_train, X_train, lambda)#
- Parameters:
y_train (Nx1 vector) – The dependent variable.
x_train (NxP matrix.) – The independent variables.
lambda (Scalar, or Kx1 vector) – The L2 penalty parameter(s).
- Returns:
mdl (struct) –
An instance of a
ridgeModel
structure. An instance named mdl will have the following members:mdl.alpha_hat
(1 x nlambdas vector) The estimated value for the intercept for each provided lambda.
mdl.beta_hat
(P x nlambdas matrix) The estimated parameter values for each provided lambda.
mdl.mse_train
(nlambdas x 1 vector) The mean squared error for each set of parameters, computed on the training set.
mdl.lambda
(nlambdas x 1 vector) The lambda values used in the estimation.
Examples#
new;
library gml;
rndseed 23423;
// Create file name with full path
fname = getGAUSSHome("pkgs/gml/examples/breastcancer.csv");
// Load all variables from dataset, except for 'ID'
X = packr(loadd(fname, ". -ID"));
// Separate dependent and independent variables
y = X[., "class"];
X = delcols(X, "class");
// Split data into 70% training and 30% test set
{ X_train, X_test, y_train, y_test } = trainTestSplit(X, y, 0.7);
// Declare 'mdl' to be an 'ridgeModel' structure
// to hold the trained model
struct ridgeModel mdl;
// Set lambda vector
lambda = { 90, 36, 14 };
// Train the decision forest classifier with default settings
mdl = ridgeCFit(y_train, X_train, lambda);
// Make predictions on the test set, from our trained model
y_hat = ridgeCPredict(mdl, X_test);
print "";
fmt = "Classification report for lambda = %.2f";
sprintf(fmt, lambda[1]);
call classificationMetrics(y_test, y_hat[.,1]);
print "";
sprintf(fmt, lambda[3]);
call classificationMetrics(y_test, y_hat[.,3]);
The code above will print the following output:
Classification report for lambda = 90.00
===================================================
Classification metrics
===================================================
Class Precision Recall F1-score Support
0 0.64 1.00 0.78 131
1 0.00 0.00 0.00 74
Macro avg 0.32 0.50 0.39 205
Weighted avg 0.41 0.64 0.50 205
Accuracy 0.64 205
Classification report for lambda = 14.00
===================================================
Classification metrics
===================================================
Class Precision Recall F1-score Support
0 0.82 1.00 0.90 131
1 1.00 0.61 0.76 74
Macro avg 0.91 0.80 0.83 205
Weighted avg 0.88 0.86 0.85 205
Accuracy 0.86 205
See also
Functions ridgeCPredict()
, ridgeFit()