Elevate your JAX deep learning projects with Flax and Optax. This course introduces Flax for concisely defining neural network architectures as Modules, and Optax for a comprehensive suite of optimizers. You'll refactor the XOR classifier from the previous course, leveraging these powerful libraries for more structured and maintainable code.
Overview
Syllabus
- Unit 1: Flax Modules: setup and __call__ Demystified
- Building Your First Flax Module
- Complete Your Neural Layer
- Bringing Your Layer to Life
- Enhancing Your Custom Flax Module
- Unit 2: Building an MLP with flax.linen.Dense
- Building Your Neural Network Foundation
- Wiring Up Your Neural Network
- Streamlining Your MLP Design
- Bringing Your MLP to Life
- Unit 3: Optax Optimizers: Beyond Gradient Descent
- Creating Your First Optimizer
- Building Your Training Foundation
- Completing Parameter Updates
- Complete Parameter Transformation
- Advanced Optimizer Transformations
- Unit 4: XOR Revisited: Building a Full Training Loop with Flax and Optax
- Training Loop Detective Work
- Building Your Training Foundation
- Building the Learning Engine
- Evaluating Your Model Performance