Class Central is learner-supported. When you buy through links on our site, we may earn an affiliate commission.

CodeSignal

JAX in Action: Neural Networks from Scratch

via CodeSignal

Overview

This course guides you through building a simple Multi-Layer Perceptron (MLP) from scratch using only JAX and its NumPy API. You'll tackle the classic XOR problem, learning how to manage model parameters as PyTrees, implement the forward pass, define a loss function, compute gradients, and write a manual training loop with a basic optimizer step.

Syllabus

  • Unit 1: MLP Foundations: XOR Data Preparation and Parameter Initialization
    • Perfecting XOR Data for Neural Networks
    • Defining Neural Network Layer Sizes
    • Mastering Neural Network Weight Initialization
    • Building MLP Parameter Initialization
  • Unit 2: Forward Propagation: From Input to Prediction
    • Bringing Neural Networks to Life
    • Bringing Life to Neural Networks
    • Hunt for the Network Bug
    • Hunt Down the Forward Pass Bug
  • Unit 3: Measuring Error: Loss Functions and Gradients in JAX
    • Measuring Neural Network Error
    • Connect Loss to Model Predictions
    • Creating the Learning Engine
  • Unit 4: Learning by Example: The Manual Training Loop
    • Making Neural Networks Learn
    • Wire Up Your Training Loop
    • Complete the Training Loop
    • Neural Network Report Card

Reviews

Start your review of JAX in Action: Neural Networks from Scratch

Never Stop Learning.

Get personalized course recommendations, track subjects and courses with reminders, and more.

Someone learning on their laptop while sitting on the floor.