Usage

The following script is a basic example of how to use the library:

 1from jax import numpy as jnp, random
 2import synax
 3
 4# Create a module.
 5module = synax.MLP([2, 32, 3])
 6
 7# Create a PRNG key.
 8key = random.key(0)
 9
10# Sample initial parameters.
11w = module.init(key)
12
13# Define an input.
14x = jnp.ones(2)
15
16# Compute the output.
17y = module.apply(w, x)
18
19# Print the output.
20print(y)

Output:

[-1.2567853  -0.80044776  0.5694267 ]

A module has the following methods:

  • init takes a JAX PRNG key and returns initial parameters for the module.

  • apply takes the module’s parameters, together with any inputs, and returns the output of the module.

Defining a custom module

The following script shows how to define a custom module:

 1from jax import random, nn
 2
 3
 4class Affine:
 5    """Affine map."""
 6
 7    def __init__(
 8        self,
 9        input_dim,
10        output_dim,
11        weight_init=nn.initializers.he_normal(),
12        bias_init=nn.initializers.zeros,
13    ):
14        self.input_dim = input_dim
15        self.output_dim = output_dim
16        self.weight_init = weight_init
17        self.bias_init = bias_init
18
19    def init(self, key):
20        keys = random.split(key)
21        weight = self.weight_init(keys[0], (self.input_dim, self.output_dim))
22        bias = self.bias_init(keys[1], (self.output_dim,))
23        return {"weight": weight, "bias": bias}
24
25    def apply(self, params, input):
26        return input @ params["weight"] + params["bias"]
27
28
29module = Affine(3, 2)
30key = random.key(0)
31params = module.init(key)
32print(params)

Output:

