Article

Website Source: blog / jax_blog

Summary

Pending synthesis from local website source.

Original source title: An object with a name, to which we attach interpretation rules.

Extracted Preview

Note: This blog is inspired my the Autodiadax blog of JAX - where they implement JAX core from scratch. I found it very confusing, so I made my own thing. Hope I was able to do it justice

JAX - Why is Everyone So Excited About This Framework

Introduction

The field of AI is going through an exponential surge, with new findings springing at an unprecedented rate. Accounting for Moore's law for data, a need for a highly performant framework to do ML is the an absolute necessity, as ultimately, unlocking the machine FLOPS is probably the main goal of any framework. There have been a lot of frameworks such as Tensorflow, PyTorch, and recently JAX that have tried to unlock these machine FLOPS, and for the purpose of this blog, we'll focus on JAX. There are a lot of this that make JAX unique, so let us jump right into it.

What is JAX?

JAX has been gaining a lot of traction in recent times, and for the right reasons. JAX allows researchers to write Python programs that are automatically scaled to leverage accelerators and supercomputers(without any additional effort). JAX was developed by Deepmind to meet a simple goal which is to balance rapid prototyping, quick iteration with the ability to deploy experiments at scale. For those aware with NumPy, think of JAX as just NumPy with Autodiff and nice distributed support. Keep these point in the back of your mind, and let us try to understand why these qualities in JAX are so important and have many people(and frontier AI labs) pivoting to it.

Core Features - JAX

Don't take just my word for it, Francois Chollet(Keras founder) tweeted recently that almost all players in Generative AI are pivoting to JAX for the because it is fast, scales really well and there is TPU support too. A one line explanation of JAX would go something like this - *JAX is basically NumPy on steroids, made for researchers to write efficient, performant and scalable workloads.* JAX has a lot of unique propositions for a performant library, so let us look into what makes JAX special:

Automatic Differentiation

Autodiff keeps track of the grads and stuff, pretty important for ML workflows. We'll cover Autodiff in a lot of detail in the coming up sections.

Just-In-Time Compilation

JAX uses a JIT compiler for speeding up entire blocks of code by exploiting any parallelism between them. Initially, we compile the function on its first use and later re-using the optimized version later, allowing efficient computations and lookup .

Integration Notes

  • Source section: blog
  • Local source: /home/yashs/Desktop/Programming/yash_blog/yash-srivastava19.github.io/blog/jax_blog.md
  • Raw copy: raw/website/yash-srivastava19-github-io/blog/jax_blog.md

Links Created Or Updated

Open Questions