Class Central is learner-supported. When you buy through links on our site, we may earn an affiliate commission.

CodeSignal

JAX Fundamentals: NumPy Power-Up

via CodeSignal

Overview

This course introduces JAX, a high-performance numerical computation library. You'll learn how JAX extends NumPy's familiar API with key features like immutability, pure functions, automatic differentiation for gradient-based optimization, and just-in-time (JIT) compilation for significant speedups on accelerators like GPUs and TPUs.

Syllabus

  • Unit 1: JAX Arrays: The Immutable Successor to NumPy
    • Creating JAX Arrays from Scratch
    • Basic Operations with JAX Arrays
    • Embracing Immutability with JAX Arrays
    • Mastering JAX Functional Array Updates
    • Benchmarking JAX JIT Compilation Performance
  • Unit 2: Pure Functions: The Cornerstone of JAX
    • Identifying Pure and Impure Functions
    • Refactoring Impure Functions to Pure Alternatives
    • Pure Functions with JAX Arrays
    • Hunting Hidden Global Dependencies
  • Unit 3: Automatic Differentiation with jax.grad
    • Fixing Gradient Function Syntax Error
    • Efficient Gradient and Value Computation
    • Gradients of Trigonometric Functions
    • Computing Gradients for Multiple Variables
  • Unit 4: Speeding Up with jax.jit: Just-In-Time Compilation
    • Speeding Up Functions with JIT Decorator
    • Accurate Timing for Asynchronous JAX Operations
    • Fixing Control Flow with Static Arguments
  • Unit 5: Control Flow in JAX: Mastering jax.lax Primitives
    • Fix Conditionals for JIT Compilation
    • Conditional Math with Multiple Operands
    • Cumulative Products with JAX Scan
    • Factorial Sequence with JAX Scan
    • Dynamic Loops Finding Multiples

Reviews

Start your review of JAX Fundamentals: NumPy Power-Up

Never Stop Learning.

Get personalized course recommendations, track subjects and courses with reminders, and more.

Someone learning on their laptop while sitting on the floor.