This course introduces JAX, a high-performance numerical computation library. You'll learn how JAX extends NumPy's familiar API with key features like immutability, pure functions, automatic differentiation for gradient-based optimization, and just-in-time (JIT) compilation for significant speedups on accelerators like GPUs and TPUs.
Overview
Syllabus
- Unit 1: JAX Arrays: The Immutable Successor to NumPy
- Creating JAX Arrays from Scratch
- Basic Operations with JAX Arrays
- Embracing Immutability with JAX Arrays
- Mastering JAX Functional Array Updates
- Benchmarking JAX JIT Compilation Performance
- Unit 2: Pure Functions: The Cornerstone of JAX
- Identifying Pure and Impure Functions
- Refactoring Impure Functions to Pure Alternatives
- Pure Functions with JAX Arrays
- Hunting Hidden Global Dependencies
- Unit 3: Automatic Differentiation with jax.grad
- Fixing Gradient Function Syntax Error
- Efficient Gradient and Value Computation
- Gradients of Trigonometric Functions
- Computing Gradients for Multiple Variables
- Unit 4: Speeding Up with jax.jit: Just-In-Time Compilation
- Speeding Up Functions with JIT Decorator
- Accurate Timing for Asynchronous JAX Operations
- Fixing Control Flow with Static Arguments
- Unit 5: Control Flow in JAX: Mastering jax.lax Primitives
- Fix Conditionals for JIT Compilation
- Conditional Math with Multiple Operands
- Cumulative Products with JAX Scan
- Factorial Sequence with JAX Scan
- Dynamic Loops Finding Multiples