This course guides you through building a simple Multi-Layer Perceptron (MLP) from scratch using only JAX and its NumPy API. You'll tackle the classic XOR problem, learning how to manage model parameters as PyTrees, implement the forward pass, define a loss function, compute gradients, and write a manual training loop with a basic optimizer step.
Overview
Syllabus
- Unit 1: MLP Foundations: XOR Data Preparation and Parameter Initialization
- Perfecting XOR Data for Neural Networks
- Defining Neural Network Layer Sizes
- Mastering Neural Network Weight Initialization
- Building MLP Parameter Initialization
- Unit 2: Forward Propagation: From Input to Prediction
- Bringing Neural Networks to Life
- Bringing Life to Neural Networks
- Hunt for the Network Bug
- Hunt Down the Forward Pass Bug
- Unit 3: Measuring Error: Loss Functions and Gradients in JAX
- Measuring Neural Network Error
- Connect Loss to Model Predictions
- Creating the Learning Engine
- Unit 4: Learning by Example: The Manual Training Loop
- Making Neural Networks Learn
- Wire Up Your Training Loop
- Complete the Training Loop
- Neural Network Report Card