8 min read

Building Advanced Neural Nets with JAX, Flax & Optax

AI

ThinkTools Team

AI Research Lead

Introduction

In the rapidly evolving landscape of deep learning, the ability to prototype and train sophisticated architectures with minimal friction is becoming a decisive factor for researchers and practitioners alike. The combination of JAX, Flax, and Optax has emerged as a powerful trio that marries the speed of just‑in‑time compilation, the modularity of a high‑level neural network library, and the flexibility of a modern optimizer ecosystem. This tutorial takes you through the end‑to‑end process of building a deep neural network that incorporates residual connections, a self‑attention module, and an adaptive learning‑rate schedule, all while keeping the code clean, readable, and highly performant.

Residual connections, popularized by ResNet, allow gradients to flow unimpeded through very deep stacks of layers, mitigating the vanishing‑gradient problem and enabling the training of networks with dozens or even hundreds of layers. Self‑attention, the core of transformer architectures, empowers a model to weigh the importance of different positions in a sequence, providing a global context that is especially valuable for tasks such as language modeling, time‑series forecasting, and image segmentation. Finally, adaptive optimizers like AdamW, coupled with learning‑rate schedulers such as cosine decay or linear warm‑up, help the training process converge faster and more reliably.

By the end of this post you will have a working implementation that can be adapted to a wide range of domains, from natural language processing to computer vision, and you will understand how each component interacts within the JAX/Flax/Optax ecosystem.

Main Content

Defining the Residual Block

A residual block is the building block of any deep residual network. In Flax, a block is typically defined as a subclass of nn.Module. The key idea is to apply a sequence of transformations to the input and then add the original input back to the transformed output. The following code demonstrates a simple 2‑layer residual block that uses nn.Dense layers and a ReLU activation:

import jax.numpy as jnp
from flax import linen as nn

class ResidualBlock(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        residual = x
        x = nn.Dense(self.features)(x)
        x = nn.relu(x)
        x = nn.Dense(self.features)(x)
        return nn.relu(x + residual)

Notice how the @nn.compact decorator allows us to declare sub‑modules inline. The addition x + residual is the hallmark of the residual connection. In practice, you might want to include a projection layer when the dimensionality of the input and output differ, but the core concept remains the same.

Building the Self‑Attention Module

Self‑attention is a mechanism that lets each element in a sequence attend to every other element. Flax provides a convenient nn.SelfAttention module, but for educational purposes we’ll construct a minimal version from scratch. The module computes queries, keys, and values, scales the dot‑product, applies a softmax, and then aggregates the values weighted by the attention scores.

class SimpleSelfAttention(nn.Module):
    num_heads: int
    head_dim: int

    @nn.compact
    def __call__(self, x):
        batch, seq_len, features = x.shape
        total_dim = self.num_heads * self.head_dim
        query = nn.Dense(total_dim)(x)
        key = nn.Dense(total_dim)(x)
        value = nn.Dense(total_dim)(x)

        query = query.reshape(batch, seq_len, self.num_heads, self.head_dim)
        key = key.reshape(batch, seq_len, self.num_heads, self.head_dim)
        value = value.reshape(batch, seq_len, self.num_heads, self.head_dim)

        attn_scores = jnp.einsum('bqhd,bkhd->bhqk', query, key) / jnp.sqrt(self.head_dim)
        attn_weights = nn.softmax(attn_scores, axis=-1)
        attn_output = jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
        attn_output = attn_output.reshape(batch, seq_len, total_dim)
        return nn.Dense(features)(attn_output)

This implementation demonstrates the core operations of multi‑head attention: linear projections, reshaping for head separation, scaled dot‑product, softmax weighting, and a final linear projection to restore the original feature dimension. In a production setting you would replace this with nn.SelfAttention for better performance and additional features such as dropout and bias handling.

Integrating Residuals and Attention

A powerful architecture often interleaves residual blocks with attention layers, allowing the network to learn both local and global representations. The following HybridBlock shows how to combine the two concepts:

class HybridBlock(nn.Module):
    features: int
    num_heads: int
    head_dim: int

    @nn.compact
    def __call__(self, x):
        # Residual path
        res = ResidualBlock(self.features)(x)
        # Attention path
        attn = SimpleSelfAttention(self.num_heads, self.head_dim)(x)
        # Merge and apply another residual
        merged = res + attn
        return ResidualBlock(self.features)(merged)

By stacking several HybridBlock instances, you can build a deep network that alternates between learning fine‑grained transformations and capturing long‑range dependencies.

Constructing the Full Model

With the building blocks defined, the full model is a simple stack of hybrid layers followed by a classification head. The model is parameterized by the number of layers, the feature size, and the attention hyper‑parameters.

class AdvancedModel(nn.Module):
    num_layers: int
    features: int
    num_heads: int
    head_dim: int
    num_classes: int

    @nn.compact
    def __call__(self, x):
        for _ in range(self.num_layers):
            x = HybridBlock(self.features, self.num_heads, self.head_dim)(x)
        x = jnp.mean(x, axis=1)  # Global average pooling
        return nn.Dense(self.num_classes)(x)

The global average pooling step collapses the sequence dimension, producing a fixed‑size representation that can be fed into a final dense layer for classification.

Choosing an Optimizer and Scheduler

Optax offers a rich set of optimizers and learning‑rate schedules. A common choice for transformer‑style models is AdamW with a cosine decay schedule and a linear warm‑up. The following snippet shows how to set up such an optimizer:

import optax

learning_rate = 1e-4
warmup_steps = 1000
total_steps = 10000

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=learning_rate,
    warmup_steps=warmup_steps,
    decay_steps=total_steps - warmup_steps,
    end_value=0.0,
)

