Usage¶
The following script is a basic example of how to use the library:
1from jax import numpy as jnp, random
2import synax
3
4# Create a module.
5module = synax.MLP([2, 32, 3])
6
7# Create a PRNG key.
8key = random.key(0)
9
10# Sample initial parameters.
11w = module.init_params(key)
12
13# Define an input.
14x = jnp.ones(2)
15
16# Compute the output.
17y = module.apply(w, x)
18
19# Print the output.
20print(y)
Output:
[-1.2567853 -0.80044776 0.5694267 ]
A module has the following methods:
init_params
takes a JAX PRNG key and returns initial parameters for the module.apply
takes the module’s parameters, together with any inputs, and returns the output of the module.
Defining a custom module¶
The following script shows how to define a custom module:
1from jax import random, nn
2
3
4class Affine:
5 """Affine map."""
6
7 def __init__(
8 self,
9 input_dim,
10 output_dim,
11 weight_init=nn.initializers.he_normal(),
12 bias_init=nn.initializers.zeros,
13 ):
14 self.input_dim = input_dim
15 self.output_dim = output_dim
16 self.weight_init = weight_init
17 self.bias_init = bias_init
18
19 def init_params(self, key):
20 keys = random.split(key)
21 weight = self.weight_init(keys[0], (self.input_dim, self.output_dim))
22 bias = self.bias_init(keys[1], (self.output_dim,))
23 return {"weight": weight, "bias": bias}
24
25 def apply(self, params, input):
26 return input @ params["weight"] + params["bias"]
27
28
29module = Affine(3, 2)
30key = random.key(0)
31params = module.init_params(key)
32print(params)
Output:
{'weight': Array([[ 0.8737965 , -0.79177886],
[-0.65683264, -1.0112412 ],
[-0.7620363 , 0.5188657 ]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
Example: Training on MNIST¶
The following script trains a model on the MNIST dataset.
1"""
2pip install tensorflow-datasets optax matplotlib
3"""
4
5import argparse
6
7import jax
8import optax
9import synax
10import tensorflow_datasets as tfds
11from jax import lax, random
12from jax import numpy as jnp
13from jax.experimental import io_callback
14from matplotlib import pyplot as plt
15
16
17def parse_args():
18 p = argparse.ArgumentParser()
19 p.add_argument("--dataset", type=str, default="mnist")
20 p.add_argument("--model", type=str, default="lenet")
21 p.add_argument("--optimizer", type=str, default="adam")
22 p.add_argument("--lr", type=float, default=1e-3)
23 p.add_argument("--seed", type=int, default=0)
24 p.add_argument("--epochs", type=int, default=10)
25 p.add_argument("--batch_size", type=int, default=32)
26 return p.parse_args()
27
28
29def sample_batch_indices(key, num_examples, batch_size):
30 num_batches = num_examples // batch_size
31 perm = random.permutation(key, num_examples)
32 limit = num_batches * batch_size
33 batch_indices = perm[:limit].reshape((num_batches, batch_size))
34 remainder = perm[limit:]
35 return batch_indices, remainder
36
37
38def get_dataset_size(ds):
39 leaves = jax.tree.leaves(ds)
40 size = leaves[0].shape[0]
41 assert all(leaf.shape[0] == size for leaf in leaves[1:])
42 return size
43
44
45def train(ds, model, optimizer, key, epochs, batch_size, epoch_callback, loss_fn):
46 def get_example_loss(params, example):
47 image = example["image"]
48 label = example["label"]
49 image /= 255
50 logits = model.apply(params, image)
51 loss = loss_fn(logits, label)
52 error = logits.argmax() != label
53 return loss, {"loss": loss, "error": error}
54
55 def get_batch_loss(params, batch):
56 losses, metrics = jax.vmap(get_example_loss, [None, 0])(params, batch)
57 mean_loss = losses.mean(0)
58 mean_metrics = jax.tree.map(lambda x: x.mean(0), metrics)
59 return mean_loss, mean_metrics
60
61 def run_batch(state, batch):
62 params = state["params"]
63 opt_state = state["optimizer"]
64 grads, metrics = jax.grad(get_batch_loss, has_aux=True)(params, batch)
65 updates, opt_state = optimizer.update(grads, opt_state, params)
66 params = optax.apply_updates(params, updates)
67 state = {
68 "params": params,
69 "optimizer": opt_state,
70 "batches": state["batches"] + 1,
71 "epochs": state["epochs"],
72 }
73 return state, metrics
74
75 def run_epoch(state, key):
76 def f(state, batch_indices):
77 batch = jax.tree.map(lambda x: x[batch_indices], ds["train"])
78 return run_batch(state, batch)
79
80 num_examples = get_dataset_size(ds["train"])
81 batch_indices, _ = sample_batch_indices(key, num_examples, batch_size)
82 state, train_metrics = lax.scan(f, state, batch_indices)
83 state |= {"epochs": state["epochs"] + 1}
84 train_metrics = jax.tree.map(lambda x: x.mean(0), train_metrics)
85
86 _, test_metrics = get_batch_loss(state["params"], ds["test"])
87
88 metrics = {"train": train_metrics, "test": test_metrics}
89
90 io_callback(epoch_callback, None, metrics, state)
91
92 return state, metrics
93
94 def run_trial(key):
95 key, subkey = random.split(key)
96 params = model.init_params(subkey)
97
98 opt_state = optimizer.init(params)
99
100 state = {
101 "params": params,
102 "optimizer": opt_state,
103 "batches": 0,
104 "epochs": 0,
105 }
106
107 keys = random.split(key, epochs)
108 state, metrics = lax.scan(run_epoch, state, keys)
109
110 return state, metrics
111
112 return run_trial(key)
113
114
115def get_optimizer(args):
116 match args.optimizer:
117 case "adam":
118 return optax.adam(args.lr)
119 case other:
120 raise NotImplementedError(other)
121
122
123def get_model(args, info):
124 image_shape = info.features["image"].shape
125 num_labels = info.features["label"].num_classes
126 match args.model:
127 case "lenet":
128 assert image_shape[:2] == (28, 28)
129 return synax.LeNet(input_channels=image_shape[2], outputs=num_labels)
130 case other:
131 raise NotImplementedError(other)
132
133
134def plot_metrics(metrics):
135 axes = {}
136
137 for split in metrics.keys():
138 for metric_name in metrics[split].keys():
139 if metric_name not in axes:
140 fig, ax = plt.subplots(constrained_layout=True)
141 axes[metric_name] = ax
142 ax = axes[metric_name]
143 ax.plot(metrics[split][metric_name], label=split)
144
145 for metric_name, ax in axes.items():
146 ax.legend()
147 ax.set(xlabel="epoch")
148 ax.set(ylabel=metric_name)
149
150
151def main(args):
152 ds, info = tfds.load(args.dataset, batch_size=-1, with_info=True) # type: ignore
153 print(f"Dataset: {info.description}\n")
154 ds = tfds.as_numpy(ds)
155 ds = jax.tree.map(jnp.asarray, ds)
156
157 model = get_model(args, info)
158 optimizer = get_optimizer(args)
159 key = random.key(args.seed)
160
161 def epoch_callback(metrics, state):
162 print(f"epochs: {state['epochs']}")
163 print(f"batches: {state['batches']}")
164 for split in info.splits.keys():
165 for metric_name in metrics[split].keys():
166 value = metrics[split][metric_name]
167 print(f"{split} {metric_name}: {value:g}")
168 print()
169
170 state, metrics = train(
171 ds=ds,
172 model=model,
173 optimizer=optimizer,
174 key=key,
175 epochs=args.epochs,
176 batch_size=args.batch_size,
177 epoch_callback=epoch_callback,
178 loss_fn=optax.softmax_cross_entropy_with_integer_labels,
179 )
180
181 plot_metrics(metrics)
182
183 plt.show()
184
185
186if __name__ == "__main__":
187 main(parse_args())