Usage

The following script is a basic example of how to use the library:

 1from jax import numpy as jnp, random
 2import synax
 3
 4# Create a module.
 5module = synax.MLP([2, 32, 3])
 6
 7# Create a PRNG key.
 8key = random.key(0)
 9
10# Sample initial parameters.
11w = module.init_params(key)
12
13# Define an input.
14x = jnp.ones(2)
15
16# Compute the output.
17y = module.apply(w, x)
18
19# Print the output.
20print(y)

Output:

[-1.2567853  -0.80044776  0.5694267 ]

A module has the following methods:

  • init_params takes a JAX PRNG key and returns initial parameters for the module.

  • apply takes the module’s parameters, together with any inputs, and returns the output of the module.

Defining a custom module

The following script shows how to define a custom module:

 1from jax import random, nn
 2
 3
 4class Affine:
 5    """Affine map."""
 6
 7    def __init__(
 8        self,
 9        input_dim,
10        output_dim,
11        weight_init=nn.initializers.he_normal(),
12        bias_init=nn.initializers.zeros,
13    ):
14        self.input_dim = input_dim
15        self.output_dim = output_dim
16        self.weight_init = weight_init
17        self.bias_init = bias_init
18
19    def init_params(self, key):
20        keys = random.split(key)
21        weight = self.weight_init(keys[0], (self.input_dim, self.output_dim))
22        bias = self.bias_init(keys[1], (self.output_dim,))
23        return {"weight": weight, "bias": bias}
24
25    def apply(self, params, input):
26        return input @ params["weight"] + params["bias"]
27
28
29module = Affine(3, 2)
30key = random.key(0)
31params = module.init_params(key)
32print(params)

Output:

