Article

Obsidian Source: Drafts / JAX - Why is Everyone So Excited About This Framework

Summary

Pending synthesis from local Obsidian source.

Original source title: An object with a name, to which we attach interpretation rules.

Extracted Preview

Introduction

The field of AI is going through an exponential surge, with new findings springing at an unprecedented rate. Accounting for Moore's law for data, a need for a highly performant framework to do ML is the an absolute necessity, as ultimately, unlocking the machine FLOPS is probably the main goal of any framework. There have been a lot of frameworks such as Tensorflow, PyTorch, and recently JAX that have tried to unlock these machine FLOPS, and for the purpose of this blog, we'll focus on JAX. There are a lot of this that make JAX unique, so let us jump right into it.

What is JAX?

JAX has been gaining a lot of traction in recent times, and for the right reasons. JAX allows researchers to write Python programs that are automatically scaled to leverage accelerators and supercomputers(without any additional effort). JAX was developed by Deepmind to meet a simple goal which is to balance rapid prototyping, quick iteration with the ability to deploy experiments at scale. For those aware with NumPy, think of JAX as just NumPy with Autodiff and nice distributed support. Keep these point in the back of your mind, and let us try to understand why these qualities in JAX are so important and have many people(and frontier AI labs) pivoting to it.

Core Features - JAX

Don't take just my word for it, Francois Chollet(Keras founder) tweeted recently that almost all players in Generative AI are pivoting to JAX for the because it is fast, scales really well and there is TPU support too. A one line explanation of JAX would go something like this - *JAX is basically NumPy on steroids, made for researchers to write efficient, performant and scalable workloads.* JAX has a lot of unique propositions for a performant library, so let us look into what makes JAX special:

Automatic Differentiation

Autodiff keeps track of the grads and stuff, pretty important for ML workflows. We'll cover Autodiff in a lot of detail in the coming up sections.

Just-In-Time Compilation

JAX uses a JIT compiler for speeding up entire blocks of code by exploiting any parallelism between them. Initially, we compile the function on its first use and later re-using the optimized version later, allowing efficient computations and lookup .

VMap(Auto Vectorization)

VMap(advanced vectorization) allows us to apply some function on one or more axes of a tensor. VMap vectorizes a function by adding a batch dimension to every primitive operation in the function.

PMap(SPMD programming)

JAX has built in support for Single Program Multiple Data Programming, allowing the same function to be run in parallel on it's own XLA device.

In the introductory JAX blog post, it was mentioned that "JAX has una anima di pura programmazione funzionale”(has a soul of pure functional programming), so let us now try to see why JAX is so awesome!!

Understanding JAX on a deeper level

Integration Notes

  • Source folder: /home/yashs/Documents/Docs/Obsidian/Research-Notes
  • Local source: /home/yashs/Documents/Docs/Obsidian/Research-Notes/Drafts/JAX - Why is Everyone So Excited About This Framework.md
  • Raw copy: raw/obsidian/research-notes/Drafts/JAX - Why is Everyone So Excited About This Framework.md

Links Created Or Updated

Open Questions