Eryn Interface

This module provides the bridge interface connecting Discovery models to the Eryn ensemble MCMC sampler. It supports parallel tempering and fixed-dimensional sampling.

Bridge utilities to run Eryn MCMC on Discovery likelihoods (PTAs).

Tested with:
  • discovery 0.5 (JAX-based PTA analysis)

  • eryn >= 1.2 (emcee-like API)

This file provides:
  • DiscoveryErynBridge: packs/unpacks parameter dicts, computes log-prob

Notes

  • Discovery’s likelihoods are JAX-ready callables that accept a dict of named parameters. We wrap them with a flat θ-vector interface that Eryn expects.

  • Priors: by default, uniform within bounds; you can supply custom log-prior callables per parameter.

  • Initialization: walkers are drawn from priors (or a Gaussian ball around an initial point if provided).

  • Sampling: uses Eryn’s EnsembleSampler with optional parallel tempering.

Gotchas:
  • Make sure Discovery and Eryn are installed in your Python environment.

  • Ensure parameter names in priors match those in the Discovery model.

  • If no parameters are sampled (all fixed), Eryn will raise an error.

  • Currently, the interface assumes that there is only one model, i.e. no reversible-jump sampling.

class discoverysamplers.eryn_interface.DiscoveryErynBridge(model, priors=None, latex_labels=None)[source]

Bases: object

__init__(model, priors=None, latex_labels=None)[source]

Initialize the Eryn interface for Discovery models.

This class creates an interface between Discovery models and the Eryn sampler, handling parameter management, prior specifications, and likelihood calculations.

Parameters:
  • model (object) – Discovery model object that must implement: - model.logL(params: dict) -> float : likelihood function - model.params : model parameters

  • priors (None | list[Param] | dict, optional) – Prior specifications for model parameters. Can be: - None: Uses default priors from the model if available - list[Param]: List of parameter specifications (legacy format) - dict: Cobaya-style prior specifications {name: {dist:…, …}}

  • latex_labels (dict, optional) – Dictionary mapping parameter names to their LaTeX representations for plotting and display purposes. If not provided, parameter names are used as labels.

discovery_paramnames

List of all parameter names in model order

Type:

list

sampled_prior_dict

Dictionary of prior specifications for sampled parameters

Type:

dict

fixed_param_dict

Dictionary of fixed parameter values

Type:

dict

fixed_names

Names of fixed parameters

Type:

list

sampled_names

Names of sampled parameters

Type:

list

n_fixed

Number of fixed parameters

Type:

int

n_sampled

Number of sampled parameters

Type:

int

ndim

Dimension of parameter space (same as n_sampled)

Type:

int

eryn_mapping

Maps parameter names to their indices in the θ-vector

Type:

dict

eryn_prior_dict

Prior specifications mapped to θ-vector indices

Type:

dict

eryn_prior_container

Eryn prior object for sampled parameters

Type:

ProbDistContainer

latex_labels

Mapping of parameter names to LaTeX labels

Type:

dict

latex_list

LaTeX labels for all parameters

Type:

list

sampled_names_latex

LaTeX labels for sampled parameters

Type:

list

fixed_names_latex

LaTeX labels for fixed parameters

Type:

list

Raises:

ValueError – If any model parameters are missing from the prior specifications

create_sampler(nwalkers, **kwargs)[source]

Create an ensemble sampler for MCMC sampling. This method initializes an EnsembleSampler object for Markov Chain Monte Carlo sampling, using the provided likelihood function and priors. :type nwalkers: int :param nwalkers: Number of walkers to use in the ensemble sampler :type nwalkers: int :type **kwargs: dict :param **kwargs: Additional keyword arguments to pass to the EnsembleSampler constructor :type **kwargs: dict

Returns:

Initialized ensemble sampler object

Return type:

EnsembleSampler

Raises:

ValueError – If no parameters are marked for sampling (ndim = 0)

Notes

