trainTestSplit

Purpose

Returns test and training splits for a given set of dependent and independent variables.

Format

{ y_train, y_test, X_train, X_test } = trainTestSplit(y, X, train_pct[, shuffle])
Parameters:
  • y (Nx1 vector, or NxK matrix.) – The dependent variables.

  • X (Nx1 vector, or NxP matrix.) – The independent variables.

  • train_pct (Scalar) – The percentage of observations to include in the training set.

  • shuffle (String) – Optional input, “True” (default) or “False”.

Returns:
  • y_train – The (train_pct * N) observations from the original y which correspond to the observations selected for X_train.

  • y_test – The remaining observations from the original y not selected for the training set.

  • X_train – (train_pct * N) x P matrix of independent variables.

  • X_test – The remaining observations from the original X which were not selected to be in the training set.

Examples

Basic example

library gml;

// Set seed for repeatable sampling
rndseed 23324;

y = { 7, 2, 5, 1, 3, 4 };

X = { 1   3,
      9   6,
      6   1,
      8   4,
      9   5,
      1   8 };

// Shuffle data and create training set with 2/3 of
// the observations and 1/3 for the test set
{ y_train, y_test, X_train, X_test } = trainTestSplit(y, X, 0.67);

After the above code:

y_train = 3   X_train = 9    5
          7             1    3
          1             8    4
          4             1    8

y_test =  2   X_test =  9    6
          5             6    1

Example without shuffling

Sometimes, for example with time series data, you may not want to shuffle before creating your train and test splits.

y = { 7, 2, 5, 1, 3, 4 };

X = { 1   3,
      9   6,
      6   1,
      8   4,
      9   5,
      1   8 };

// Create training set in the original order with 2/3 of
// the observations and 1/3 for the test set
{ y_train, y_test, X_train, X_test } = trainTestSplit(y, X, 0.67, "False");

This time, the split data will be in the same order as the original data.

y_train = 7   X_train = 1    3
          2             9    6
          5             6    1
          1             8    4

y_test =  3   X_test =  9    5
          4             1    8

Remarks

If shuffle is enabled, the observations from X and y are first randomly shuffled such that the corresponding rows of X and y are kept together. For repeatable shuffling, use the rndseed keyword before calling trainTestSplit().

See also

Functions cvSplit(), rndi(), sampleData(), splitData()