TinyML (Part 5) : Training data, validation data, test data

Imagine a scenario where you want to train a neural network to recognize shoes. It's like teaching somebody who had never seen a shoe before about what a shoe actually is so that in the future when they see an object, they can decide whether what they're seeing is a shoe or not. Now, we know there's a huge variety of shoes, and there's no hard and fast rule about what makes a shoe a shoe.

Typically, the steps that you'd follow are to get as many examples of shoes as possible, train a neural network using those examples, and then profit.

And you might end up with something like this. An increasing accuracy number as you train. By the time it finishes training, it has a very high accuracy, maybe even 100%. Now that might mean you think you've written an amazing AI that can recognize shoes, and it's time to profit.

But then you show it a shoe like this, and it fails. You thought you were 100% accurate in recognizing shoes. But the reality is that you are 100% accurate at recognizing the types of shoe that you trained the neural network on and that 100% figure led you to a false sense of security. Your perfect neural network architecture may not be so perfect after all.

There's a number of ways to help prevent this, and we'll look at the simplest first. Consider our shoe scenario where we had a lot of shoes, and we can represent them using the yellow rectangle. We used all of our data to train our neural networks, and why not? The more we train, the better, right? In principle this sounds fine, but we've no way of testing it against previously unseen data because we used all of our data to train it. So what if we hold back some of the data? Don't train with all of the data. Save some for validating that the neural network is training well.

How about holding back a little more that we can use to test our neural network after it's done training. Our first subset, the yellow rectangle, is the data that we'll train with. Our second subset, the red rectangle, is the data that we'll use to validate the training by testing after each step how the network is doing.

But we don't use this data to train the network. Our third and final subset is the blue rectangle.And this can be used as a clean set to test the network after we're done training it.

Following this methodology, we could, for example, pick a neural network architecture and train it. On the training data, it's 99.9% accurate.But it does worse on the other two sets. For example, say it's only 80% on the test set. This can be like our shoe scenario earlier. We designed a neural network that's great on the training data but not so great on the other data.

So why don't we redesign our neural network architecture and try again?

Our accuracy on training might go down but what's most important s keeping the accuracy of the network on the training data as close to the validation and test accuracy as possible. And that will give us a true signal of the real accuracy of the network.

So returning to the code that we've been using to train endless handwriting digits, recall that we had training images and labels and validation images and labels in the data set.

We could load them with one line of code in TensorFlow.

But when we trained, we just used the training images and labels. And we could see our training accuracy epoch by epoch. It got to 97 and 1/2% and just 20 epochs and that's pretty good. But that was just on the training data. We have no context for how good or how bad it is with previously unseen data.

So we can update our model.fit to use the validation data like this. When you train, you then get reports on the accuracy of the network on the training data step-by-step as well as the validation data. I've highlighted them in this diagram. And while the accuracy on the training data gets up to 97.59%, the accuracy on the validation data is at 96.26%. They'll almost never be the same. And your goal is to strive to get them as close to each other as possible. This is a sign that your neural network is good at generalizing and not over specializing to the training data.

If you have a test set, you can then perform model.evaluate on it after training is complete to see the results. And you'll get an accuracy value for that. You'd expect it to be similar to the validation one but maybe a little bit lower if your network is well-architected. Your confidence and your network's ability to classify data should be based on this value and not the training one.

Done!