The method creates an internal likelihood function that combines both fixed and sampled parameters before evaluation. The sampler is stored as an instance attribute and the initial shape for p0 is recorded.

run_sampler(nsteps, p0=None, **kwargs)[source]

Run the MCMC sampler for a specified number of steps. This method executes the MCMC sampling process using the previously created sampler. It can start from provided initial positions or generate them from the prior distributions. :type nsteps: int :param nsteps: Number of steps to run the MCMC sampler :type nsteps: int :type p0: array-like, optional :param p0: Initial positions for the walkers. If None, positions are drawn from the prior

distributions. Shape should match sampler requirements.

Parameters:

**kwargs (dict) – Additional keyword arguments to pass to the sampler’s run_mcmc method

Returns:

sampler – The MCMC sampler object after running the chain

Return type:

object

Raises:

ValueError – If the sampler has not been created or if no parameters are marked for sampling

Notes

The method requires that create_sampler has been called first and that at least one parameter has been marked as non-fixed in the prior distributions.

return_all_samples()[source]

Returns all MCMC samples including both sampled and fixed parameters. This method retrieves the MCMC chain from the sampler and combines the sampled parameters with the fixed parameters to create a complete parameter set for each sample. :returns: A dictionary containing:

  • ‘names’ (list): Names of all parameters (sampled and fixed)

  • ‘labels’ (list): LaTeX labels for all parameters

  • ‘chain’ (ndarray): Array of shape (nwalkers*nsteps, n_all_params) containing all parameter samples, where n_all_params is the total number of parameters (both sampled and fixed)

Return type:

dict

Raises:
  • ValueError – If the sampler has not been created using create_sampler()

  • RuntimeError – If the MCMC chain cannot be retrieved (e.g., if sampling hasn’t been run)

return_sampled_samples()[source]

Returns the sampled parameters and their names from the MCMC chain. This method retrieves the sampling chain from the sampler and returns it along with parameter names and their LaTeX representations. :returns: Dictionary containing:

  • ‘names’ (list): List of parameter names

  • ‘labels’ (list): List of parameter names in LaTeX format

  • ‘chain’ (ndarray): MCMC chain with shape (nwalkers*nsteps, n_sampled_params)

Return type:

dict

Raises:
  • ValueError – If sampler has not been created using create_sampler() method

  • RuntimeError – If sampling has not been run or chain cannot be retrieved

Notes

The returned chain combines all walkers and steps into a single array, flattening the typical (nwalkers, nsteps, ndim) shape into (nwalkers*nsteps, ndim).

return_logZ(*, results=None)[source]

Return the log evidence estimate.

Note: Eryn is an MCMC sampler and does not compute the Bayesian evidence. This method is provided for API consistency but raises NotImplementedError.

Raises:

NotImplementedError – Always raised - MCMC samplers do not compute evidence

Return type:

Dict[str, float]

plot_trace(burn=0, plot_fixed=False, **kwargs)[source]

Plot the MCMC chains for all parameters.

Parameters:
  • burn (int, optional) – Number of initial steps to discard from the plot, by default 0

  • plot_fixed (bool, optional) – If True, includes fixed parameters in the plot, by default False

  • **kwargs – Additional keyword arguments passed to plots.plot_trace()

Returns:

Figure object containing the trace plots

Return type:

matplotlib.figure.Figure

plot_corner(burn=0, temp=0, **kwargs)[source]

Create corner plots for the MCMC chain.

Parameters:
  • burn (int, optional) – Number of initial samples to discard as burn-in period, by default 0.

  • temp (int, optional) – Temperature index to plot (0 = cold chain), by default 0.

  • **kwargs – Additional keyword arguments passed to corner.corner().

Returns:

Corner plot figure.

Return type:

matplotlib.figure.Figure

names()[source]
Return type:

List[str]

pack(d)[source]
Return type:

ndarray

unpack(theta)[source]
Return type:

Dict[str, float]

pack_all(d)[source]
Return type:

ndarray

unpack_all(theta)[source]
Return type:

Dict[str, float]

See Also