Source code for discoverysamplers.eryn_interface

# Filename: eryn.py
# Author: Jonas El Gammal
# Description: Interface to Eryn
# Date Created: 2025-09-01

"""
Bridge utilities to run Eryn MCMC on Discovery likelihoods (PTAs).

Tested with:
  - discovery 0.5 (JAX-based PTA analysis)  
  - eryn >= 1.2 (emcee-like API)

This file provides:
  - DiscoveryErynBridge: packs/unpacks parameter dicts, computes log-prob

Notes:
  • Discovery’s likelihoods are JAX-ready callables that accept a dict of
    named parameters. We wrap them with a flat θ-vector interface that Eryn
    expects.
  • Priors: by default, uniform within bounds; you can supply custom
    log-prior callables per parameter.
  • Initialization: walkers are drawn from priors (or a Gaussian ball around
    an initial point if provided).
  • Sampling: uses Eryn’s EnsembleSampler with optional parallel tempering.

Gotchas:
    • Make sure Discovery and Eryn are installed in your Python environment.
    • Ensure parameter names in priors match those in the Discovery model.
    • If no parameters are sampled (all fixed), Eryn will raise an error.
    • Currently, the interface assumes that there is only one model, i.e. no
      reversible-jump sampling.
"""


import numpy as np
import re
from typing import Mapping, Sequence, Optional, Dict, List, Any
import warnings

try:
    from eryn.prior import uniform_dist, log_uniform, ProbDistContainer
    from eryn.ensemble import EnsembleSampler
except ImportError:
    raise ImportError("eryn is not installed. Please install it to use this module.")


