Overview
Master JAX from the ground up! This path takes you from NumPy-like basics and automatic differentiation, through advanced batching and PyTrees, to building and training deep neural networks with Flax and Optax—culminating in a real-world image classification project.
Syllabus
- Course 1: JAX Fundamentals: NumPy Power-Up
- Course 2: Advanced JAX: Transformations for Speed & Scale
- Course 3: JAX in Action: Neural Networks from Scratch
- Course 4: Beyond Pure JAX: Flax & Optax for Elegant ML
- Course 5: JAX in Action: Building an Image Classifier
Courses
-
This course introduces JAX, a high-performance numerical computation library. You'll learn how JAX extends NumPy's familiar API with key features like immutability, pure functions, automatic differentiation for gradient-based optimization, and just-in-time (JIT) compilation for significant speedups on accelerators like GPUs and TPUs.
-
Dive deeper into JAX's powerful functional transformations. This course covers explicit pseudo-random number generation for reproducibility, automatic vectorization with jax.vmap for batching, an introduction to jax.shard_map for multi-device parallelism, the concept of PyTrees for handling complex data structures, and basic profiling/debugging techniques.
-
This course guides you through building a simple Multi-Layer Perceptron (MLP) from scratch using only JAX and its NumPy API. You'll tackle the classic XOR problem, learning how to manage model parameters as PyTrees, implement the forward pass, define a loss function, compute gradients, and write a manual training loop with a basic optimizer step.
-
Elevate your JAX deep learning projects with Flax and Optax. This course introduces Flax for concisely defining neural network architectures as Modules, and Optax for a comprehensive suite of optimizers. You'll refactor the XOR classifier from the previous course, leveraging these powerful libraries for more structured and maintainable code.
-
Apply your JAX, Flax, and Optax skills to a practical project: building an image classification pipeline. This course covers setting up a multi-file project, loading and preprocessing image data (e.g., MNIST), defining a Convolutional Neural Network (CNN) with Flax, and implementing robust training and evaluation loops.