Custom Priors

This guide covers advanced prior specifications beyond the standard distributions.

Overview

While discoverysamplers provides built-in support for common priors (uniform, log-uniform, normal, fixed), you may need custom priors for:

  • Non-standard distributions (e.g., Beta, Gamma, truncated distributions)

  • Physical constraints (e.g., mass ratios, triangle inequalities)

  • Informative priors from previous measurements

  • Hierarchical priors

Callable Prior Interface

Basic Callable Prior

Create a custom prior by implementing a class with a logpdf method:

import numpy as np

class CustomPrior:
    def __init__(self, param1, param2):
        """Initialize prior with hyperparameters."""
        self.param1 = param1
        self.param2 = param2
        # Optional: specify bounds for samplers that need them
        self.bounds = (lower, upper)

    def logpdf(self, value):
        """
        Compute log probability density.

        Parameters
        ----------
        value : float or array
            Parameter value(s)

        Returns
        -------
        float or array
            Log probability density
        """
        # Implement your prior here
        if self.bounds[0] <= value <= self.bounds[1]:
            return np.log(density_function(value))
        return -np.inf

# Use with bridge
priors = {
    'param1': CustomPrior(a=1.0, b=2.0),
    'param2': ('uniform', 0, 10),
}

bridge = DiscoveryNessaiBridge(model, priors)

Common Custom Priors

Beta Distribution

Useful for parameters bounded in [0, 1]:

from scipy.stats import beta as beta_dist

class BetaPrior:
    def __init__(self, a, b, loc=0, scale=1):
        """
        Beta distribution prior.

        Parameters
        ----------
        a, b : float
            Shape parameters (a, b > 0)
        loc : float
            Lower bound (default: 0)
        scale : float
            Range (default: 1)
        """
        self.dist = beta_dist(a, b, loc=loc, scale=scale)
        self.bounds = (loc, loc + scale)

    def logpdf(self, value):
        return self.dist.logpdf(value)

# Example: Beta(2, 5) on [0, 1]
priors = {
    'eccentricity': BetaPrior(a=2, b=5),
}

Gamma Distribution

For positive parameters:

from scipy.stats import gamma as gamma_dist

class GammaPrior:
    def __init__(self, a, scale=1.0):
        """
        Gamma distribution prior.

        Parameters
        ----------
        a : float
            Shape parameter
        scale : float
            Scale parameter
        """
        self.dist = gamma_dist(a=a, scale=scale)
        self.bounds = (0, np.inf)

    def logpdf(self, value):
        if value < 0:
            return -np.inf
        return self.dist.logpdf(value)

# Example: Gamma(2, scale=0.5)
priors = {
    'rate_parameter': GammaPrior(a=2, scale=0.5),
}

Truncated Gaussian

Gaussian with hard bounds:

from scipy.stats import truncnorm

class TruncatedNormalPrior:
    def __init__(self, mean, std, lower, upper):
        """
        Truncated normal distribution.

        Parameters
        ----------
        mean, std : float
            Mean and standard deviation of underlying Gaussian
        lower, upper : float
            Truncation bounds
        """
        # Convert to standard form
        a = (lower - mean) / std
        b = (upper - mean) / std
        self.dist = truncnorm(a, b, loc=mean, scale=std)
        self.bounds = (lower, upper)

    def logpdf(self, value):
        return self.dist.logpdf(value)

# Example: N(0, 1) truncated to [-2, 2]
priors = {
    'param': TruncatedNormalPrior(mean=0, std=1, lower=-2, upper=2),
}

Mixture Prior

Mixture of distributions:

class MixturePrior:
    def __init__(self, components, weights):
        """
        Mixture of distributions.

        Parameters
        ----------
        components : list of distributions
            Each must have logpdf method
        weights : array
            Mixing weights (must sum to 1)
        """
        self.components = components
        self.weights = np.array(weights)
        assert np.allclose(self.weights.sum(), 1.0)

        # Bounds: union of component bounds
        self.bounds = (
            min(c.bounds[0] for c in components),
            max(c.bounds[1] for c in components)
        )

    def logpdf(self, value):
        # Log-sum-exp of weighted components
        log_probs = [
            np.log(w) + c.logpdf(value)
            for w, c in zip(self.weights, self.components)
        ]
        from scipy.special import logsumexp
        return logsumexp(log_probs)

# Example: 50% U(0,1) + 50% U(9,10) (bimodal)
from scipy.stats import uniform

priors = {
    'param': MixturePrior(
        components=[
            uniform(loc=0, scale=1),
            uniform(loc=9, scale=1),
        ],
        weights=[0.5, 0.5]
    )
}

JAX-Compatible Custom Priors

For JAX-NS and JIT Compilation

Custom priors must be JAX-compatible:

import jax.numpy as jnp
from jax import jit

class JAXBetaPrior:
    def __init__(self, a, b):
        self.a = a
        self.b = b
        self.bounds = (0, 1)

    def logpdf(self, value):
        """JAX-compatible log-pdf."""
        # Beta distribution log-pdf
        from jax.scipy.special import betaln

        log_pdf = (
            (self.a - 1) * jnp.log(value) +
            (self.b - 1) * jnp.log(1 - value) -
            betaln(self.a, self.b)
        )

        # Use jnp.where instead of if/else
        return jnp.where(
            (value > 0) & (value < 1),
            log_pdf,
            -jnp.inf
        )

# Use with JAX-NS
from discoverysamplers.jaxns_interface import DiscoveryJAXNSBridge

priors = {
    'param': JAXBetaPrior(a=2, b=5),
}

bridge = DiscoveryJAXNSBridge(jax_model, priors, jit=True)

Best Practices

  1. Always Specify Bounds: Nested samplers need bounded priors

  2. Test Normalization: Verify prior integrates to 1

  3. Use Log-Space: Always work with log-probabilities for numerical stability

  4. JAX Compatibility: Avoid Python control flow if using JAX

  5. Document Priors: Clearly document prior choices and rationale

See Also