Likelihood Module

This module provides model adapter utilities for wrapping Discovery models to work with various sampler backends. It handles JIT compilation, fixed parameter injection, and vectorized evaluation.

Discovery model adapter utilities.

This module provides adapters for wrapping Discovery models to work with various sampler backends, including support for JAX JIT compilation and vectorized evaluation.

class discoverysamplers.likelihood.LikelihoodWrapper(model, fixed_params=None, jit=True, allow_array_api=False)[source]

Bases: object

Adapter to wrap Discovery models for sampler interfaces.

This class provides a consistent interface for Discovery models, handling JIT compilation and optional vectorization.

Parameters:
  • model (callable) – Discovery model callable (typically likelihood.logL)

  • fixed_params (dict, optional) – Dictionary of fixed parameter values to inject

  • jit (bool, default=True) – Whether to JIT-compile the model using JAX

  • allow_array_api (bool, default=False) – Whether to support vectorized (batched) evaluation

model

The wrapped model

Type:

callable

fixed_params

Fixed parameters to inject

Type:

dict

jit_enabled

Whether JIT compilation is enabled

Type:

bool

array_api_enabled

Whether array API is enabled

Type:

bool

array_order

Parameter order for array API

Type:

list, optional

Examples

>>> import discovery as ds
>>> psr = ds.Pulsar.read_feather('pulsar.feather')
>>> likelihood = ds.PulsarLikelihood([...])
>>>
>>> # Wrap the likelihood
>>> adapter = _DiscoveryAdapter(
...     model=likelihood.logL,
...     fixed_params={'param1': 1.0},
...     jit=True
... )
>>>
>>> # Evaluate (fixed params auto-injected)
>>> log_L = adapter.log_likelihood({'param2': 2.0})
__init__(model, fixed_params=None, jit=True, allow_array_api=False)[source]

Initialize the discovery adapter.

configure_array_api(order)[source]

Configure vectorized (batched) likelihood evaluation.

Parameters:

order (list of str) – Order of parameters for array construction

Raises:

RuntimeError – If array API was not enabled in __init__

Examples

>>> adapter = _DiscoveryAdapter(model, allow_array_api=True)
>>> adapter.configure_array_api(['param1', 'param2'])
>>> # Now can evaluate with batched parameters
log_likelihood(params_dict)[source]

Evaluate log-likelihood for a single parameter set.

Parameters:

params_dict (dict) – Dictionary of parameter values (sampled parameters only)

Returns:

Log-likelihood value

Return type:

float

Examples

>>> log_L = adapter.log_likelihood({'param2': 2.0, 'param3': 3.0})
log_likelihood_row(params_dict)[source]

Evaluate log-likelihood for a batch of parameter sets (array API).

Assumes params_dict values are arrays of shape (N,).

Parameters:

params_dict (dict) – Dictionary with array values for each parameter

Returns:

Log-likelihood values, shape (N,)

Return type:

array

Raises:

RuntimeError – If array API was not configured

Examples

>>> import jax.numpy as jnp
>>> params = {
...     'param1': jnp.array([1.0, 2.0, 3.0]),
...     'param2': jnp.array([0.5, 1.0, 1.5])
... }
>>> log_L_batch = adapter.log_likelihood_row(params)
>>> # Returns array of shape (3,)
log_likelihood_matrix(params_array)[source]

Evaluate log-likelihood for a 2D array of parameters.

Parameters:

params_array (array, shape (N, ndim)) – Parameter values for N samples

Returns:

Log-likelihood values

Return type:

array, shape (N,)

Raises:

RuntimeError – If array API was not configured

Examples

>>> import jax.numpy as jnp
>>> # Array with 3 samples, 2 parameters each
>>> params = jnp.array([
...     [1.0, 0.5],
...     [2.0, 1.0],
...     [3.0, 1.5]
... ])
>>> log_L_batch = adapter.log_likelihood_matrix(params)
>>> # Returns array of shape (3,)

See Also