{'weight': Array([[ 0.8737965 , -0.79177886],
       [-0.65683264, -1.0112412 ],
       [-0.7620363 ,  0.5188657 ]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}

Example: Training on MNIST

The following script trains a model on the MNIST dataset.

  1"""
  2pip install tensorflow-datasets optax matplotlib
  3"""
  4
  5import argparse
  6
  7import jax
  8import optax
  9import synax
 10import tensorflow_datasets as tfds
 11from jax import lax, random
 12from jax import numpy as jnp
 13from jax.experimental import io_callback
 14from matplotlib import pyplot as plt
 15
 16
 17def parse_args():
 18    p = argparse.ArgumentParser()
 19    p.add_argument("--dataset", type=str, default="mnist")
 20    p.add_argument("--model", type=str, default="lenet")
 21    p.add_argument("--optimizer", type=str, default="adam")
 22    p.add_argument("--lr", type=float, default=1e-3)
 23    p.add_argument("--seed", type=int, default=0)
 24    p.add_argument("--epochs", type=int, default=10)
 25    p.add_argument("--batch_size", type=int, default=32)
 26    return p.parse_args()
 27
 28
 29def sample_batch_indices(key, num_examples, batch_size):
 30    num_batches = num_examples // batch_size
 31    perm = random.permutation(key, num_examples)
 32    limit = num_batches * batch_size
 33    batch_indices = perm[:limit].reshape((num_batches, batch_size))
 34    remainder = perm[limit:]
 35    return batch_indices, remainder
 36
 37
 38def train(ds, info, model, optimizer, key, epochs, batch_size, epoch_callback):
 39    def get_example_loss(params, instance):
 40        image = instance["image"]
 41        label = instance["label"]
 42        image /= 255
 43        logits = model.apply(params, image)
 44        loss = optax.softmax_cross_entropy_with_integer_labels(logits, label)
 45        error = logits.argmax() != label
 46        return loss, {"loss": loss, "error": error}
 47
 48    def get_batch_loss(params, batch):
 49        losses, metrics = jax.vmap(get_example_loss, [None, 0])(params, batch)
 50        mean_loss = losses.mean(0)
 51        mean_metrics = jax.tree.map(lambda x: x.mean(0), metrics)
 52        return mean_loss, mean_metrics
 53
 54    def run_batch(state, batch):
 55        params = state["params"]
 56        opt_state = state["optimizer"]
 57        grads, metrics = jax.grad(get_batch_loss, has_aux=True)(params, batch)
 58        updates, opt_state = optimizer.update(grads, opt_state, params)
 59        params = optax.apply_updates(params, updates)
 60        state = {
 61            "params": params,
 62            "optimizer": opt_state,
 63            "batches": state["batches"] + 1,
 64            "epochs": state["epochs"],
 65        }
 66        return state, metrics
 67
 68    def run_epoch(state, key):
 69        def f(state, batch_indices):
 70            batch = jax.tree.map(lambda x: x[batch_indices], ds["train"])
 71            return run_batch(state, batch)
 72
 73        num_examples = info.splits["train"].num_examples
 74        batch_indices, _ = sample_batch_indices(key, num_examples, batch_size)
 75        state, train_metrics = lax.scan(f, state, batch_indices)
 76        state |= {"epochs": state["epochs"] + 1}
 77        train_metrics = jax.tree.map(lambda x: x.mean(0), train_metrics)
 78
 79        _, test_metrics = get_batch_loss(state["params"], ds["test"])
 80
 81        metrics = {"train": train_metrics, "test": test_metrics}
 82
 83        io_callback(epoch_callback, None, metrics, state)
 84
 85        return state, metrics
 86
 87    def run_trial(key):
 88        key, subkey = random.split(key)
 89        params = model.init(subkey)
 90
 91        opt_state = optimizer.init(params)
 92
 93        state = {
 94            "params": params,
 95            "optimizer": opt_state,
 96            "batches": 0,
 97            "epochs": 0,
 98        }
 99
100        keys = random.split(key, epochs)
101        state, metrics = lax.scan(run_epoch, state, keys)
102
103        return state, metrics
104
105    return run_trial(key)
106
107
108def get_optimizer(args):
109    match args.optimizer:
110        case "adam":
111            return optax.adam(args.lr)
112        case other:
113            raise NotImplementedError(other)
114
115
116def get_model(args, info):
117    image_shape = info.features["image"].shape
118    num_labels = info.features["label"].num_classes
119    match args.model:
120        case "lenet":
121            assert image_shape[:2] == (28, 28)
122            return synax.LeNet(input_channels=image_shape[2], outputs=num_labels)
123        case other:
124            raise NotImplementedError(other)
125
126
127def main(args):
128    ds, info = tfds.load(args.dataset, batch_size=-1, with_info=True)  # type: ignore
129    print(info.description)
130    ds = tfds.as_numpy(ds)
131    ds = jax.tree.map(jnp.asarray, ds)
132
133    model = get_model(args, info)
134    optimizer = get_optimizer(args)
135    key = random.key(args.seed)
136
137    def epoch_callback(metrics, state):
138        print(f"epochs: {state['epochs']}")
139        print(f"batches: {state['batches']}")
140        print(f"train loss: {metrics['train']['loss']:g}")
141        print(f"test loss: {metrics['test']['loss']:g}")
142        print(f"train error: {metrics['train']['error']:g}")
143        print(f"test error: {metrics['test']['error']:g}")
144        print()
145
146    state, metrics = train(
147        ds=ds,
148        info=info,
149        model=model,
150        optimizer=optimizer,
151        key=key,
152        epochs=args.epochs,
153        batch_size=args.batch_size,
154        epoch_callback=epoch_callback,
155    )
156
157    for metric_name in ["loss", "error"]:
158        fig, ax = plt.subplots()
159        ax.set(xlabel="epoch", ylabel=metric_name)
160        for split in metrics.keys():
161            ax.plot(metrics[split][metric_name], label=split)
162        ax.legend()
163
164    plt.show()
165
166
167if __name__ == "__main__":
168    main(parse_args())