Dive deeper into JAX's powerful functional transformations. This course covers explicit pseudo-random number generation for reproducibility, automatic vectorization with jax.vmap for batching, an introduction to jax.shard_map for multi-device parallelism, the concept of PyTrees for handling complex data structures, and basic profiling/debugging techniques.
Overview
Syllabus
- Unit 1: Reproducible Randomness with jax.random: Keys, Splitting, and Determinism
- Creating Your First Random Samples
- Fixing Key Reuse for Independent Randomness
- Verifying JAX's Reproducibility Promise
- Multi Distribution Key Management Strategy
- Unit 2: Effortless Batching with jax.vmap
- Vectorizing Vector Norms with vmap
- Fixing Axis Bugs in vmap
- Mixed Batching with Vector Transformations
- Configuring Multiple Outputs with vmap
- Nested Function Vectorization with Mixed Arguments
- Unit 3: Parallel Universes: SPMD Parallelism with jax.shard_map and the Evolution Beyond pmap
- Creating Device Meshes for Parallel Computing
- Partitioning Data Across Device Mesh
- Applying shard_map for Parallel Computing
- Implementing Collective Operations for Cross Device Communication
- Specifying Output Distribution in Parallel Computing
- Unit 4: Nested Data Structures: Mastering JAX PyTrees
- Creating Your First Neural Network PyTree
- Element-wise PyTree Operations
- Handling Mixed Data in PyTrees
- Building Complex Neural Network PyTrees
- Unit 5: Peeking Inside: Profiling and Debugging JAX
- Timing JAX Operations Accurately
- Debugging JIT Functions with Debug Print
- Building a JAX Timer Decorator