Squashed Multivariate Normal Distribution
lerax.distribution.SquashedMultivariateNormalDiag
Bases: AbstractTransformedDistribution[Float[Array, ' dims']]
Multivariate Normal with squashing bijector for bounded outputs.
Attributes:
| Name | Type | Description |
|---|---|---|
distribution |
distributions.Transformed
|
The underlying distreqx Transformed distribution. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loc
|
Float[ArrayLike, ' dims']
|
The mean of the multivariate normal distribution. |
required |
scale_diag
|
Float[ArrayLike, ' dims']
|
The diagonal of the covariance matrix. |
required |
high
|
Float[ArrayLike, ' dims']
|
The upper bound for bounded squashing.. |
jnp.array(1.0)
|
low
|
Float[ArrayLike, ' dims']
|
The lower bound for bounded squashing.. |
jnp.array(-1.0)
|