Model Requirements
This page describes the requirements for Discovery models to work with discoverysamplers bridge interfaces.
Basic Requirements
Your model must satisfy these fundamental requirements:
Callable: The model must be a callable object (function, method, or class with
__call__)Dictionary Input: Must accept a dictionary mapping parameter names (strings) to parameter values
Scalar Output: Must return a single scalar value representing the log-likelihood
Deterministic: Should return the same output for the same input (for reproducibility)
Note: For Discovery likelihoods, you can pass either the likelihood object or likelihood.logL — all bridges automatically extract the callable if you pass the object.
Valid Model Signatures
Discovery Likelihood (Recommended)
The easiest way to use discoverysamplers is with Discovery likelihoods:
import discovery as ds
# Load pulsar and create likelihood
psr = ds.Pulsar.read_feather('path/to/pulsar.feather')
likelihood = ds.PulsarLikelihood([
psr.residuals,
ds.makenoise_measurement(psr, psr.noisedict),
ds.makegp_timing(psr, svd=True),
])
# You can pass either likelihood or likelihood.logL to bridges
from discoverysamplers.nessai_interface import DiscoveryNessaiBridge
bridge = DiscoveryNessaiBridge(
discovery_model=likelihood, # Bridge extracts .logL automatically
priors=priors
)
# OR equivalently:
bridge = DiscoveryNessaiBridge(
discovery_model=likelihood.logL,
priors=priors
)
Simple Function
The most straightforward model is a simple function:
def my_model(params: dict) -> float:
"""
Compute log-likelihood from parameters.
Parameters
----------
params : dict
Dictionary mapping parameter names to values
Returns
-------
float
Log-likelihood value
"""
x = params['x']
y = params['y']
return -0.5 * (x**2 + y**2)
Function with Additional Arguments
If your model needs additional data or configuration, you can use closures or partial functions:
import numpy as np
from functools import partial
def model_with_data(params: dict, data: np.ndarray, noise_std: float) -> float:
"""Model that requires data."""
signal = params['amplitude'] * np.sin(2 * np.pi * params['frequency'] * data)
residuals = signal - data
return -0.5 * np.sum((residuals / noise_std)**2)
# Create a partial function with fixed data
data = np.random.randn(100)
model = partial(model_with_data, data=data, noise_std=0.1)
# Now model(params) has the correct signature
from discoverysamplers.nessai_interface import DiscoveryNessaiBridge
bridge = DiscoveryNessaiBridge(model, priors)
Callable Class
For complex models, use a class with a __call__ method:
class MyComplexModel:
def __init__(self, data, config):
"""Initialize model with data and configuration."""
self.data = data
self.config = config
self._precompute()
def _precompute(self):
"""Pre-compute expensive quantities."""
self.data_fft = np.fft.fft(self.data)
# ... other precomputations
def __call__(self, params: dict) -> float:
"""Compute log-likelihood."""
# Use precomputed quantities for efficiency
signal = self._generate_signal(params)
return self._compute_log_likelihood(signal)
def _generate_signal(self, params):
# Implementation details
pass
def _compute_log_likelihood(self, signal):
# Implementation details
pass
# Use the class instance as the model
model = MyComplexModel(data, config)
bridge = DiscoveryNessaiBridge(model, priors)
Parameter Dictionary Structure
Input Dictionary
The input params dictionary contains all sampled and fixed parameters:
params = {
'mass': 1.5, # float
'distance': 100.0, # float
'sky_loc': [0.5, 1.2], # Can be array-like for vector parameters
}
Key Points:
Keys are parameter name strings
Values are typically floats or NumPy arrays
All sampled parameters (from
priors) will be presentFixed parameters are automatically injected by the bridge
Parameter order doesn’t matter (dictionary is unordered)
Handling Vector Parameters
If your model has vector-valued parameters:
def model_with_vectors(params):
"""Model with vector parameters."""
# Scalar parameters
mass = params['mass']
# Vector parameters (arrays)
sky_location = params['sky_location'] # Shape: (2,)
ra, dec = sky_location
# 2D parameters
covariance = params['covariance'] # Shape: (3, 3)
return compute_log_likelihood(mass, ra, dec, covariance)
# Priors for vector parameters
priors = {
'mass': ('uniform', 1.0, 3.0),
'sky_location': [
('uniform', 0, 2*np.pi), # RA
('uniform', -np.pi/2, np.pi/2), # Dec
],
# For complex structures, use fixed parameters
'covariance': ('fixed', np.eye(3)),
}
Return Value
Log-Likelihood
The model must return the log-likelihood (not likelihood or chi-squared):
def correct_model(params):
# Compute chi-squared or similar
chi2 = compute_chi_squared(params)
# Return log-likelihood
return -0.5 * chi2
def incorrect_model(params):
# DON'T return likelihood
likelihood = np.exp(-0.5 * chi2) # Wrong!
return likelihood
# DON'T return chi-squared
return chi2 # Wrong!
Special Values
Use -np.inf to indicate impossible parameter combinations:
import numpy as np
def model_with_constraints(params):
mass1 = params['mass1']
mass2 = params['mass2']
# Physical constraint: mass1 >= mass2
if mass1 < mass2:
return -np.inf
# Numerical stability: avoid log(0)
if mass1 < 1e-10:
return -np.inf
# Otherwise compute likelihood
return compute_log_likelihood(mass1, mass2)
Never return +np.inf or np.nan as these will break the samplers.
Performance Considerations
JAX Compatibility
For best performance with JAX-based samplers (JAX-NS, or when using jit=True):
import jax.numpy as jnp
from jax import jit
def jax_model(params):
"""JAX-compatible model using jax.numpy."""
x = params['x']
y = params['y']
# Use jax.numpy instead of numpy
return -0.5 * (jnp.square(x) + jnp.square(y))
# Optionally JIT-compile the model
jax_model_compiled = jit(jax_model)
bridge = DiscoveryJAXNSBridge(jax_model_compiled, priors, jit=True)
JAX Tips:
Use
jax.numpyinstead ofnumpyAvoid Python control flow (if/else); use
jnp.whereinsteadAvoid in-place array updates
See JAX documentation for more details
Vectorization
Some samplers (especially JAX-NS) can evaluate the likelihood for multiple parameter sets simultaneously:
def vectorized_model(params):
"""Model that supports vectorized evaluation."""
x = params['x'] # Could be shape (N,) instead of scalar
y = params['y'] # Could be shape (N,)
# Operations work element-wise
return -0.5 * (x**2 + y**2) # Returns shape (N,)
bridge = DiscoveryJAXNSBridge(vectorized_model, priors)
bridge.configure_array_api(order=['x', 'y']) # Enable vectorization
Caching and Memoization
If your likelihood involves expensive computations, consider caching:
from functools import lru_cache
class CachedModel:
def __init__(self):
# Cache expensive precomputations
self._cache = {}
def __call__(self, params):
# Create a hashable key from params
key = tuple(sorted(params.items()))
if key not in self._cache:
self._cache[key] = self._compute_expensive(params)
return self._cache[key]
def _compute_expensive(self, params):
# Expensive likelihood computation
return log_likelihood
model = CachedModel()
Warning: Ensure cache keys are truly unique to avoid incorrect results.
Common Patterns
Pattern 1: Separating Model Components
class ModularModel:
def __init__(self, data):
self.data = data
def __call__(self, params):
signal = self.signal_model(params)
noise = self.noise_model(params)
return self.log_likelihood(signal, noise)
def signal_model(self, params):
"""Generate signal from parameters."""
return params['amplitude'] * np.sin(2 * np.pi * params['frequency'] * self.data)
def noise_model(self, params):
"""Model noise properties."""
return params['noise_level']
def log_likelihood(self, signal, noise):
"""Compute log-likelihood from signal and noise."""
residuals = signal - self.data
return -0.5 * np.sum((residuals / noise)**2)
Pattern 2: Model with Units
from astropy import units as u
class PhysicalModel:
def __init__(self, data, units_dict):
"""
Parameters
----------
units_dict : dict
Maps parameter names to astropy units
"""
self.data = data
self.units = units_dict
def __call__(self, params):
# Attach units
mass = params['mass'] * self.units['mass']
distance = params['distance'] * self.units['distance']
# Compute physical quantities with unit checking
quantity = (mass / distance**2).to(u.solMass / u.Mpc**2)
# Return dimensionless log-likelihood
return float(-0.5 * quantity.value**2)
Pattern 3: Hierarchical Model
def hierarchical_model(params):
"""Model with hierarchical parameters."""
# Population-level parameters
mu = params['population_mean']
sigma = params['population_std']
# Individual-level parameters
individuals = [params[f'individual_{i}'] for i in range(10)]
# Population prior
log_prior = -0.5 * np.sum([(ind - mu)**2 / sigma**2 for ind in individuals])
# Individual likelihoods
log_like = sum([individual_likelihood(ind) for ind in individuals])
return log_prior + log_like
Error Handling
Graceful Failure
Handle potential errors gracefully:
def robust_model(params):
try:
result = compute_likelihood(params)
# Check for invalid results
if not np.isfinite(result):
return -np.inf
return result
except Exception as e:
# Log the error for debugging
print(f"Error in likelihood: {e}, params: {params}")
return -np.inf
Numerical Stability
Avoid numerical issues:
def stable_model(params):
x = params['x']
# Avoid division by zero
if abs(x) < 1e-10:
return -np.inf
# Use log-space for products
# Instead of: p = p1 * p2 * p3
# Use: log_p = log_p1 + log_p2 + log_p3
log_p = np.log(p1) + np.log(p2) + np.log(p3)
# Use logsumexp for sums in log-space
from scipy.special import logsumexp
log_sum = logsumexp([log_a, log_b, log_c])
return log_p + log_sum
Testing Your Model
Before running samplers, test your model:
def test_model():
"""Test model with sample parameters."""
# Test with typical values
params = {'x': 1.0, 'y': 2.0}
result = my_model(params)
assert np.isfinite(result), "Model returned non-finite value"
assert isinstance(result, (float, np.floating)), "Model must return scalar"
# Test boundary cases
edge_params = {'x': -5.0, 'y': 5.0} # Prior bounds
result = my_model(edge_params)
assert np.isfinite(result), "Model failed at prior boundaries"
# Test multiple calls (reproducibility)
result1 = my_model(params)
result2 = my_model(params)
assert result1 == result2, "Model is not deterministic"
print("Model tests passed!")
test_model()
RJMCMC Model Requirements
For reversible jump MCMC (trans-dimensional sampling), models have different requirements:
Signal Constructor Pattern:
RJMCMC models typically use signal constructors that return a (delay_function, param_names) tuple:
def my_signal_constructor(psr):
"""
Signal constructor for RJMCMC.
Parameters
----------
psr : Pulsar
Pulsar object
Returns
-------
delay_fn : callable
Function that computes signal delay given parameters
param_names : list of str
Names of parameters for this signal component
"""
def delay_fn(params):
# Compute signal delay from parameters
f = params['f']
h = params['h']
phi = params['phi']
return compute_cw_delay(f, h, phi, psr.toas)
return delay_fn, ['f', 'h', 'phi']
Using RJ_Discovery_model:
The RJ_Discovery_model wrapper handles caching likelihoods for different component counts:
from discoverysamplers.eryn_RJ_interface import RJ_Discovery_model
rj_model = RJ_Discovery_model(
signal_constructors=[my_signal_constructor],
pulsar=psr,
variable_component_numbers=[0, 1, 2, 3], # Allowed component counts
)
# The model caches likelihoods for each configuration
# logL receives parameters as nested lists:
# params[branch_idx][component_idx] is shape (nwalkers, 1)
logL Signature for RJMCMC:
class RJ_Discovery_model:
def logL(self, *params):
"""
Compute log-likelihood for variable number of components.
Parameters
----------
*params : tuple of lists
params[i] is a list of arrays for branch i
params[i][j] has shape (nwalkers, 1) for component j
Returns
-------
array
Log-likelihood values, shape (nwalkers,)
"""
# Determine active components from params structure
cw_params = params[0] # First branch
n_sources = len(cw_params)
# Select pre-cached likelihood for this configuration
return self.likelihoods[n_sources](combined_params)
See Eryn MCMC Sampler for complete RJMCMC examples.
Checklist
Before using your model with discoverysamplers, verify:
[ ] Model accepts a dictionary and returns a scalar
[ ] Model returns log-likelihood (not likelihood or chi-squared)
[ ] Model returns
-np.inffor invalid parameters (not exceptions)[ ] Model is deterministic (same input → same output)
[ ] Model works with all parameters in the prior dictionary
[ ] Model handles fixed parameters correctly (they’re auto-injected)
[ ] Model is numerically stable across the prior range
[ ] If using JAX: model uses
jax.numpyand avoids unsupported operations[ ] Model has been tested with sample parameter values
See Also
Prior Specification - How to specify priors for parameters (including RJMCMC)
Quick Start Guide - Basic usage examples
Eryn MCMC Sampler - RJMCMC and parallel tempering
Performance Optimization - Performance optimization guide