diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index f76a98546e..142ee4fcd8 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2705,31 +2705,31 @@ class ZeroSumNormal(Distribution): .. math:: - \begin{align*} - ZSN(\sigma) = N \Big( 0, \sigma^2 (I - \tfrac{1}{n}J) \Big) \\ - \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ - n = \text{nbr of zero-sum axes} - \end{align*} + \begin{align*} + ZSN(\sigma) = N \Big( 0, \sigma^2 (I_K - \tfrac{1}{K}J_K) \Big) \\ + \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ + K = \text{size (length) of the constrained axis} + \end{align*} Parameters ---------- sigma : tensor_like of float - Scale parameter (sigma > 0). - It's actually the standard deviation of the underlying, unconstrained Normal distribution. - Defaults to 1 if not specified. ``sigma`` cannot have length > 1 across the zero-sum axes. + Scale parameter (sigma > 0). + It's actually the standard deviation of the underlying, unconstrained Normal distribution. + Defaults to 1 if not specified. ``sigma`` cannot have length > 1 across the zero-sum axes. n_zerosum_axes: int, defaults to 1 - Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position. - Defaults to 1, i.e the rightmost axis. + Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position. + Defaults to 1, i.e the rightmost axis. dims: sequence of strings, optional - Dimension names of the distribution. Works the same as for other PyMC distributions. - Necessary if ``shape`` is not passed. + Dimension names of the distribution. Works the same as for other PyMC distributions. + Necessary if ``shape`` is not passed. shape: tuple of integers, optional - Shape of the distribution. Works the same as for other PyMC distributions. - Necessary if ``dims`` or ``observed`` is not passed. + Shape of the distribution. Works the same as for other PyMC distributions. + Necessary if ``dims`` or ``observed`` is not passed. Warnings -------- - Currently, ``sigma``cannot have length > 1 across the zero-sum axes to ensure the zero-sum constraint. + Currently, ``sigma`` cannot have length > 1 across the zero-sum axes to ensure the zero-sum constraint. ``n_zerosum_axes`` has to be > 0. If you want the behavior of ``n_zerosum_axes = 0``, just use ``pm.Normal``. @@ -2737,23 +2737,23 @@ class ZeroSumNormal(Distribution): Examples -------- Define a `ZeroSumNormal` variable, with `sigma=1` and - `n_zerosum_axes=1` by default:: - - COORDS = { - "regions": ["a", "b", "c"], - "answers": ["yes", "no", "whatever", "don't understand question"], - } - with pm.Model(coords=COORDS) as m: - # the zero sum axis will be 'answers' - v = pm.ZeroSumNormal("v", dims=("regions", "answers")) - - with pm.Model(coords=COORDS) as m: - # the zero sum axes will be 'answers' and 'regions' - v = pm.ZeroSumNormal("v", dims=("regions", "answers"), n_zerosum_axes=2) - - with pm.Model(coords=COORDS) as m: - # the zero sum axes will be the last two - v = pm.ZeroSumNormal("v", shape=(3, 4, 5), n_zerosum_axes=2) + `n_zerosum_axes=1` by default:: + + COORDS = { + "regions": ["a", "b", "c"], + "answers": ["yes", "no", "whatever", "don't understand question"], + } + with pm.Model(coords=COORDS) as m: + # the zero sum axis will be 'answers' + v = pm.ZeroSumNormal("v", dims=("regions", "answers")) + + with pm.Model(coords=COORDS) as m: + # the zero sum axes will be 'answers' and 'regions' + v = pm.ZeroSumNormal("v", dims=("regions", "answers"), n_zerosum_axes=2) + + with pm.Model(coords=COORDS) as m: + # the zero sum axes will be the last two + v = pm.ZeroSumNormal("v", shape=(3, 4, 5), n_zerosum_axes=2) """ rv_type = ZeroSumNormalRV