Model Requirements

This page describes the requirements for Discovery models to work with discoverysamplers bridge interfaces.

Basic Requirements

Your model must satisfy these fundamental requirements:

  1. Callable: The model must be a callable object (function, method, or class with __call__)

  2. Dictionary Input: Must accept a dictionary mapping parameter names (strings) to parameter values

  3. Scalar Output: Must return a single scalar value representing the log-likelihood

  4. Deterministic: Should return the same output for the same input (for reproducibility)

Note: For Discovery likelihoods, you can pass either the likelihood object or likelihood.logL — all bridges automatically extract the callable if you pass the object.

Valid Model Signatures

Simple Function

The most straightforward model is a simple function:

def my_model(params: dict) -> float:
    """
    Compute log-likelihood from parameters.

    Parameters
    ----------
    params : dict
        Dictionary mapping parameter names to values

    Returns
    -------
    float
        Log-likelihood value
    """
    x = params['x']
    y = params['y']

    return -0.5 * (x**2 + y**2)

Function with Additional Arguments

If your model needs additional data or configuration, you can use closures or partial functions:

import numpy as np
from functools import partial

def model_with_data(params: dict, data: np.ndarray, noise_std: float) -> float:
    """Model that requires data."""
    signal = params['amplitude'] * np.sin(2 * np.pi * params['frequency'] * data)
    residuals = signal - data
    return -0.5 * np.sum((residuals / noise_std)**2)

# Create a partial function with fixed data
data = np.random.randn(100)
model = partial(model_with_data, data=data, noise_std=0.1)

# Now model(params) has the correct signature
from discoverysamplers.nessai_interface import DiscoveryNessaiBridge
bridge = DiscoveryNessaiBridge(model, priors)

Callable Class

For complex models, use a class with a __call__ method:

class MyComplexModel:
    def __init__(self, data, config):
        """Initialize model with data and configuration."""
        self.data = data
        self.config = config
        self._precompute()

    def _precompute(self):
        """Pre-compute expensive quantities."""
        self.data_fft = np.fft.fft(self.data)
        # ... other precomputations

    def __call__(self, params: dict) -> float:
        """Compute log-likelihood."""
        # Use precomputed quantities for efficiency
        signal = self._generate_signal(params)
        return self._compute_log_likelihood(signal)

    def _generate_signal(self, params):
        # Implementation details
        pass

    def _compute_log_likelihood(self, signal):
        # Implementation details
        pass

# Use the class instance as the model
model = MyComplexModel(data, config)
bridge = DiscoveryNessaiBridge(model, priors)

Parameter Dictionary Structure

Input Dictionary

The input params dictionary contains all sampled and fixed parameters:

params = {
    'mass': 1.5,           # float
    'distance': 100.0,     # float
    'sky_loc': [0.5, 1.2], # Can be array-like for vector parameters
}

Key Points:

  • Keys are parameter name strings

  • Values are typically floats or NumPy arrays

  • All sampled parameters (from priors) will be present

  • Fixed parameters are automatically injected by the bridge

  • Parameter order doesn’t matter (dictionary is unordered)

Handling Vector Parameters

If your model has vector-valued parameters:

def model_with_vectors(params):
    """Model with vector parameters."""
    # Scalar parameters
    mass = params['mass']

    # Vector parameters (arrays)
    sky_location = params['sky_location']  # Shape: (2,)
    ra, dec = sky_location

    # 2D parameters
    covariance = params['covariance']  # Shape: (3, 3)

    return compute_log_likelihood(mass, ra, dec, covariance)

# Priors for vector parameters
priors = {
    'mass': ('uniform', 1.0, 3.0),
    'sky_location': [
        ('uniform', 0, 2*np.pi),    # RA
        ('uniform', -np.pi/2, np.pi/2),  # Dec
    ],
    # For complex structures, use fixed parameters
    'covariance': ('fixed', np.eye(3)),
}

Return Value

Log-Likelihood

The model must return the log-likelihood (not likelihood or chi-squared):

def correct_model(params):
    # Compute chi-squared or similar
    chi2 = compute_chi_squared(params)

    # Return log-likelihood
    return -0.5 * chi2

def incorrect_model(params):
    # DON'T return likelihood
    likelihood = np.exp(-0.5 * chi2)  # Wrong!
    return likelihood

    # DON'T return chi-squared
    return chi2  # Wrong!

