synax.Attention.parameter_loss

synax.Attention.parameter_loss(parameters: dict[str, jax.Array]) float[source]

Parameter loss.

Parameters:

parameters – Parameters.

Returns:

Scalar.