Learning JAX by Implementing a Simple Language Model
JAX is a cool machine learning framework from Google. Language models (LMs) are the family of models that include ChatGPT and GPT-4, which have taken the machine learning world by storm.
What better way to learn about both than to implement a simple LM in JAX?
I’ll heavily rely on Andrej Karpathy’s amazing course Neural Networks: Zero to Hero for this implementation. The course starts from the very basics of neural networks and slowly builds up to implementing a GPT model.
If you are interested in learning about all of these things in detail, I highly recommend checking the course out. This blog post is only concerned about porting the implementation to JAX, and won’t go into the theoretical background of what is being implemented.
As for JAX, I’ve found the official documentation adequate so far.
The model I’m interested in building is the Multi Layer Perceptron (MLP) model from Building makemore Part 2: MLP. I think this is a good starting point due to the simplicity of the model built in this chapter.
Following the steps in this video, I will be building a character-level LM that looks at the previous three characters to predict which character will come next.
While displaying the relevant code blocks, I’ll also provide the PyTorch counterparts in a separate tab for easier comparison.
Let’s start by initializing the model parameters:
alphabet_size = 27 context_size = 3 neurons = 100 d = 2 key = jrandom.PRNGKey(0) key, *subkeys = jrandom.split(key, 6) subkeys = iter(subkeys) C = jrandom.normal(next(subkeys), shape=(alphabet_size, d)) W1 = jrandom.normal(next(subkeys), shape=(d * context_size, neurons)) b1 = jrandom.normal(next(subkeys), shape=(neurons, )) W2 = jrandom.normal(next(subkeys), shape=(neurons, alphabet_size)) b2 = jrandom.normal(next(subkeys), shape=(alphabet_size, )) parameters = [C, W1, b1, W2, b2]
You can notice they are very similar. Other than method and parameter names, the only difference is how JAX handles random number generation. You can read more about that in the corresponding section of 🔪 JAX - The Sharp Bits 🔪.
Now let’s write the forward pass:
def model(C, W1, b1, W2, b2, X): emb = C[X] h = jnp.tanh(emb.reshape((-1, 6)) @ W1 + b1) logits = h @ W2 + b2 logits -= logits.max(1, keepdims=True) counts = jnp.exp(logits) probs = counts / counts.sum(1, keepdims=True) return probs
Again very similar.
Now we need a loss function:
def loss_fn(C, W1, b1, W2, b2, X, Y): probs = model(C, W1, b1, W2, b2, X) loss = -jnp.log(probs[jnp.arange(X.shape), Y]).mean() return loss
Here literally the only change is replacing
And now the training function, which is the first major difference between PyTorch and JAX:
def train(C, W1, b1, W2, b2, Xb, Yb, lr=0.1): grad_fn = grad(loss_fn, argnums=(0, 1, 2, 3, 4)) gC, gW1, gb1, gW2, gb2 = grad_fn(C, W1, b1, W2, b2, Xb, Yb) C -= lr * gC W1 -= lr * gW1 b1 -= lr * gb1 W2 -= lr * gW2 b2 -= lr * gb2 return C, W1, b1, W2, b2
In PyTorch, we get a tensor from the loss function, which we can call
backward on to backpropagate the gradients to our parameters.
Then, we can loop through all our parameters and update them in-place.
We also need to remember to clear up the gradients for each parameter after every iteration, otherwise they would accumulate.
In JAX, we use
grad on the loss function to obtain a new function that evaluates the gradients of that function with respect to the parameters we choose.
We then use that function to find the gradients at the current point in our parameter space, and use those gradients to update them.
Also notable: we need to return new arrays to replace the old ones, instead of updating them in-place as JAX follows a more functional paradigm than PyTorch.
Even though I passed each parameter and received their gradients explicitly in this implementation, JAX also has ways to pass all parameters and receive all gradients together using pytrees.
Finally, the training loop:
key = jrandom.PRNGKey(0) for _ in range(10000): key, subkey = jrandom.split(key) ix = jrandom.randint(subkey, (32,), minval=0, maxval=Xtr.shape) Xb, Yb = Xtr[ix], Ytr[ix] C, W1, b1, W2, b2 = train(C, W1, b1, W2, b2, Xb, Yb)
Again we see only minimal differences due to how the two libraries handle random number generation differently.
I’ll use the same dataset of 32K English names as Andrej and split it into train / dev / test sets in the same way.
Training & Evaluation
Here are the training and dev losses during training:
These are the learned character embeddings:
These are a sample of names generated by the LM:
If I run 10000 steps of training on my CPU based on the code presented above, I get the following times:
But JAX has one more core functionality I didn’t mention, Just In Time (JIT) compilation. When a function is decorated with the JIT decorator, all the operations in that function get compiled into a more efficient version that can then be run on any input of the same shape and type.
After we JIT the
train function, we get:
|PyTorch||JAX||JAX (after JIT)|
Nobody would normally train language models on MacBook CPUs, but this at least gives an idea on the importance of knowing how to use JIT when using JAX.
I believe this was a good introduction to JAX and to implementing language models. A good next step would be to implement the more sophisticated LMs from later in the series, and using that as an opportunity to learn Haiku!