Usage¶
The following script is a basic example of how to use the library:
1from jax import numpy as jnp, random
2import synax
3
4# Create a module.
5module = synax.MLP([2, 32, 3])
6
7# Create a PRNG key.
8key = random.key(0)
9
10# Sample initial parameters.
11w = module.init(key)
12
13# Define an input.
14x = jnp.ones(2)
15
16# Compute the output.
17y = module.apply(w, x)
18
19# Print the output.
20print(y)
Output:
[-1.2567853 -0.80044776 0.5694267 ]
A module has the following methods:
init
takes a JAX PRNG key and returns initial parameters for the module.apply
takes the module’s parameters, together with any inputs, and returns the output of the module.
Defining a custom module¶
The following script shows how to define a custom module:
1from jax import random, nn
2
3
4class Affine:
5 """Affine map."""
6
7 def __init__(
8 self,
9 input_dim,
10 output_dim,
11 weight_init=nn.initializers.he_normal(),
12 bias_init=nn.initializers.zeros,
13 ):
14 self.input_dim = input_dim
15 self.output_dim = output_dim
16 self.weight_init = weight_init
17 self.bias_init = bias_init
18
19 def init(self, key):
20 keys = random.split(key)
21 weight = self.weight_init(keys[0], (self.input_dim, self.output_dim))
22 bias = self.bias_init(keys[1], (self.output_dim,))
23 return {"weight": weight, "bias": bias}
24
25 def apply(self, params, input):
26 return input @ params["weight"] + params["bias"]
27
28
29module = Affine(3, 2)
30key = random.key(0)
31params = module.init(key)
32print(params)
Output:
{'weight': Array([[ 0.8737965 , -0.79177886],
[-0.65683264, -1.0112412 ],
[-0.7620363 , 0.5188657 ]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
Example: Training on MNIST¶
The following script trains a model on the MNIST dataset.
1"""
2pip install tensorflow-datasets optax matplotlib
3"""
4
5import argparse
6
7import jax
8import optax
9import synax
10import tensorflow_datasets as tfds
11from jax import lax, random
12from jax import numpy as jnp
13from jax.experimental import io_callback
14from matplotlib import pyplot as plt
15
16
17def parse_args():
18 p = argparse.ArgumentParser()
19 p.add_argument("--dataset", type=str, default="mnist")
20 p.add_argument("--model", type=str, default="lenet")
21 p.add_argument("--optimizer", type=str, default="adam")
22 p.add_argument("--lr", type=float, default=1e-3)
23 p.add_argument("--seed", type=int, default=0)
24 p.add_argument("--epochs", type=int, default=10)
25 p.add_argument("--batch_size", type=int, default=32)
26 return p.parse_args()
27
28
29def sample_batch_indices(key, num_examples, batch_size):
30 num_batches = num_examples // batch_size
31 perm = random.permutation(key, num_examples)
32 limit = num_batches * batch_size
33 batch_indices = perm[:limit].reshape((num_batches, batch_size))
34 remainder = perm[limit:]
35 return batch_indices, remainder
36
37
38def train(ds, info, model, optimizer, key, epochs, batch_size, epoch_callback):
39 def get_example_loss(params, instance):
40 image = instance["image"]
41 label = instance["label"]
42 image /= 255
43 logits = model.apply(params, image)
44 loss = optax.softmax_cross_entropy_with_integer_labels(logits, label)
45 error = logits.argmax() != label
46 return loss, {"loss": loss, "error": error}
47
48 def get_batch_loss(params, batch):
49 losses, metrics = jax.vmap(get_example_loss, [None, 0])(params, batch)
50 mean_loss = losses.mean(0)
51 mean_metrics = jax.tree.map(lambda x: x.mean(0), metrics)
52 return mean_loss, mean_metrics
53
54 def run_batch(state, batch):
55 params = state["params"]
56 opt_state = state["optimizer"]
57 grads, metrics = jax.grad(get_batch_loss, has_aux=True)(params, batch)
58 updates, opt_state = optimizer.update(grads, opt_state, params)
59 params = optax.apply_updates(params, updates)
60 state = {
61 "params": params,
62 "optimizer": opt_state,
63 "batches": state["batches"] + 1,
64 "epochs": state["epochs"],
65 }
66 return state, metrics
67
68 def run_epoch(state, key):
69 def f(state, batch_indices):
70 batch = jax.tree.map(lambda x: x[batch_indices], ds["train"])
71 return run_batch(state, batch)
72
73 num_examples = info.splits["train"].num_examples
74 batch_indices, _ = sample_batch_indices(key, num_examples, batch_size)
75 state, train_metrics = lax.scan(f, state, batch_indices)
76 state |= {"epochs": state["epochs"] + 1}
77 train_metrics = jax.tree.map(lambda x: x.mean(0), train_metrics)
78
79 _, test_metrics = get_batch_loss(state["params"], ds["test"])
80
81 metrics = {"train": train_metrics, "test": test_metrics}
82
83 io_callback(epoch_callback, None, metrics, state)
84
85 return state, metrics
86
87 def run_trial(key):
88 key, subkey = random.split(key)
89 params = model.init(subkey)
90
91 opt_state = optimizer.init(params)
92
93 state = {
94 "params": params,
95 "optimizer": opt_state,
96 "batches": 0,
97 "epochs": 0,
98 }
99
100 keys = random.split(key, epochs)
101 state, metrics = lax.scan(run_epoch, state, keys)
102
103 return state, metrics
104
105 return run_trial(key)
106
107
108def get_optimizer(args):
109 match args.optimizer:
110 case "adam":
111 return optax.adam(args.lr)
112 case other:
113 raise NotImplementedError(other)
114
115
116def get_model(args, info):
117 image_shape = info.features["image"].shape
118 num_labels = info.features["label"].num_classes
119 match args.model:
120 case "lenet":
121 assert image_shape[:2] == (28, 28)
122 return synax.LeNet(input_channels=image_shape[2], outputs=num_labels)
123 case other:
124 raise NotImplementedError(other)
125
126
127def main(args):
128 ds, info = tfds.load(args.dataset, batch_size=-1, with_info=True) # type: ignore
129 print(info.description)
130 ds = tfds.as_numpy(ds)
131 ds = jax.tree.map(jnp.asarray, ds)
132
133 model = get_model(args, info)
134 optimizer = get_optimizer(args)
135 key = random.key(args.seed)
136
137 def epoch_callback(metrics, state):
138 print(f"epochs: {state['epochs']}")
139 print(f"batches: {state['batches']}")
140 print(f"train loss: {metrics['train']['loss']:g}")
141 print(f"test loss: {metrics['test']['loss']:g}")
142 print(f"train error: {metrics['train']['error']:g}")
143 print(f"test error: {metrics['test']['error']:g}")
144 print()
145
146 state, metrics = train(
147 ds=ds,
148 info=info,
149 model=model,
150 optimizer=optimizer,
151 key=key,
152 epochs=args.epochs,
153 batch_size=args.batch_size,
154 epoch_callback=epoch_callback,
155 )
156
157 for metric_name in ["loss", "error"]:
158 fig, ax = plt.subplots()
159 ax.set(xlabel="epoch", ylabel=metric_name)
160 for split in metrics.keys():
161 ax.plot(metrics[split][metric_name], label=split)
162 ax.legend()
163
164 plt.show()
165
166
167if __name__ == "__main__":
168 main(parse_args())