# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytensor.xtensor as ptx
import pytensor.xtensor.random as pxr
from pytensor.xtensor import as_xtensor
from pymc.dims.distributions.core import (
DimDistribution,
PositiveDimDistribution,
UnitDimDistribution,
)
from pymc.distributions.continuous import Beta as RegularBeta
from pymc.distributions.continuous import Gamma as RegularGamma
from pymc.distributions.continuous import HalfStudentTRV, flat, halfflat
def _get_sigma_from_either_sigma_or_tau(*, sigma, tau):
if sigma is not None and tau is not None:
raise ValueError("Can't pass both tau and sigma")
if sigma is None and tau is None:
return 1.0
if sigma is not None:
return sigma
return ptx.math.reciprocal(ptx.math.sqrt(tau))
[docs]
class Flat(DimDistribution):
xrv_op = pxr.as_xrv(flat)
[docs]
@classmethod
def dist(cls, **kwargs):
return super().dist([], **kwargs)
[docs]
class HalfFlat(PositiveDimDistribution):
xrv_op = pxr.as_xrv(halfflat, [], ())
[docs]
@classmethod
def dist(cls, **kwargs):
return super().dist([], **kwargs)
[docs]
class Normal(DimDistribution):
xrv_op = pxr.normal
[docs]
@classmethod
def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs):
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau)
return super().dist([mu, sigma], **kwargs)
[docs]
class HalfNormal(PositiveDimDistribution):
xrv_op = pxr.halfnormal
[docs]
@classmethod
def dist(cls, sigma=None, *, tau=None, **kwargs):
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau)
return super().dist([0.0, sigma], **kwargs)
[docs]
class LogNormal(PositiveDimDistribution):
xrv_op = pxr.lognormal
[docs]
@classmethod
def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs):
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau)
return super().dist([mu, sigma], **kwargs)
[docs]
class StudentT(DimDistribution):
xrv_op = pxr.t
[docs]
@classmethod
def dist(cls, nu, mu=0, sigma=None, *, lam=None, **kwargs):
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam)
return super().dist([nu, mu, sigma], **kwargs)
[docs]
class HalfStudentT(PositiveDimDistribution):
[docs]
@classmethod
def dist(cls, nu, sigma=None, *, lam=None, **kwargs):
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam)
return super().dist([nu, sigma], **kwargs)
@classmethod
def xrv_op(self, nu, sigma, core_dims=None, extra_dims=None, rng=None):
nu = as_xtensor(nu)
sigma = as_xtensor(sigma)
core_rv = HalfStudentTRV.rv_op(nu=nu.values, sigma=sigma.values).owner.op
xop = pxr.as_xrv(core_rv)
return xop(nu, sigma, core_dims=core_dims, extra_dims=extra_dims, rng=rng)
[docs]
class Cauchy(DimDistribution):
xrv_op = pxr.cauchy
[docs]
@classmethod
def dist(cls, alpha, beta, **kwargs):
return super().dist([alpha, beta], **kwargs)
[docs]
class HalfCauchy(PositiveDimDistribution):
xrv_op = pxr.halfcauchy
[docs]
@classmethod
def dist(cls, beta, **kwargs):
return super().dist([0.0, beta], **kwargs)
[docs]
class Beta(UnitDimDistribution):
xrv_op = pxr.beta
[docs]
@classmethod
def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, nu=None, **kwargs):
alpha, beta = RegularBeta.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma, nu=nu)
return super().dist([alpha, beta], **kwargs)
[docs]
class Laplace(DimDistribution):
xrv_op = pxr.laplace
[docs]
@classmethod
def dist(cls, mu=0, b=1, **kwargs):
return super().dist([mu, b], **kwargs)
class Exponential(PositiveDimDistribution):
xrv_op = pxr.exponential
@classmethod
def dist(cls, lam=None, *, scale=None, **kwargs):
if lam is None and scale is None:
scale = 1.0
elif lam is not None and scale is not None:
raise ValueError("Cannot pass both 'lam' and 'scale'. Use one of them.")
elif lam is not None:
scale = 1 / lam
return super().dist([scale], **kwargs)
[docs]
class Gamma(PositiveDimDistribution):
xrv_op = pxr.gamma
[docs]
@classmethod
def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs):
if (alpha is not None) and (beta is not None):
pass
elif (mu is not None) and (sigma is not None):
# Use sign of sigma to not let negative sigma fly by
alpha = (mu**2 / sigma**2) * ptx.math.sign(sigma)
beta = mu / sigma**2
else:
raise ValueError(
"Incompatible parameterization. Either use alpha and beta, or mu and sigma."
)
alpha, beta = RegularGamma.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma)
return super().dist([alpha, ptx.math.reciprocal(beta)], **kwargs)
[docs]
class InverseGamma(PositiveDimDistribution):
xrv_op = pxr.invgamma
[docs]
@classmethod
def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs):
if alpha is not None:
if beta is None:
beta = 1.0
elif (mu is not None) and (sigma is not None):
# Use sign of sigma to not let negative sigma fly by
alpha = ((2 * sigma**2 + mu**2) / sigma**2) * ptx.math.sign(sigma)
beta = mu * (mu**2 + sigma**2) / sigma**2
else:
raise ValueError(
"Incompatible parameterization. Either use alpha and (optionally) beta, or mu and sigma"
)
return super().dist([alpha, beta], **kwargs)