synax.Switch

class synax.Switch(module: Module, branches: int)[source]

Switch map.

Multiplex between \(n\) copies of the same module, each with its own independent parameters. Given \(i \in [n]\), computes

\[y = f_i(x)\]

where \(\{f_i\}_{i \in [n]}\) is the ensemble of modules.

Parameters:
  • module – Module to apply.

  • branches – Number of branches to have.

Methods

init_params(→ Any)

Sample initial parameters.

param_loss(→ jax.Array)

Parameter loss.