{'weight': Array([[ 0.8737965 , -0.79177886],
       [-0.65683264, -1.0112412 ],
       [-0.7620363 ,  0.5188657 ]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}

Example: Training on MNIST

The following script trains a model on the MNIST dataset.

  1"""
  2pip install tensorflow-datasets optax matplotlib
  3"""
  4
  5import argparse
  6
  7import jax
  8import optax
  9import synax
 10import tensorflow_datasets as tfds
 11from jax import lax, random
 12from jax import numpy as jnp
 13from jax.experimental import io_callback
 14from matplotlib import pyplot as plt
 15
 16
 17def parse_args():
 18    p = argparse.ArgumentParser()
 19    p.add_argument("--dataset", type=str, default="mnist")
 20    p.add_argument("--model", type=str, default="lenet")
 21    p.add_argument("--optimizer", type=str, default="adam")
 22    p.add_argument("--lr", type=float, default=1e-3)
 23    p.add_argument("--seed", type=int, default=0)
 24    p.add_argument("--epochs", type=int, default=10)
 25    p.add_argument("--batch_size", type=int, default=32)
 26    return p.parse_args()
 27
 28
 29def sample_batch_indices(key, num_examples, batch_size):
 30    num_batches = num_examples // batch_size
 31    perm = random.permutation(key, num_examples)
 32    limit = num_batches * batch_size
 33    batch_indices = perm[:limit].reshape((num_batches, batch_size))
 34    remainder = perm[limit:]
 35    return batch_indices, remainder
 36
 37
 38def get_dataset_size(ds):
 39    leaves = jax.tree.leaves(ds)
 40    size = leaves[0].shape[0]
 41    assert all(leaf.shape[0] == size for leaf in leaves[1:])
 42    return size
 43
 44
 45def train(ds, model, optimizer, key, epochs, batch_size, epoch_callback, loss_fn):
 46    def get_example_loss(params, example):
 47        image = example["image"]
 48        label = example["label"]
 49        image /= 255
 50        logits = model.apply(params, image)
 51        loss = loss_fn(logits, label)
 52        error = logits.argmax() != label
 53        return loss, {"loss": loss, "error": error}
 54
 55    def get_batch_loss(params, batch):
 56        losses, metrics = jax.vmap(get_example_loss, [None, 0])(params, batch)
 57        mean_loss = losses.mean(0)
 58        mean_metrics = jax.tree.map(lambda x: x.mean(0), metrics)
 59        return mean_loss, mean_metrics
 60
 61    def run_batch(state, batch):
 62        params = state["params"]
 63        opt_state = state["optimizer"]
 64        grads, metrics = jax.grad(get_batch_loss, has_aux=True)(params, batch)
 65        updates, opt_state = optimizer.update(grads, opt_state, params)
 66        params = optax.apply_updates(params, updates)
 67        state = {
 68            "params": params,
 69            "optimizer": opt_state,
 70            "batches": state["batches"] + 1,
 71            "epochs": state["epochs"],
 72        }
 73        return state, metrics
 74
 75    def run_epoch(state, key):
 76        def f(state, batch_indices):
 77            batch = jax.tree.map(lambda x: x[batch_indices], ds["train"])
 78            return run_batch(state, batch)
 79
 80        num_examples = get_dataset_size(ds["train"])
 81        batch_indices, _ = sample_batch_indices(key, num_examples, batch_size)
 82        state, train_metrics = lax.scan(f, state, batch_indices)
 83        state |= {"epochs": state["epochs"] + 1}
 84        train_metrics = jax.tree.map(lambda x: x.mean(0), train_metrics)
 85
 86        _, test_metrics = get_batch_loss(state["params"], ds["test"])
 87
 88        metrics = {"train": train_metrics, "test": test_metrics}
 89
 90        io_callback(epoch_callback, None, metrics, state)
 91
 92        return state, metrics
 93
 94    def run_trial(key):
 95        key, subkey = random.split(key)
 96        params = model.init_params(subkey)
 97
 98        opt_state = optimizer.init(params)
 99
100        state = {
101            "params": params,
102            "optimizer": opt_state,
103            "batches": 0,
104            "epochs": 0,
105        }
106
107        keys = random.split(key, epochs)
108        state, metrics = lax.scan(run_epoch, state, keys)
109
110        return state, metrics
111
112    return run_trial(key)
113
114
115def get_optimizer(args):
116    match args.optimizer:
117        case "adam":
118            return optax.adam(args.lr)
119        case other:
120            raise NotImplementedError(other)
121
122
123def get_model(args, info):
124    image_shape = info.features["image"].shape
125    num_labels = info.features["label"].num_classes
126    match args.model:
127        case "lenet":
128            assert image_shape[:2] == (28, 28)
129            return synax.LeNet(input_channels=image_shape[2], outputs=num_labels)
130        case other:
131            raise NotImplementedError(other)
132
133
134def plot_metrics(metrics):
135    axes = {}
136
137    for split in metrics.keys():
138        for metric_name in metrics[split].keys():
139            if metric_name not in axes:
140                fig, ax = plt.subplots(constrained_layout=True)
141                axes[metric_name] = ax
142            ax = axes[metric_name]
143            ax.plot(metrics[split][metric_name], label=split)
144
145    for metric_name, ax in axes.items():
146        ax.legend()
147        ax.set(xlabel="epoch")
148        ax.set(ylabel=metric_name)
149
150
151def main(args):
152    ds, info = tfds.load(args.dataset, batch_size=-1, with_info=True)  # type: ignore
153    print(f"Dataset: {info.description}\n")
154    ds = tfds.as_numpy(ds)
155    ds = jax.tree.map(jnp.asarray, ds)
156
157    model = get_model(args, info)
158    optimizer = get_optimizer(args)
159    key = random.key(args.seed)
160
161    def epoch_callback(metrics, state):
162        print(f"epochs: {state['epochs']}")
163        print(f"batches: {state['batches']}")
164        for split in info.splits.keys():
165            for metric_name in metrics[split].keys():
166                value = metrics[split][metric_name]
167                print(f"{split} {metric_name}: {value:g}")
168        print()
169
170    state, metrics = train(
171        ds=ds,
172        model=model,
173        optimizer=optimizer,
174        key=key,
175        epochs=args.epochs,
176        batch_size=args.batch_size,
177        epoch_callback=epoch_callback,
178        loss_fn=optax.softmax_cross_entropy_with_integer_labels,
179    )
180
181    plot_metrics(metrics)
182
183    plt.show()
184
185
186if __name__ == "__main__":
187    main(parse_args())