It’s time to get your hand’s on training your Neural Network, because you are going are going to be capable of working through parameters with real data coming from the MNIST Kaggle competion.
Hands on real data
If you haven’t been following the previous steps, you can check out how Kaggle works and how to start with the MNIST competion. After getting confortable with the data you could look how to create basic Neural Networks and how to train them. If you are already up to date and ready for a challenge, buckle up and let’s go for it.
Some tips working with PyTorch
When working with Neural Networks in PyTorch, some functions expect specific types of tensors, you can look up the types here. But what caught me on this specific Neural net was the difference between torch.Tensor and torch.tensor. The images we pass in the forward pass is expected to be torch.FloatTensor, but the labels are expected to be torch.LongTensor, you can create the first type with torch.Tensor, and the second one with torch.tensor.
After creating your model you need to train it. This is where things got interesting, we have in total 42000 images, we’re going to create simple batches of 64 images and loop it through. The epochs are how many times we’re going to repeat this process.
- Choose your epoch size ( usually between 30 and 70 )
- Choose your batch size
- Iterate through your batch of images
- Calculate the outputs from the initial model
- Calculate the overall loss
- Use the optimizer to adjust the weights and biases
- Do it all over for the number of epochs
Testing your model
We can see that our model can predict that the handwritten digit is a zero
There are innumerous learning opportunities here, I’ve started with a Softmax function in the end layer and discover that the LogSoftmax worked way better. Also the Loss function and the optmizer could be chosen differently. PyTorch also has some cool functions to normalize and clean the data. An important thing also is validation and testing datasets, I’ll be making more posts on this.
- Cleaning and normalizing the data
- 15% of your dataset should be for validation
- 30% of your dataset should be for checking
- Play with the parameters of your Neural Network
If you want more content like this, I have a YouTube Channel