Class Central is learner-supported. When you buy through links on our site, we may earn an affiliate commission.

Google

Checkpointing Flax NNX Models with Orbax - Part 1

Google via YouTube

Overview

Coursera Flash Sale
40% Off Coursera Plus for 3 Months!
Grab it
Learn how to effectively save and restore Flax NNX models using Orbax, the standard checkpointing library in the JAX ecosystem. Explore the fundamentals of NNX state management and understand how Orbax interacts with Flax NNX's Pythonic, stateful approach that feels closer to frameworks like PyTorch while maintaining all the advantages of JAX. Master the core concepts including what NNX state is, how Orbax is structured, and the complete workflow for saving and restoring a single, basic NNX model. Discover how NNX manages state and gain insights into advanced techniques for handling distributed training scenarios. This 13-minute tutorial from Google serves as the first part of a two-episode series, providing essential knowledge for developers working with JAX-based machine learning models who need reliable checkpointing solutions.

Syllabus

Checkpointing Flax NNX Models with Orbax (Part 1)

Taught by

Google Developers

Reviews

Start your review of Checkpointing Flax NNX Models with Orbax - Part 1

Never Stop Learning.

Get personalized course recommendations, track subjects and courses with reminders, and more.

Someone learning on their laptop while sitting on the floor.