optimizer = optax.adamw(learning_rate=schedule, weight_decay=1e-4)

The warmup_cosine_decay_schedule smoothly ramps the learning rate up during the first warmup_steps and then decays it following a cosine curve, which has been shown to improve convergence for large‑scale models.

Training Loop with JAX’s Functional Paradigm

JAX encourages a functional style of programming where the model parameters and optimizer state are explicitly passed to functions. The training loop below demonstrates how to perform a single training step, compute gradients, and update the parameters:

import jax
from flax.training import train_state

class TrainState(train_state.TrainState):
    pass

@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['labels']).mean()
        return loss
    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

The @jax.jit decorator compiles the function to XLA, ensuring that the entire training step runs on the GPU or TPU with minimal overhead. The TrainState class bundles the parameters, optimizer state, and any other training metadata.

Evaluation and Logging

After training, you typically evaluate the model on a held‑out validation set. The evaluation function mirrors the training step but without gradient computation:

@jax.jit
def eval_step(state, batch):
    logits = state.apply_fn({'params': state.params}, batch['inputs'])
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['labels']).mean()
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == batch['labels'])
    return loss, accuracy

You can then aggregate the metrics over the entire validation set and log them using a lightweight logger such as tensorboardX or wandb.

Conclusion

The combination of JAX, Flax, and Optax provides a compelling platform for building state‑of‑the‑art neural networks. Residual connections give you the depth you need without sacrificing trainability, self‑attention grants the model a global view of the data, and adaptive optimizers with sophisticated schedules accelerate convergence. By structuring your code around modular nn.Module classes and leveraging JAX’s just‑in‑time compilation, you can experiment with new architectures at a fraction of the cost of traditional deep learning frameworks.

Whether you are tackling a new research problem or deploying a production‑grade model, the patterns illustrated in this tutorial will serve as a solid foundation. The key takeaway is that a well‑designed architecture, coupled with the right training machinery, can unlock performance gains that would otherwise be difficult to achieve.

Call to Action

If you found this guide helpful, consider experimenting with the code on a dataset that interests you—be it text, images, or time‑series data. Feel free to tweak the number of layers, the attention heads, or the learning‑rate schedule to see how the model behaves. Share your results on GitHub or a blog post, and let the community know what worked and what didn’t. For deeper dives, explore the official JAX, Flax, and Optax documentation, or try integrating other libraries such as Haiku or Equinox for alternative design patterns. Happy coding, and may your models learn efficiently and elegantly!

We value your privacy

We use cookies, including Google Analytics, to improve your experience on our site. By accepting, you agree to our use of these cookies. Learn more