Special Values

Use -np.inf to indicate impossible parameter combinations:

import numpy as np

def model_with_constraints(params):
    mass1 = params['mass1']
    mass2 = params['mass2']

    # Physical constraint: mass1 >= mass2
    if mass1 < mass2:
        return -np.inf

    # Numerical stability: avoid log(0)
    if mass1 < 1e-10:
        return -np.inf

    # Otherwise compute likelihood
    return compute_log_likelihood(mass1, mass2)

Never return +np.inf or np.nan as these will break the samplers.

Performance Considerations

JAX Compatibility

For best performance with JAX-based samplers (JAX-NS, or when using jit=True):

import jax.numpy as jnp
from jax import jit

def jax_model(params):
    """JAX-compatible model using jax.numpy."""
    x = params['x']
    y = params['y']

    # Use jax.numpy instead of numpy
    return -0.5 * (jnp.square(x) + jnp.square(y))

# Optionally JIT-compile the model
jax_model_compiled = jit(jax_model)

bridge = DiscoveryJAXNSBridge(jax_model_compiled, priors, jit=True)

JAX Tips:

  • Use jax.numpy instead of numpy

  • Avoid Python control flow (if/else); use jnp.where instead

  • Avoid in-place array updates

  • See JAX documentation for more details

Vectorization

Some samplers (especially JAX-NS) can evaluate the likelihood for multiple parameter sets simultaneously:

def vectorized_model(params):
    """Model that supports vectorized evaluation."""
    x = params['x']  # Could be shape (N,) instead of scalar
    y = params['y']  # Could be shape (N,)

    # Operations work element-wise
    return -0.5 * (x**2 + y**2)  # Returns shape (N,)

bridge = DiscoveryJAXNSBridge(vectorized_model, priors)
bridge.configure_array_api(order=['x', 'y'])  # Enable vectorization

Caching and Memoization

If your likelihood involves expensive computations, consider caching:

from functools import lru_cache

class CachedModel:
    def __init__(self):
        # Cache expensive precomputations
        self._cache = {}

    def __call__(self, params):
        # Create a hashable key from params
        key = tuple(sorted(params.items()))

        if key not in self._cache:
            self._cache[key] = self._compute_expensive(params)

        return self._cache[key]

    def _compute_expensive(self, params):
        # Expensive likelihood computation
        return log_likelihood

model = CachedModel()

Warning: Ensure cache keys are truly unique to avoid incorrect results.

Common Patterns

Pattern 1: Separating Model Components

class ModularModel:
    def __init__(self, data):
        self.data = data

    def __call__(self, params):
        signal = self.signal_model(params)
        noise = self.noise_model(params)
        return self.log_likelihood(signal, noise)

    def signal_model(self, params):
        """Generate signal from parameters."""
        return params['amplitude'] * np.sin(2 * np.pi * params['frequency'] * self.data)

    def noise_model(self, params):
        """Model noise properties."""
        return params['noise_level']

    def log_likelihood(self, signal, noise):
        """Compute log-likelihood from signal and noise."""
        residuals = signal - self.data
        return -0.5 * np.sum((residuals / noise)**2)

Pattern 2: Model with Units

from astropy import units as u

class PhysicalModel:
    def __init__(self, data, units_dict):
        """
        Parameters
        ----------
        units_dict : dict
            Maps parameter names to astropy units
        """
        self.data = data
        self.units = units_dict

    def __call__(self, params):
        # Attach units
        mass = params['mass'] * self.units['mass']
        distance = params['distance'] * self.units['distance']

        # Compute physical quantities with unit checking
        quantity = (mass / distance**2).to(u.solMass / u.Mpc**2)

        # Return dimensionless log-likelihood
        return float(-0.5 * quantity.value**2)

Pattern 3: Hierarchical Model

def hierarchical_model(params):
    """Model with hierarchical parameters."""
    # Population-level parameters
    mu = params['population_mean']
    sigma = params['population_std']

    # Individual-level parameters
    individuals = [params[f'individual_{i}'] for i in range(10)]

    # Population prior
    log_prior = -0.5 * np.sum([(ind - mu)**2 / sigma**2 for ind in individuals])

    # Individual likelihoods
    log_like = sum([individual_likelihood(ind) for ind in individuals])

    return log_prior + log_like