[docs] class DiscoveryErynBridge:
[docs] def __init__(self, model, priors: Optional[Any] = None, latex_labels: Optional[Dict[str, str]] = None): """ Initialize the Eryn interface for Discovery models. This class creates an interface between Discovery models and the Eryn sampler, handling parameter management, prior specifications, and likelihood calculations. Parameters ---------- model : object Discovery model object that must implement: - model.logL(params: dict) -> float : likelihood function - model.params : model parameters priors : None | list[Param] | dict, optional Prior specifications for model parameters. Can be: - None: Uses default priors from the model if available - list[Param]: List of parameter specifications (legacy format) - dict: Cobaya-style prior specifications {name: {dist:..., ...}} latex_labels : dict, optional Dictionary mapping parameter names to their LaTeX representations for plotting and display purposes. If not provided, parameter names are used as labels. Attributes ---------- discovery_paramnames : list List of all parameter names in model order sampled_prior_dict : dict Dictionary of prior specifications for sampled parameters fixed_param_dict : dict Dictionary of fixed parameter values fixed_names : list Names of fixed parameters sampled_names : list Names of sampled parameters n_fixed : int Number of fixed parameters n_sampled : int Number of sampled parameters ndim : int Dimension of parameter space (same as n_sampled) eryn_mapping : dict Maps parameter names to their indices in the θ-vector eryn_prior_dict : dict Prior specifications mapped to θ-vector indices eryn_prior_container : ProbDistContainer Eryn prior object for sampled parameters latex_labels : dict Mapping of parameter names to LaTeX labels latex_list : list LaTeX labels for all parameters sampled_names_latex : list LaTeX labels for sampled parameters fixed_names_latex : list LaTeX labels for fixed parameters Raises ------ ValueError If any model parameters are missing from the prior specifications """ self.model = model # 1) Get all model parameter names (order = sampler order) self.discovery_paramnames = self._infer_model_param_names(model) # 2) Normalize priors to a single internal spec per parameter self.sampled_prior_dict, self.fixed_param_dict = self._check_and_create_priors(priors, model, self.discovery_paramnames) # 3) Sampled vs fixed parameters self.fixed_names = list(self.fixed_param_dict.keys()) self.sampled_names = list(self.sampled_prior_dict.keys()) # make sure that the sampled + fixed = all parameters all_names = set(self.fixed_names).union(set(self.sampled_names)) if set(self.discovery_paramnames) != all_names: missing = set(self.discovery_paramnames) - all_names raise ValueError(f"Parameters missing from priors: {missing}. Please provide priors for all parameters.") # Make sure the order is the same as in discovery_paramnames self.fixed_names.sort(key=self.discovery_paramnames.index) self.sampled_names.sort(key=self.discovery_paramnames.index) self.n_fixed = len(self.fixed_names) self.n_sampled = len(self.sampled_names) self.ndim = self.n_sampled # for Eryn # map from sampled param name to index in θ-vector self.eryn_mapping = {name: i for i, name in enumerate(self.sampled_names)} # Create a eryn prior dict that maps the names to the θ-vector self.eryn_prior_dict = {self.eryn_mapping[name]: self.sampled_prior_dict[name] for name in self.sampled_names} # 4) Build Eryn prior object for sampled parameters self.eryn_prior_container = ProbDistContainer(self.eryn_prior_dict) if latex_labels is None: latex_labels = {} self.latex_labels = {name: latex_labels.get(name, name) for name in self.discovery_paramnames} self.latex_list = [self.latex_labels[name] for name in self.discovery_paramnames] self.sampled_names_latex = [self.latex_labels[name] for name in self.sampled_names] self.fixed_names_latex = [self.latex_labels[name] for name in self.fixed_names] # 5) Likelihood function (Discovery accepts a dict) # JIT optional — your earlier code used jax, keep the same idea if you want. self._loglike_fn = model.logL # must accept a dict of {name: value}
[docs] def create_sampler(self, nwalkers, **kwargs): """Create an ensemble sampler for MCMC sampling. This method initializes an EnsembleSampler object for Markov Chain Monte Carlo sampling, using the provided likelihood function and priors. Parameters ---------- nwalkers : int Number of walkers to use in the ensemble sampler **kwargs : dict Additional keyword arguments to pass to the EnsembleSampler constructor Returns ------- EnsembleSampler Initialized ensemble sampler object Raises ------ ValueError If no parameters are marked for sampling (ndim = 0) Notes ----- The method creates an internal likelihood function that combines both fixed and sampled parameters before evaluation. The sampler is stored as an instance attribute and the initial shape for p0 is recorded. """ if self.ndim == 0: raise ValueError("No sampled parameters. Provide priors that mark at least one parameter as non-fixed.") def _loglike_only(theta): # Merge fixed + current sampled theta params = self.unpack(theta) # sampled params.update(self.fixed_param_dict) # add fixed val = float(self._loglike_fn(params)) return val self.sampler = EnsembleSampler( nwalkers, self.ndim, _loglike_only, priors=self.eryn_prior_container, **kwargs ) self.p0_shape = self.sampler.backend.shape["model_0"][:-1] return self.sampler
[docs] def run_sampler(self, nsteps, p0=None, **kwargs): """ Run the MCMC sampler for a specified number of steps. This method executes the MCMC sampling process using the previously created sampler. It can start from provided initial positions or generate them from the prior distributions. Parameters ---------- nsteps : int Number of steps to run the MCMC sampler p0 : array-like, optional Initial positions for the walkers. If None, positions are drawn from the prior distributions. Shape should match sampler requirements. **kwargs : dict Additional keyword arguments to pass to the sampler's run_mcmc method Returns ------- sampler : object The MCMC sampler object after running the chain Raises ------ ValueError If the sampler has not been created or if no parameters are marked for sampling Notes ----- The method requires that create_sampler has been called first and that at least one parameter has been marked as non-fixed in the prior distributions. """ if self.sampler is None: raise ValueError("Sampler not created. Call the create_sampler method first.") if self.ndim == 0: raise ValueError("No sampled parameters. Provide priors that mark at least one parameter as non-fixed.") # Initial positions from priors using Eryn's prior container if p0 is None: p0 = self.eryn_prior_container.rvs(size=self.p0_shape) self.sampler.run_mcmc(p0, nsteps, **kwargs) return self.sampler
[docs] def return_all_samples(self): """ Returns all MCMC samples including both sampled and fixed parameters. This method retrieves the MCMC chain from the sampler and combines the sampled parameters with the fixed parameters to create a complete parameter set for each sample. Returns ------- dict A dictionary containing: - 'names' (list): Names of all parameters (sampled and fixed) - 'labels' (list): LaTeX labels for all parameters - 'chain' (ndarray): Array of shape (nwalkers*nsteps, n_all_params) containing all parameter samples, where n_all_params is the total number of parameters (both sampled and fixed) Raises ------ ValueError If the sampler has not been created using create_sampler() RuntimeError If the MCMC chain cannot be retrieved (e.g., if sampling hasn't been run) """ if self.sampler is None: raise ValueError("Sampler not created. Call the create_sampler method first.") try: chain = self.sampler.get_chain()["model_0"] # shape (nwalkers*nsteps, ndim) except Exception as e: raise RuntimeError("Could not get chain from sampler. Make sure sampling has been run.") from e chain_shape = chain.shape all_samples = np.zeros(chain_shape[:-1] + (len(self.discovery_paramnames),), dtype=float) # Fill in sampled parameters for i, name in enumerate(self.sampled_names): idx = self.discovery_paramnames.index(name) all_samples[..., idx] = chain[..., i] # Fill in fixed parameters for name in self.fixed_names: idx = self.discovery_paramnames.index(name) all_samples[..., idx] = self.fixed_param_dict[name] return {"names": self.discovery_paramnames, "labels": self.latex_list, "chain": all_samples} # shape (nwalkers*nsteps, n_all_params)
[docs] def return_sampled_samples(self): """ Returns the sampled parameters and their names from the MCMC chain. This method retrieves the sampling chain from the sampler and returns it along with parameter names and their LaTeX representations. Returns ------- dict Dictionary containing: - 'names' (list): List of parameter names - 'labels' (list): List of parameter names in LaTeX format - 'chain' (ndarray): MCMC chain with shape (nwalkers*nsteps, n_sampled_params) Raises ------ ValueError If sampler has not been created using create_sampler() method RuntimeError If sampling has not been run or chain cannot be retrieved Notes ----- The returned chain combines all walkers and steps into a single array, flattening the typical (nwalkers, nsteps, ndim) shape into (nwalkers*nsteps, ndim). """ if self.sampler is None: raise ValueError("Sampler not created. Call the create_sampler method first.") try: chain = self.sampler.get_chain()["model_0"] # shape (nwalkers*nsteps, ndim) except Exception as e: raise RuntimeError("Could not get chain from sampler. Make sure sampling has been run.") from e return {"names": self.sampled_names, "labels": self.sampled_names_latex, "chain": chain} # shape (nwalkers*nsteps, n_sampled_params)
[docs] def return_logZ(self, *, results=None) -> Dict[str, float]: """ Return the log evidence estimate. Note: Eryn is an MCMC sampler and does not compute the Bayesian evidence. This method is provided for API consistency but raises NotImplementedError. Raises ------ NotImplementedError Always raised - MCMC samplers do not compute evidence """ raise NotImplementedError( "Eryn is an MCMC sampler and does not compute the Bayesian evidence (logZ). " "Use a nested sampling method (Nessai, JAX-NS) if you need evidence estimates." )
[docs] def plot_trace(self, burn=0, plot_fixed=False, **kwargs): """Plot the MCMC chains for all parameters. Parameters ---------- burn : int, optional Number of initial steps to discard from the plot, by default 0 plot_fixed : bool, optional If True, includes fixed parameters in the plot, by default False **kwargs Additional keyword arguments passed to plots.plot_trace() Returns ------- matplotlib.figure.Figure Figure object containing the trace plots """ from .plots import plot_trace if plot_fixed: samples = self.return_all_samples() else: samples = self.return_sampled_samples() return plot_trace( samples, burn=burn, fixed_params=self.fixed_param_dict, fixed_names=self.fixed_names, **kwargs )
[docs] def plot_corner(self, burn=0, temp=0, **kwargs): """Create corner plots for the MCMC chain. Parameters ---------- burn : int, optional Number of initial samples to discard as burn-in period, by default 0. temp : int, optional Temperature index to plot (0 = cold chain), by default 0. **kwargs Additional keyword arguments passed to corner.corner(). Returns ------- matplotlib.figure.Figure Corner plot figure. """ from .plots import plot_corner samples = self.return_sampled_samples() return plot_corner(samples, burn=burn, temp=temp, **kwargs)
# ----- Packing helpers (sampled only) ----- #
[docs] def names(self) -> List[str]: return list(self.sampled_names)
[docs] def pack(self, d: Mapping[str, float]) -> np.ndarray: return np.array([float(d[n]) for n in self.sampled_names], dtype=float)
[docs] def unpack(self, theta: Sequence[float]) -> Dict[str, float]: return {n: float(v) for n, v in zip(self.sampled_names, theta)}
[docs] def pack_all(self, d: Mapping[str, float]) -> np.ndarray: return np.array([float(d[n]) for n in self.discovery_paramnames], dtype=float)
[docs] def unpack_all(self, theta: Sequence[float]) -> Dict[str, float]: return {n: float(v) for n, v in zip(self.discovery_paramnames, theta)}
# ====================================================================== # # Internals # ====================================================================== # def _infer_model_param_names(self, model) -> List[str]: """Infer parameter names from the model object.""" names = [] if hasattr(model.logL, "params"): p = getattr(model.logL, "params") if isinstance(p, (list, tuple)) and all(isinstance(x, str) for x in p): names = list(p) else: raise ValueError("model.logL.params doesn't seem to exist or return a list of strings. Please provide a correct discovery model.") return names def _check_and_create_priors(self, priors, model, discovery_paramnames) -> dict: """Check and create prior distributions for the model parameters.""" try: from discovery.prior import priordict_standard except ImportError: warnings.warn("Could not find default priors in Discovery. You must provide priors explicitly.") priordict_standard = {} if priors is None: sampled_prior_dict = {} for parname in discovery_paramnames: for par, range in priordict_standard.items(): if re.match(par, parname): sampled_prior_dict[parname] = uniform_dist(range[0], range[1]) break else: raise KeyError(f"No known prior for {parname}. Please provide priors explicitly.") return sampled_prior_dict, {} elif isinstance(priors, dict): keys = priors.keys() missing = [par for par in discovery_paramnames if par not in keys] sampled_prior_dict = {} fixed_params = {} if missing: raise ValueError(f"Priors missing for parameters: {missing}. You can provide None if ", "you want to use default priors from the model but all parameters must be covered.") for i, parname in enumerate(discovery_paramnames): spec = priors[parname] if spec is None or spec == "default": # Use default prior from Discovery for par, range in priordict_standard.items(): if re.match(par, parname): sampled_prior_dict[parname] = uniform_dist(range[0], range[1]) break else: raise KeyError(f"No known default prior for {parname}. Please provide priors explicitly.") elif not isinstance(spec, dict): # Custom prior object - must have logpdf and rvs methods if hasattr(spec, 'logpdf') and hasattr(spec, 'rvs'): sampled_prior_dict[parname] = spec else: raise ValueError(f"Prior object for {parname} must have logpdf and rvs methods.") elif 'dist' not in spec: raise ValueError(f"Prior for {parname} missing 'dist' key.") elif spec['dist'] == 'uniform': if 'min' not in spec or 'max' not in spec: raise ValueError(f"Uniform prior for {parname} requires 'min' and 'max'.") sampled_prior_dict[parname] = uniform_dist(spec['min'], spec['max']) elif spec['dist'] == 'loguniform': if 'a' not in spec or 'b' not in spec: raise ValueError(f"Log-uniform prior for {parname} requires 'a' and 'b'.") sampled_prior_dict[parname] = log_uniform(spec['a'], spec['b']) elif spec['dist'] == 'fixed': if 'value' not in spec: raise ValueError(f"Fixed prior for {parname} requires 'value'.") fixed_params[parname] = spec['value'] else: raise ValueError(f"Unsupported prior dist '{spec['dist']}' for {parname}. Supported: uniform, loguniform, fixed.") return sampled_prior_dict, fixed_params