Training Set and Test Set in Machine Learning
By: Karen Tao, UX Researcher
April 8, 2020
In my last blog post, we learned that machine learning aims to create a model by learning from its input data, and the trained model should generalize well to new data. We will now discuss the difference between a training set, a validation set, and a test set. We will use our recently published Workforce Retention of Graduates report as an example. We first split the data into training data, validation data, and test data. More specifically, we divide the 22,073 unique graduates into these three subsets. Training set is a set of examples used for learning (Ripley, 1996). Given records of students in the training set, including attributes such as gender, age, and work history prior to beginning postsecondary education, etc. We allow the model to see, or learn from, the outcome of whether these students remained in the workforce in Utah.
The validation set is used to evaluate the model and to fine tune the parameters, for example to choose the number of hidden units in a neural network. The model occasionally sees this data, but it does not learn from this set. This sample of data is used to provide an unbiased evaluation of a model fit on the training dataset while tuning model hyperparameters, which control how the model is learned (Ripley, 1996).
Our model is finally evaluated on the test set to provide an estimate of its accuracy in the real world. After a model is trained, a test set is generally used to assess the performance of the classifier (Ripley, 1996). Unlike training data and validation data, test data is only used once a model is completely trained. These records are previously unseen by the model during training. In our example, the trained model now predicts whether students in the test data remained in the workforce in Utah by examining the attributes of their records. The output of the model, the prediction whether a given student in the test data remained for each of the students, is then compared with the ground truth, the actual retention records of the students in the test set. This allows us to evaluate the trained model for its ability to generalize to previously unseen data.
When we train a model using training data, one common mistake is overfitting. Overfitting means that what we’ve fit the model too much to the training data. Our model has learned to fit the training data too closely. The result may be that the model could perform well with the training data, but it cannot be generalized to new data. When this happens, the model learns the “noise” in the training data instead of the actual relationships between variables in the data.
Let’s now get coding to solidify our understanding of the split. We can easily implement the split in Python using the Scikit-Learn library. More specifically the `train_test_split method`. Below is an example:
import pandas as pd from sklearn.model_selection import train_test_split # Load the dataset df = pd.read_csv(“students”) # define the dependent variable (whether student remained in Utah) as y # outcome is 0/1 boolean for classification y = df.outcome # define all other variables as features X X = df.drop('outcome',axis=1) # create train/test variables using the train_test_split method X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) # our test set is 20% of the data. The model will not see these 20%
Now that we understand the difference between train and test data, we will move onto other machine learning topics next. In the meantime, don’t forget to check out the Workforce Retention of Graduates report.
Ripley, B. D. (1996). Pattern recognition and neural networks. Cambridge university press.