Error Handling

Graceful Failure

Handle potential errors gracefully:

def robust_model(params):
    try:
        result = compute_likelihood(params)

        # Check for invalid results
        if not np.isfinite(result):
            return -np.inf

        return result

    except Exception as e:
        # Log the error for debugging
        print(f"Error in likelihood: {e}, params: {params}")
        return -np.inf

Numerical Stability

Avoid numerical issues:

def stable_model(params):
    x = params['x']

    # Avoid division by zero
    if abs(x) < 1e-10:
        return -np.inf

    # Use log-space for products
    # Instead of: p = p1 * p2 * p3
    # Use: log_p = log_p1 + log_p2 + log_p3
    log_p = np.log(p1) + np.log(p2) + np.log(p3)

    # Use logsumexp for sums in log-space
    from scipy.special import logsumexp
    log_sum = logsumexp([log_a, log_b, log_c])

    return log_p + log_sum

Testing Your Model

Before running samplers, test your model:

def test_model():
    """Test model with sample parameters."""
    # Test with typical values
    params = {'x': 1.0, 'y': 2.0}
    result = my_model(params)

    assert np.isfinite(result), "Model returned non-finite value"
    assert isinstance(result, (float, np.floating)), "Model must return scalar"

    # Test boundary cases
    edge_params = {'x': -5.0, 'y': 5.0}  # Prior bounds
    result = my_model(edge_params)
    assert np.isfinite(result), "Model failed at prior boundaries"

    # Test multiple calls (reproducibility)
    result1 = my_model(params)
    result2 = my_model(params)
    assert result1 == result2, "Model is not deterministic"

    print("Model tests passed!")

test_model()

RJMCMC Model Requirements

For reversible jump MCMC (trans-dimensional sampling), models have different requirements:

Signal Constructor Pattern:

RJMCMC models typically use signal constructors that return a (delay_function, param_names) tuple:

def my_signal_constructor(psr):
    """
    Signal constructor for RJMCMC.

    Parameters
    ----------
    psr : Pulsar
        Pulsar object

    Returns
    -------
    delay_fn : callable
        Function that computes signal delay given parameters
    param_names : list of str
        Names of parameters for this signal component
    """
    def delay_fn(params):
        # Compute signal delay from parameters
        f = params['f']
        h = params['h']
        phi = params['phi']
        return compute_cw_delay(f, h, phi, psr.toas)

    return delay_fn, ['f', 'h', 'phi']

Using RJ_Discovery_model:

The RJ_Discovery_model wrapper handles caching likelihoods for different component counts:

from discoverysamplers.eryn_RJ_interface import RJ_Discovery_model

rj_model = RJ_Discovery_model(
    signal_constructors=[my_signal_constructor],
    pulsar=psr,
    variable_component_numbers=[0, 1, 2, 3],  # Allowed component counts
)

# The model caches likelihoods for each configuration
# logL receives parameters as nested lists:
# params[branch_idx][component_idx] is shape (nwalkers, 1)

logL Signature for RJMCMC:

class RJ_Discovery_model:
    def logL(self, *params):
        """
        Compute log-likelihood for variable number of components.

        Parameters
        ----------
        *params : tuple of lists
            params[i] is a list of arrays for branch i
            params[i][j] has shape (nwalkers, 1) for component j

        Returns
        -------
        array
            Log-likelihood values, shape (nwalkers,)
        """
        # Determine active components from params structure
        cw_params = params[0]  # First branch
        n_sources = len(cw_params)

        # Select pre-cached likelihood for this configuration
        return self.likelihoods[n_sources](combined_params)

See Eryn MCMC Sampler for complete RJMCMC examples.

Checklist

Before using your model with discoverysamplers, verify:

  • [ ] Model accepts a dictionary and returns a scalar

  • [ ] Model returns log-likelihood (not likelihood or chi-squared)

  • [ ] Model returns -np.inf for invalid parameters (not exceptions)

  • [ ] Model is deterministic (same input → same output)

  • [ ] Model works with all parameters in the prior dictionary

  • [ ] Model handles fixed parameters correctly (they’re auto-injected)

  • [ ] Model is numerically stable across the prior range

  • [ ] If using JAX: model uses jax.numpy and avoids unsupported operations

  • [ ] Model has been tested with sample parameter values

See Also