@@ -2591,8 +2591,7 @@ class ZeroSumNormal(Distribution):
25912591 sigma : tensor_like of float
25922592 Scale parameter (sigma > 0).
25932593 It's actually the standard deviation of the underlying, unconstrained Normal distribution.
2594- Defaults to 1 if not specified.
2595- For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
2594+ Defaults to 1 if not specified. ``sigma`` cannot have length > 1 across the zero-sum axes.
25962595 n_zerosum_axes: int, defaults to 1
25972596 Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
25982597 Defaults to 1, i.e the rightmost axis.
@@ -2606,8 +2605,7 @@ class ZeroSumNormal(Distribution):
26062605
26072606 Warnings
26082607 --------
2609- ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
2610- The ability to specify a vector of ``sigma`` may be added in future versions.
2608+ Currently, ``sigma``cannot have length > 1 across the zero-sum axes to ensure the zero-sum constraint.
26112609
26122610 ``n_zerosum_axes`` has to be > 0. If you want the behavior of ``n_zerosum_axes = 0``,
26132611 just use ``pm.Normal``.
@@ -2669,8 +2667,8 @@ def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs):
26692667 n_zerosum_axes = cls .check_zerosum_axes (n_zerosum_axes )
26702668
26712669 sigma = pt .as_tensor_variable (floatX (sigma ))
2672- if sigma .ndim > 0 :
2673- raise ValueError ("sigma has to be a scalar " )
2670+ if not all ( sigma .type . broadcastable [ - n_zerosum_axes :]) :
2671+ raise ValueError ("sigma must have length one across the zero-sum axes " )
26742672
26752673 support_shape = get_support_shape (
26762674 support_shape = support_shape ,
@@ -2681,9 +2679,7 @@ def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs):
26812679 if support_shape is None :
26822680 if n_zerosum_axes > 0 :
26832681 raise ValueError ("You must specify dims, shape or support_shape parameter" )
2684- # TODO: edge-case doesn't work for now, because pt.stack in get_support_shape fails
2685- # else:
2686- # support_shape = () # because it's just a Normal in that case
2682+
26872683 support_shape = pt .as_tensor_variable (intX (support_shape ))
26882684
26892685 assert n_zerosum_axes == pt .get_vector_length (
@@ -2706,7 +2702,12 @@ def check_zerosum_axes(cls, n_zerosum_axes: Optional[int]) -> int:
27062702
27072703 @classmethod
27082704 def rv_op (cls , sigma , n_zerosum_axes , support_shape , size = None ):
2709- shape = to_tuple (size ) + tuple (support_shape )
2705+ if size is not None :
2706+ shape = tuple (size ) + tuple (support_shape )
2707+ else :
2708+ # Size is implied by shape of sigma
2709+ shape = tuple (sigma .shape [:- n_zerosum_axes ]) + tuple (support_shape )
2710+
27102711 normal_dist = pm .Normal .dist (sigma = sigma , shape = shape )
27112712
27122713 if n_zerosum_axes > normal_dist .ndim :
0 commit comments