Custom Priors
This guide covers advanced prior specifications beyond the standard distributions.
Overview
While discoverysamplers provides built-in support for common priors (uniform, log-uniform, normal, fixed), you may need custom priors for:
Non-standard distributions (e.g., Beta, Gamma, truncated distributions)
Physical constraints (e.g., mass ratios, triangle inequalities)
Informative priors from previous measurements
Hierarchical priors
Callable Prior Interface
Basic Callable Prior
Create a custom prior by implementing a class with a logpdf method:
import numpy as np
class CustomPrior:
def __init__(self, param1, param2):
"""Initialize prior with hyperparameters."""
self.param1 = param1
self.param2 = param2
# Optional: specify bounds for samplers that need them
self.bounds = (lower, upper)
def logpdf(self, value):
"""
Compute log probability density.
Parameters
----------
value : float or array
Parameter value(s)
Returns
-------
float or array
Log probability density
"""
# Implement your prior here
if self.bounds[0] <= value <= self.bounds[1]:
return np.log(density_function(value))
return -np.inf
# Use with bridge
priors = {
'param1': CustomPrior(a=1.0, b=2.0),
'param2': ('uniform', 0, 10),
}
bridge = DiscoveryNessaiBridge(model, priors)
Common Custom Priors
Beta Distribution
Useful for parameters bounded in [0, 1]:
from scipy.stats import beta as beta_dist
class BetaPrior:
def __init__(self, a, b, loc=0, scale=1):
"""
Beta distribution prior.
Parameters
----------
a, b : float
Shape parameters (a, b > 0)
loc : float
Lower bound (default: 0)
scale : float
Range (default: 1)
"""
self.dist = beta_dist(a, b, loc=loc, scale=scale)
self.bounds = (loc, loc + scale)
def logpdf(self, value):
return self.dist.logpdf(value)
# Example: Beta(2, 5) on [0, 1]
priors = {
'eccentricity': BetaPrior(a=2, b=5),
}
Gamma Distribution
For positive parameters:
from scipy.stats import gamma as gamma_dist
class GammaPrior:
def __init__(self, a, scale=1.0):
"""
Gamma distribution prior.
Parameters
----------
a : float
Shape parameter
scale : float
Scale parameter
"""
self.dist = gamma_dist(a=a, scale=scale)
self.bounds = (0, np.inf)
def logpdf(self, value):
if value < 0:
return -np.inf
return self.dist.logpdf(value)
# Example: Gamma(2, scale=0.5)
priors = {
'rate_parameter': GammaPrior(a=2, scale=0.5),
}
Truncated Gaussian
Gaussian with hard bounds:
from scipy.stats import truncnorm
class TruncatedNormalPrior:
def __init__(self, mean, std, lower, upper):
"""
Truncated normal distribution.
Parameters
----------
mean, std : float
Mean and standard deviation of underlying Gaussian
lower, upper : float
Truncation bounds
"""
# Convert to standard form
a = (lower - mean) / std
b = (upper - mean) / std
self.dist = truncnorm(a, b, loc=mean, scale=std)
self.bounds = (lower, upper)
def logpdf(self, value):
return self.dist.logpdf(value)
# Example: N(0, 1) truncated to [-2, 2]
priors = {
'param': TruncatedNormalPrior(mean=0, std=1, lower=-2, upper=2),
}
Mixture Prior
Mixture of distributions:
class MixturePrior:
def __init__(self, components, weights):
"""
Mixture of distributions.
Parameters
----------
components : list of distributions
Each must have logpdf method
weights : array
Mixing weights (must sum to 1)
"""
self.components = components
self.weights = np.array(weights)
assert np.allclose(self.weights.sum(), 1.0)
# Bounds: union of component bounds
self.bounds = (
min(c.bounds[0] for c in components),
max(c.bounds[1] for c in components)
)
def logpdf(self, value):
# Log-sum-exp of weighted components
log_probs = [
np.log(w) + c.logpdf(value)
for w, c in zip(self.weights, self.components)
]
from scipy.special import logsumexp
return logsumexp(log_probs)
# Example: 50% U(0,1) + 50% U(9,10) (bimodal)
from scipy.stats import uniform
priors = {
'param': MixturePrior(
components=[
uniform(loc=0, scale=1),
uniform(loc=9, scale=1),
],
weights=[0.5, 0.5]
)
}
JAX-Compatible Custom Priors
For JAX-NS and JIT Compilation
Custom priors must be JAX-compatible:
import jax.numpy as jnp
from jax import jit
class JAXBetaPrior:
def __init__(self, a, b):
self.a = a
self.b = b
self.bounds = (0, 1)
def logpdf(self, value):
"""JAX-compatible log-pdf."""
# Beta distribution log-pdf
from jax.scipy.special import betaln
log_pdf = (
(self.a - 1) * jnp.log(value) +
(self.b - 1) * jnp.log(1 - value) -
betaln(self.a, self.b)
)
# Use jnp.where instead of if/else
return jnp.where(
(value > 0) & (value < 1),
log_pdf,
-jnp.inf
)
# Use with JAX-NS
from discoverysamplers.jaxns_interface import DiscoveryJAXNSBridge
priors = {
'param': JAXBetaPrior(a=2, b=5),
}
bridge = DiscoveryJAXNSBridge(jax_model, priors, jit=True)
Best Practices
Always Specify Bounds: Nested samplers need bounded priors
Test Normalization: Verify prior integrates to 1
Use Log-Space: Always work with log-probabilities for numerical stability
JAX Compatibility: Avoid Python control flow if using JAX
Document Priors: Clearly document prior choices and rationale
See Also
Prior Specification - Standard priors
Model Requirements - Model requirements
Performance Optimization - Optimization tips