Likelihood Module
This module provides model adapter utilities for wrapping Discovery models to work with various sampler backends. It handles JIT compilation, fixed parameter injection, and vectorized evaluation.
Discovery model adapter utilities.
This module provides adapters for wrapping Discovery models to work with various sampler backends, including support for JAX JIT compilation and vectorized evaluation.
- class discoverysamplers.likelihood.LikelihoodWrapper(model, fixed_params=None, jit=True, allow_array_api=False)[source]
Bases:
objectAdapter to wrap Discovery models for sampler interfaces.
This class provides a consistent interface for Discovery models, handling JIT compilation and optional vectorization.
- Parameters:
model (callable) – Discovery model callable (typically likelihood.logL)
fixed_params (dict, optional) – Dictionary of fixed parameter values to inject
jit (bool, default=True) – Whether to JIT-compile the model using JAX
allow_array_api (bool, default=False) – Whether to support vectorized (batched) evaluation
- model
The wrapped model
- Type:
callable
Examples
>>> import discovery as ds >>> psr = ds.Pulsar.read_feather('pulsar.feather') >>> likelihood = ds.PulsarLikelihood([...]) >>> >>> # Wrap the likelihood >>> adapter = _DiscoveryAdapter( ... model=likelihood.logL, ... fixed_params={'param1': 1.0}, ... jit=True ... ) >>> >>> # Evaluate (fixed params auto-injected) >>> log_L = adapter.log_likelihood({'param2': 2.0})
- __init__(model, fixed_params=None, jit=True, allow_array_api=False)[source]
Initialize the discovery adapter.
- configure_array_api(order)[source]
Configure vectorized (batched) likelihood evaluation.
- Parameters:
order (list of str) – Order of parameters for array construction
- Raises:
RuntimeError – If array API was not enabled in __init__
Examples
>>> adapter = _DiscoveryAdapter(model, allow_array_api=True) >>> adapter.configure_array_api(['param1', 'param2']) >>> # Now can evaluate with batched parameters
- log_likelihood(params_dict)[source]
Evaluate log-likelihood for a single parameter set.
- Parameters:
params_dict (dict) – Dictionary of parameter values (sampled parameters only)
- Returns:
Log-likelihood value
- Return type:
Examples
>>> log_L = adapter.log_likelihood({'param2': 2.0, 'param3': 3.0})
- log_likelihood_row(params_dict)[source]
Evaluate log-likelihood for a batch of parameter sets (array API).
Assumes params_dict values are arrays of shape (N,).
- Parameters:
params_dict (dict) – Dictionary with array values for each parameter
- Returns:
Log-likelihood values, shape (N,)
- Return type:
array
- Raises:
RuntimeError – If array API was not configured
Examples
>>> import jax.numpy as jnp >>> params = { ... 'param1': jnp.array([1.0, 2.0, 3.0]), ... 'param2': jnp.array([0.5, 1.0, 1.5]) ... } >>> log_L_batch = adapter.log_likelihood_row(params) >>> # Returns array of shape (3,)
- log_likelihood_matrix(params_array)[source]
Evaluate log-likelihood for a 2D array of parameters.
- Parameters:
params_array (array, shape (N, ndim)) – Parameter values for N samples
- Returns:
Log-likelihood values
- Return type:
array, shape (N,)
- Raises:
RuntimeError – If array API was not configured
Examples
>>> import jax.numpy as jnp >>> # Array with 3 samples, 2 parameters each >>> params = jnp.array([ ... [1.0, 0.5], ... [2.0, 1.0], ... [3.0, 1.5] ... ]) >>> log_L_batch = adapter.log_likelihood_matrix(params) >>> # Returns array of shape (3,)
See Also
Model Requirements - Requirements for Discovery models
Performance Optimization - Performance optimization tips