Learning JAX by Implementing a Simple Language Model

JAX is a cool machine learning framework from Google. Language models (LMs) are the family of models that include ChatGPT and GPT-4, which have taken the machine learning world by storm.

What better way to learn about both than to implement a simple LM in JAX?


I’ll heavily rely on Andrej Karpathy’s amazing course Neural Networks: Zero to Hero for this implementation. The course starts from the very basics of neural networks and slowly builds up to implementing a GPT model.

If you are interested in learning about all of these things in detail, I highly recommend checking the course out. This blog post is only concerned about porting the implementation to JAX, and won’t go into the theoretical background of what is being implemented.

As for JAX, I’ve found the official documentation adequate so far.


The model I’m interested in building is the Multi Layer Perceptron (MLP) model from Building makemore Part 2: MLP. I think this is a good starting point due to the simplicity of the model built in this chapter.

Following the steps in this video, I will be building a character-level LM that looks at the previous three characters to predict which character will come next.


While displaying the relevant code blocks, I’ll also provide the PyTorch counterparts in a separate tab for easier comparison.

Let’s start by initializing the model parameters:

  • Torch
  • JAX
alphabet_size = 27
context_size = 3
neurons = 100
d = 2

key = jrandom.PRNGKey(0)
key, *subkeys = jrandom.split(key, 6)
subkeys = iter(subkeys)

C = jrandom.normal(next(subkeys), shape=(alphabet_size, d))
W1 = jrandom.normal(next(subkeys), shape=(d * context_size, neurons))
b1 = jrandom.normal(next(subkeys), shape=(neurons, ))
W2 = jrandom.normal(next(subkeys), shape=(neurons, alphabet_size))
b2 = jrandom.normal(next(subkeys), shape=(alphabet_size, ))

parameters = [C, W1, b1, W2, b2]

You can notice they are very similar. Other than method and parameter names, the only difference is how JAX handles random number generation. You can read more about that in the corresponding section of 🔪 JAX - The Sharp Bits 🔪.

Now let’s write the forward pass:

  • Torch
  • JAX
def model(C, W1, b1, W2, b2, X):
    emb = C[X]
    h = jnp.tanh(emb.reshape((-1, 6)) @ W1 + b1)
    logits = h @ W2 + b2
    logits -= logits.max(1, keepdims=True)
    counts = jnp.exp(logits)
    probs = counts / counts.sum(1, keepdims=True)
    return probs

Again very similar.

Now we need a loss function:

  • Torch
  • JAX
def loss_fn(C, W1, b1, W2, b2, X, Y):
    probs = model(C, W1, b1, W2, b2, X)
    loss = -jnp.log(probs[jnp.arange(X.shape[0]), Y]).mean()
    return loss

Here literally the only change is replacing torch with jnp.

And now the training function, which is the first major difference between PyTorch and JAX:

  • Torch
  • JAX
def train(C, W1, b1, W2, b2, Xb, Yb, lr=0.1):
    grad_fn = grad(loss_fn, argnums=(0, 1, 2, 3, 4))
    gC, gW1, gb1, gW2, gb2 = grad_fn(C, W1, b1, W2, b2, Xb, Yb)

    C -= lr * gC
    W1 -= lr * gW1
    b1 -= lr * gb1
    W2 -= lr * gW2
    b2 -= lr * gb2

    return C, W1, b1, W2, b2

In PyTorch, we get a tensor from the loss function, which we can call backward on to backpropagate the gradients to our parameters. Then, we can loop through all our parameters and update them in-place. We also need to remember to clear up the gradients for each parameter after every iteration, otherwise they would accumulate.

In JAX, we use grad on the loss function to obtain a new function that evaluates the gradients of that function with respect to the parameters we choose. We then use that function to find the gradients at the current point in our parameter space, and use those gradients to update them. Also notable: we need to return new arrays to replace the old ones, instead of updating them in-place as JAX follows a more functional paradigm than PyTorch.

Even though I passed each parameter and received their gradients explicitly in this implementation, JAX also has ways to pass all parameters and receive all gradients together using pytrees.

Finally, the training loop:

  • Torch
  • JAX
key = jrandom.PRNGKey(0)
for _ in range(10000):
    key, subkey = jrandom.split(key)
    ix = jrandom.randint(subkey, (32,), minval=0, maxval=Xtr.shape[0])
    Xb, Yb = Xtr[ix], Ytr[ix]
    C, W1, b1, W2, b2 = train(C, W1, b1, W2, b2, Xb, Yb)

Again we see only minimal differences due to how the two libraries handle random number generation differently.


I’ll use the same dataset of 32K English names as Andrej and split it into train / dev / test sets in the same way.

Training & Evaluation

Here are the training and dev losses during training:

  • Torch
  • JAX


These are the learned character embeddings:

  • Torch
  • JAX


These are a sample of names generated by the LM:

  • Torch
  • JAX
  • brone
  • jieai
  • krillo
  • jrayle
  • karireio
  • addelie


If I run 10000 steps of training on my CPU based on the code presented above, I get the following times:

PyTorch JAX
5.6s 158s

But JAX has one more core functionality I didn’t mention, Just In Time (JIT) compilation. When a function is decorated with the JIT decorator, all the operations in that function get compiled into a more efficient version that can then be run on any input of the same shape and type.

After we JIT the train function, we get:

PyTorch JAX JAX (after JIT)
5.6s 158s 5.9s

Nobody would normally train language models on MacBook CPUs, but this at least gives an idea on the importance of knowing how to use JIT when using JAX.


I believe this was a good introduction to JAX and to implementing language models. A good next step would be to implement the more sophisticated LMs from later in the series, and using that as an opportunity to learn Haiku!