Source code for discoverysamplers.eryn_RJ_interface

"""
Reversible-jump bridge for Discovery models using Eryn.

This module provides an interface between Discovery models with variable dimensions
(e.g., variable number of gravitational wave sources) and Eryn's reversible-jump
MCMC sampler.

The key components are:
- ``RJ_Discovery_model``: A wrapper that caches likelihoods for all model configurations
  (e.g., 1 source, 2 sources, etc.) and provides a unified ``logL`` function that
  Eryn's RJ sampler can call.  Supports both *variable* branches (RJ) and *fixed*
  branches (always-present parameter groups like pulsar noise or GW background).
- ``DiscoveryErynRJBridge``: The interface class that sets up the Eryn sampler with
  proper priors and handles sampling, result extraction, and plotting.

Example usage (single-branch, backward compatible)::

    rj_model = RJ_Discovery_model(
        psrs=pulsars,
        fixed_components={'per_psr': {'base': make_fixed_components}},
        variable_components={'global': {'cw': (signal_constructor, base_param_names)}},
        variable_component_numbers={'cw': (1, 4)},
    )
    priors = {"cw": {0: uniform_dist(-20, -11), ...}}
    bridge = DiscoveryErynRJBridge(rj_model, priors=priors)
    bridge.create_sampler(nwalkers=32, ntemps=2)
    bridge.run_sampler(nsteps=5000)

Example usage (multi-branch with fixed noise/GW branches)::

    rj_model = RJ_Discovery_model(
        psrs=pulsars,
        fixed_components={'per_psr': {'base': make_fixed_components}},
        variable_components={'global': {'cw': (signal_constructor, cw_param_names)}},
        variable_component_numbers={'cw': (0, 4)},
        fixed_branches={
            'psrn': ['red_noise', 'dm_gp', 'chrom', 'dmexp'],
            'gw': ['gw_'],
        },
    )
    priors = {
        "psrn": {0: ..., 1: ..., ...},
        "gw": {0: ..., 1: ...},
        "cw": {0: ..., 1: ..., ...},
    }
    bridge = DiscoveryErynRJBridge(rj_model, priors=priors)
    bridge.create_sampler(nwalkers=32, ntemps=2)
    bridge.run_sampler(nsteps=5000)

Example usage (single-likelihood mode — one JIT'd likelihood for all configs)::

    rj_model = RJ_Discovery_model(
        psrs=pulsars,
        fixed_components={
            'per_psr': {'base': make_fixed_components},
            'global': {'globalgp': globalgp},
        },
        variable_components={'global': {'cw': (signal_constructor, cw_param_names)}},
        variable_component_numbers={'cw': (0, 5)},
        fixed_branches={
            'psrn': ['red_noise', 'dm_gp', 'chrom', 'dmexp'],
            'gw': ['gw_'],
        },
        single_likelihood=True,  # Build 1 likelihood, silence dead sources
    )

Tested with:
  - discovery >= 0.5
  - eryn >= 1.2
"""

from __future__ import annotations

import itertools
import re
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import warnings
import numpy as np

try:
    from eryn.prior import uniform_dist, ProbDistContainer
    from eryn.ensemble import EnsembleSampler
    from eryn.state import State
    from eryn.moves import GaussianMove
    from eryn.backends import HDFBackend
except ImportError:
    raise ImportError("eryn is not installed. Please install it to use this module.")


[docs] class RJ_Discovery_model: """ Discovery model wrapper for reversible-jump MCMC sampling with Eryn. This class manages multiple model configurations (e.g., different numbers of gravitational wave sources) by pre-computing and caching the likelihood for each configuration. It provides a unified ``logL`` interface that Eryn's RJ sampler can call with nested parameter lists. It also supports **fixed branches**: groups of always-present parameters (e.g., pulsar noise, GW background) that are exposed as separate Eryn branches with ``nleaves=1``. This enables Gibbs-style block updates between noise, background, and signal parameters. Parameters ---------- psrs : list List of pulsar objects (Discovery Pulsar instances). fixed_components : dict Components that don't change in number. Structure:: { 'per_psr': {'name': constructor_function}, 'global': {'name': constructor_function} # optional } Where constructor_function(psr) returns a list of model components. variable_components : dict Components that can vary in number. Structure:: { 'global': {'branch_name': (constructor_func, base_param_names)} } Where constructor_func() returns (delay_function, param_names) and base_param_names is a list like ['log10_h0', 'log10_f0', ...]. variable_component_numbers : dict Min/max counts for each variable component:: {'branch_name': (min_count, max_count)} fixed_branches : dict or None, optional Groups of always-present parameters to expose as separate Eryn branches. Keys are branch names, values are lists of substring patterns used to match Discovery parameter names. Example:: { 'psrn': ['red_noise', 'dm_gp', 'chrom', 'dmexp'], 'gw': ['gw_'], } Parameters matching any pattern for a branch are assigned to that branch. Parameters not matching any fixed branch or variable component become truly fixed (not sampled). Default ``None`` (no fixed branches — backward-compatible behavior). custom_logL : callable or None, optional Custom log-likelihood function with signature ``custom_logL(param_dict, config) -> float``, where ``param_dict`` is a flat dictionary of all parameter values and ``config`` is a dict mapping variable component names to their active counts (e.g., ``{'cw': 1}``). When provided, this replaces the standard ``GlobalLikelihood.logL`` call. Useful for phase-marginalized likelihoods. Default ``None``. verbose : bool, optional Print detailed information during setup. Default False. single_likelihood : bool, optional If True, build only ONE likelihood with the maximum number of sources and silence dead sources by setting their amplitude to ``zero_amplitude_value``. This avoids caching N separate likelihoods and ensures the JAX-compiled likelihood function is never recompiled during sampling. Default False. zero_amplitude_value : float, optional Value assigned to the amplitude parameter of dead sources in single_likelihood mode. Default -300.0 (i.e., h0 = 10^-300 ≈ 0). zero_amplitude_param : str, optional Substring used to identify the amplitude parameter among the base parameter names. Default ``"log10_h0"``. Notes ----- **single_likelihood edge cases:** - When all sources are dead (``n_active=0``), all CW amplitudes are set to ``zero_amplitude_value`` and the likelihood equals the noise-only value. Verified to match the standard 0-source likelihood to machine precision. - Dead-source non-amplitude parameters default to ``0.0``. This is safe because ``h0 ~ 0`` makes the CW contribution vanish regardless of other parameter values (the signal is linear in amplitude). - This mode requires that the signal model is **linear in the amplitude parameter** (i.e., setting amplitude to zero produces zero signal). All Discovery CW delay models satisfy this. Attributes ---------- likelihood_cache : dict Cache of Discovery GlobalLikelihood objects for each configuration. params : dict Dictionary with 'fixed', 'variable', and 'fixed_branches' parameter info. fixed_branch_names : list Ordered list of fixed branch names (empty if no fixed branches). fixed_branch_param_names : dict Maps each fixed branch name to its ordered list of Discovery parameter names. Examples -------- >>> rj_model = RJ_Discovery_model( ... psrs=pulsars, ... fixed_components={'per_psr': {'base': make_fixed}}, ... variable_components={'global': {'cw': (make_cw_signal, param_names)}}, ... variable_component_numbers={'cw': (1, 4)}, ... fixed_branches={'psrn': ['red_noise', 'dm_gp'], 'gw': ['gw_']}, ... ) """
[docs] def __init__( self, psrs: List[Any], fixed_components: Dict[str, Dict[str, Callable]], variable_components: Dict[str, Dict[str, Tuple[Callable, List[str]]]], variable_component_numbers: Dict[str, Tuple[int, int]], fixed_branches: Optional[Dict[str, List[str]]] = None, custom_logL: Optional[Callable] = None, verbose: bool = False, single_likelihood: bool = False, zero_amplitude_value: float = -300.0, zero_amplitude_param: str = "log10_h0", ) -> None: # Delayed import to avoid circular dependencies try: import discovery as ds self._ds = ds except ImportError: raise ImportError("discovery is not installed. Please install it to use RJ_Discovery_model.") self.psrs = psrs self.verbose = verbose self._custom_logL = custom_logL self._single_likelihood = single_likelihood self._zero_amplitude_value = zero_amplitude_value self._zero_amplitude_param = zero_amplitude_param # Parse fixed components self.fixed_components = fixed_components self.fixed_per_psr = fixed_components.get("per_psr", {}) # Global fixed components are passed as globalgp to GlobalLikelihood # Expected format: {'globalgp': callable_returning_globalgp_component} self.fixed_global = fixed_components.get("global", None) # Parse variable components self.variable_components = variable_components self.variable_global = variable_components.get("global", None) self.variable_per_psr = variable_components.get("per_psr", None) if self.variable_global is None: raise ValueError( "No global variable components provided. At least one variable component " "is required for RJMCMC. Use the standard DiscoveryErynBridge for fixed models." ) if self.variable_per_psr is not None: raise NotImplementedError("Variable per-pulsar components are not implemented.") self.variable_component_numbers = variable_component_numbers # Fixed branches configuration self._fixed_branches_spec = fixed_branches or {} self.fixed_branch_names: List[str] = list(self._fixed_branches_spec.keys()) self.fixed_branch_param_names: Dict[str, List[str]] = {} # Caches self.likelihood_cache: Dict[Tuple, Any] = {} self.param_dicts_cache: Dict[Tuple, Dict] = {} self.param_mappings_cache: Dict[Tuple, List] = {} # Store base parameter names for each variable component self.base_param_names_variable: Dict[str, List[str]] = {} for comp_name, (constructor, base_names) in self.variable_global.items(): self.base_param_names_variable[comp_name] = base_names # Pre-compute configurations if self._single_likelihood: self._precompute_single_likelihood() else: self._precompute_configurations() self._determine_all_params() # In single_likelihood mode, JIT-compile the max-config likelihood once if self._single_likelihood: self._setup_single_likelihood_jit()
def _determine_all_params(self) -> None: """Determine fixed, variable, and fixed-branch parameter sets.""" # Get all params from all configurations and find the config with most parameters max_params: List[str] = [] for likelihood in self.likelihood_cache.values(): params = likelihood.logL.params if len(params) > len(max_params): max_params = params # Identify variable parameters (belonging to RJ components) variable_param_set = set() for comp_name, (min_count, max_count) in self.variable_component_numbers.items(): for param in max_params: if any(param.startswith(f"{comp_name}{i}_") for i in range(max_count)): variable_param_set.add(param) # Non-variable parameters (candidates for fixed branches or truly fixed) non_variable_params = [p for p in max_params if p not in variable_param_set] # Assign non-variable params to fixed branches (if specified) assigned_to_branch = set() for branch_name, patterns in self._fixed_branches_spec.items(): branch_params = [] for param in non_variable_params: if any(pattern in param for pattern in patterns): branch_params.append(param) assigned_to_branch.add(param) self.fixed_branch_param_names[branch_name] = branch_params if self.verbose: print(f"Fixed branch '{branch_name}': {len(branch_params)} parameters") for p in branch_params: print(f" {p}") # Truly fixed params: not variable, not in any fixed branch # These need default values since they won't be sampled. self.fixed_params = [ p for p in non_variable_params if p not in assigned_to_branch ] # Default values for truly fixed params (zeros). Users can override # via the ``fixed_param_values`` attribute after construction. self.fixed_param_values: Dict[str, float] = {p: 0.0 for p in self.fixed_params} if self.verbose and self.fixed_params: print(f"Truly fixed (not sampled) parameters: {len(self.fixed_params)}") for p in self.fixed_params: print(f" {p}") # Variable params are the base names for each component type self.variable_params: Dict[str, List[str]] = {} for comp_name, (constructor, base_names) in self.variable_global.items(): self.variable_params[comp_name] = base_names self.params = { "fixed": self.fixed_params, "variable": self.variable_params, "fixed_branches": self.fixed_branch_param_names, } def _precompute_configurations(self, recompute: bool = True) -> None: """Pre-compute all valid model configurations and their likelihoods.""" if self.likelihood_cache and not recompute: if self.verbose: print("Likelihood configurations already pre-computed, skipping.") return if self.likelihood_cache and recompute: if self.verbose: print("Warning: Recomputing likelihood configurations.") self.likelihood_cache = {} self.param_mappings_cache = {} self.param_dicts_cache = {} # Get all possible counts for each variable component component_ranges = [] component_names = [] for comp_name, (min_count, max_count) in self.variable_component_numbers.items(): component_names.append(comp_name) component_ranges.append(range(min_count, max_count + 1)) # Generate all combinations for counts in itertools.product(*component_ranges): config = dict(zip(component_names, counts)) config_key = tuple(sorted(config.items())) if self.verbose: print(f"Building likelihood for configuration: {config}") # Build likelihood for this configuration likelihood = self._build_likelihood(config) self.likelihood_cache[config_key] = likelihood # Determine parameter mapping param_dict, param_mapping = self._find_param_mapping(config, likelihood) self.param_dicts_cache[config_key] = param_dict self.param_mappings_cache[config_key] = param_mapping if self.verbose: print(f" Parameters: {likelihood.logL.params}") print(f" Total params: {len(likelihood.logL.params)}") if self.verbose: print(f"Pre-computed {len(self.likelihood_cache)} model configurations") def _precompute_single_likelihood(self) -> None: """Build only the max-config likelihood (single_likelihood mode). Instead of caching one likelihood per configuration, we build a single likelihood with all sources at their maximum count. Dead sources are silenced at evaluation time by setting their amplitude to ~0. """ # Build the max-config only max_config = { comp_name: max_count for comp_name, (_, max_count) in self.variable_component_numbers.items() } max_config_key = tuple(sorted(max_config.items())) if self.verbose: print(f"[single_likelihood] Building likelihood for max config: {max_config}") likelihood = self._build_likelihood(max_config) self.likelihood_cache[max_config_key] = likelihood param_dict, param_mapping = self._find_param_mapping(max_config, likelihood) self.param_dicts_cache[max_config_key] = param_dict self.param_mappings_cache[max_config_key] = param_mapping self._max_config = max_config self._max_config_key = max_config_key if self.verbose: print(f"[single_likelihood] Parameters: {likelihood.logL.params}") print(f"[single_likelihood] Total params: {len(likelihood.logL.params)}") # Build default param values for dead sources. # For each variable component, store the "zero" parameter values # that silence a source (amplitude → 0, other params at safe defaults). self._dead_source_defaults: Dict[str, Dict[str, float]] = {} for comp_name, (_, base_names) in self.variable_global.items(): defaults = {} for pname in base_names: if self._zero_amplitude_param in pname: defaults[pname] = self._zero_amplitude_value else: # Safe midpoint defaults for inactive sources defaults[pname] = 0.0 self._dead_source_defaults[comp_name] = defaults if self.verbose: print(f"[single_likelihood] Dead-source defaults: {self._dead_source_defaults}") def _setup_single_likelihood_jit(self) -> None: """JIT-compile the single max-config likelihood.""" try: import jax except ImportError: if self.verbose: print("[single_likelihood] JAX not available, skipping JIT") self._jit_logL = None return max_lkl = self.likelihood_cache[self._max_config_key] self._jit_logL = jax.jit(max_lkl.logL) if self.verbose: print("[single_likelihood] JIT-compiled max-config likelihood") def _build_likelihood(self, config: Dict[str, int]) -> Any: """Build a Discovery likelihood for the given configuration.""" ds = self._ds pslmodels = [] for psr in self.psrs: model_components = [] # Add fixed components for comp_name, comp_constructor in self.fixed_per_psr.items(): component = comp_constructor(psr) if isinstance(component, list): model_components.extend(component) else: model_components.append(component) # Add variable components based on current configuration for comp_name, count in config.items(): if count == 0: continue constructor_func, base_names = self.variable_global[comp_name] base_delay_fn = constructor_func()[0] # Get the delay function # Add 'count' instances of this component for i in range(count): source_name = f"{comp_name}{i}" common_params = [f"{source_name}_{param}" for param in base_names] delay_component = ds.makedelay( psr, base_delay_fn, components=None, common=common_params, name=source_name, ) model_components.append(delay_component) pslmodels.append(ds.PulsarLikelihood(model_components)) # Build GlobalLikelihood with optional globalgp globalgp_kwargs = {} if self.fixed_global is not None: # The 'globalgp' key should map to a pre-built globalgp component # or a callable that returns one globalgp_component = self.fixed_global.get("globalgp", None) if globalgp_component is not None: if callable(globalgp_component): globalgp_component = globalgp_component() globalgp_kwargs["globalgp"] = globalgp_component return ds.GlobalLikelihood(psls=pslmodels, **globalgp_kwargs) def _find_param_mapping( self, config: Dict[str, int], lkl: Any ) -> Tuple[Dict[str, Any], List]: """ Find the mapping of parameters for a given configuration. Returns a dict splitting fixed/variable params and a nested list structure that matches what Eryn expects. """ params_list = lkl.logL.params # Determine variable parameters with their naming scheme variable_params: Dict[str, Dict[str, List[str]]] = {} for comp_name, count in config.items(): if count == 0: continue _, base_names = self.variable_global[comp_name] variable_params[comp_name] = {} for i in range(count): source_name = f"{comp_name}{i}" common_params = [f"{source_name}_{param}" for param in base_names] variable_params[comp_name][source_name] = common_params # Fixed params are all others fixed_params = params_list.copy() for comp_params in variable_params.values(): for source_params in comp_params.values(): for param in source_params: if param in fixed_params: fixed_params.remove(param) params_dict = {"fixed": fixed_params, "variable": variable_params} # Create nested list structure for Eryn param_mapping = [params_dict["fixed"]] for comp_name, sources in params_dict["variable"].items(): comp_list = [] for i in range(config[comp_name]): source_name = f"{comp_name}{i}" if source_name in sources: comp_list.append(sources[source_name]) else: raise ValueError(f"Source name {source_name} not found in variable parameters.") param_mapping.append(comp_list) return params_dict, param_mapping
[docs] def get_likelihood_for_config(self, config: Dict[str, int]) -> Any: """Get the cached likelihood for a given configuration.""" config_key = tuple(sorted(config.items())) return self.likelihood_cache.get(config_key)
[docs] def get_current_config_from_params(self, params: Dict[str, Any]) -> Dict[str, int]: """Determine the current model configuration from the parameter dictionary.""" config = {} for comp_name in self.variable_global.keys(): count = 0 base_names = self.variable_params[comp_name] while True: source_name = f"{comp_name}{count}" param_exists = any(f"{source_name}_{param}" in params for param in base_names) if not param_exists: break count += 1 config[comp_name] = count return config
def _logL(self, params: Dict[str, Any]) -> float: """Evaluate log-likelihood for given parameters (dict interface).""" config = self.get_current_config_from_params(params) likelihood = self.get_likelihood_for_config(config) if likelihood is None: raise ValueError(f"No likelihood found for configuration: {config}") return likelihood.logL(params) # ------------------------------------------------------------------ # Private helpers for logL (split out for readability) # ------------------------------------------------------------------ def _unpack_multi_branch_params(self, params, param_dict): """Unpack parameters in multi-branch mode (fixed + variable branches). Eryn sends ``[x_branch0, x_branch1, ...]`` as a single list arg when ``nbranches > 1``. Fixed branches are unpacked first, then variable (RJ) branches. *param_dict* is modified in-place. Returns ------- dict Configuration override mapping each variable component name to its number of active sources, derived directly from the array shapes. """ all_branch_data = ( params[0] if len(params) == 1 and isinstance(params[0], list) else list(params) ) n_fixed = len(self.fixed_branch_names) # Fixed branches: each has shape (1, ndim_branch) — always 1 leaf for idx, branch_name in enumerate(self.fixed_branch_names): branch_array = all_branch_data[idx] # shape (1, ndim) or (ndim,) if branch_array.ndim == 1: branch_array = branch_array[np.newaxis, :] branch_param_names = self.fixed_branch_param_names[branch_name] for j, pname in enumerate(branch_param_names): param_dict[pname] = branch_array[0, j] # Variable (RJ) branches: each has shape (n_active_leaves, ndim_branch) config_override = {} for comp_index, (comp_name, base_names) in enumerate( self.variable_params.items() ): branch_array = all_branch_data[n_fixed + comp_index] if branch_array.ndim == 1: branch_array = branch_array[np.newaxis, :] n_sources = branch_array.shape[0] config_override[comp_name] = n_sources for source_index in range(n_sources): source_name = f"{comp_name}{source_index}" for j, param_name in enumerate(base_names): param_dict[f"{source_name}_{param_name}"] = branch_array[ source_index, j ] return config_override def _unpack_legacy_params(self, params, param_dict): """Unpack parameters in legacy single-branch mode (backward compatible). Handles the original calling convention where fixed parameters come first (if any), followed by one array per variable component. *param_dict* is modified in-place. Returns ------- None Always returns ``None`` (no config override in legacy mode). """ if not self.fixed_params: fixed_params_arr = [] offset = 0 else: fixed_params_arr = params[0] offset = 1 for i, param_name in enumerate(self.fixed_params): if i < len(fixed_params_arr): param_dict[param_name] = fixed_params_arr[i] else: raise ValueError( f"Not enough fixed parameters. Expected {len(self.fixed_params)}, " f"got {len(fixed_params_arr)}." ) # Variable parameters (subsequent lists) for comp_index, (comp_name, base_names) in enumerate( self.variable_params.items() ): comp_param_lists = params[comp_index + offset] for source_index, source_params in enumerate(comp_param_lists): source_name = f"{comp_name}{source_index}" for j, param_name in enumerate(base_names): if j < len(source_params): full_param_name = f"{source_name}_{param_name}" param_dict[full_param_name] = source_params[j] else: raise ValueError( f"Not enough parameters for {source_name}. " f"Expected {len(base_names)}, got {len(source_params)}." ) return None def _pad_dead_sources(self, param_dict, config): """Pad dead sources with zero-amplitude defaults (single-likelihood mode). Ensures *param_dict* always contains keys for the maximum number of sources so that JAX JIT traces remain stable. Also fills any per-pulsar parameters (e.g. ``phi_psr``, ``d_psr``) that are missing. *param_dict* is modified in-place. """ for comp_name, (_, max_count) in self.variable_component_numbers.items(): n_active = config.get(comp_name, 0) defaults = self._dead_source_defaults[comp_name] base_names = self.variable_params[comp_name] for i in range(n_active, max_count): source_name = f"{comp_name}{i}" for pname in base_names: param_dict[f"{source_name}_{pname}"] = defaults[pname] # Also fill per-pulsar params for dead sources (e.g. phi_psr, d_psr) max_lkl = self.likelihood_cache[self._max_config_key] for p in max_lkl.logL.params: if p not in param_dict: param_dict[p] = self.fixed_param_values.get(p, 0.0) def _evaluate_likelihood(self, param_dict, config, config_override): """Dispatch to the appropriate likelihood evaluation path. Handles custom log-likelihood, JIT-compiled single-likelihood, config-based cache lookup, and the legacy ``_logL`` fallback. Parameters ---------- param_dict : dict Flat parameter dictionary. config : dict Current model configuration (component name -> active count). config_override : dict or None Non-``None`` when the configuration was determined directly from multi-branch array shapes (as opposed to inferred from param names). Returns ------- float The log-likelihood value, or ``-np.inf`` if no cached likelihood matches the requested configuration. """ if self._single_likelihood: if self._custom_logL is not None: return self._custom_logL(param_dict, config) elif self._jit_logL is not None: return self._jit_logL(param_dict) else: max_lkl = self.likelihood_cache[self._max_config_key] return max_lkl.logL(param_dict) elif self._custom_logL is not None: return self._custom_logL(param_dict, config) elif config_override is not None: likelihood = self.get_likelihood_for_config(config_override) if likelihood is None: return -np.inf return likelihood.logL(param_dict) else: return self._logL(param_dict) # ------------------------------------------------------------------ # Main log-likelihood entry point # ------------------------------------------------------------------
[docs] def logL(self, *params) -> float: """ Log-likelihood function for Eryn's RJ sampler. When called with a single branch (no fixed branches), Eryn passes a single 2D array of shape ``(n_active_leaves, ndim)``. When called with multiple branches (fixed + variable), Eryn passes a list of 2D arrays, one per branch in the order given by ``branch_names`` (fixed branches first, then variable branches). Parameters ---------- *params : arrays Nested structure of parameters from Eryn. Returns ------- float Log-likelihood value (or -inf for invalid configurations). """ param_dict = dict(self.fixed_param_values) if self.fixed_branch_names: config_override = self._unpack_multi_branch_params(params, param_dict) else: config_override = self._unpack_legacy_params(params, param_dict) config = ( config_override if config_override is not None else self.get_current_config_from_params(param_dict) ) if self._single_likelihood: self._pad_dead_sources(param_dict, config) logL_val = self._evaluate_likelihood(param_dict, config, config_override) return float( np.nan_to_num(logL_val, nan=-np.inf, posinf=np.inf, neginf=-np.inf) )
[docs] def get_all_configurations(self) -> List[Dict[str, int]]: """Return all pre-computed configurations.""" return [dict(config_key) for config_key in self.likelihood_cache.keys()]
[docs] def params_all_configurations(self) -> List[str]: """Return the list of all parameters across all configurations.""" all_params: set = set() for likelihood in self.likelihood_cache.values(): all_params.update(likelihood.logL.params) return sorted(all_params)
[docs] def get_param_dict_for_config(self, config: Dict[str, int]) -> Optional[Dict]: """Get parameter dictionary for a specific configuration.""" return self.param_dicts_cache.get(tuple(sorted(config.items())))
[docs] def get_param_mapping_for_config(self, config: Dict[str, int]) -> Optional[List]: """Get parameter mapping for a specific configuration.""" return self.param_mappings_cache.get(tuple(sorted(config.items())))
[docs] class DiscoveryErynRJBridge: """ Bridge between RJ_Discovery_model and Eryn's reversible-jump MCMC sampler. Supports two modes: 1. **Single-branch** (backward compatible): only variable (RJ) branches. 2. **Multi-branch**: fixed branches (always-present, ``nleaves=1``) alongside variable (RJ) branches, enabling Gibbs-style block updates. Parameters ---------- rj_model : RJ_Discovery_model The model with cached likelihoods for all configurations. priors : dict Priors in Eryn RJ format:: { "branch_name": { 0: prior_for_param_0, 1: prior_for_param_1, ... } } Must include entries for every fixed branch (if any) **and** every variable branch. latex_labels : dict, optional LaTeX labels for parameter names. Attributes ---------- all_branch_names : list All branch names (fixed + variable), in the order Eryn sees them. rj_branch_names : list Names of variable (RJ) branches only. fixed_branch_names : list Names of fixed (always-present) branches only. ndims_dict : dict Maps branch name -> number of parameters in that branch. nleaves_min_dict / nleaves_max_dict : dict Maps branch name -> min/max leaf count. """
[docs] def __init__( self, rj_model: RJ_Discovery_model, priors: Dict[str, Dict[int, Any]], latex_labels: Optional[Dict[str, str]] = None, ) -> None: self.rj_model = rj_model self.priors = priors self.latex_labels = latex_labels or {} # ---- Identify branches ---- self.fixed_branch_names = list(rj_model.fixed_branch_names) self.rj_branch_names = list(rj_model.variable_component_numbers.keys()) # Eryn branch order: fixed branches first, then variable self.all_branch_names = self.fixed_branch_names + self.rj_branch_names # ---- Dimensions per branch ---- self.ndims_dict: Dict[str, int] = {} for branch in self.fixed_branch_names: self.ndims_dict[branch] = len(rj_model.fixed_branch_param_names[branch]) for branch in self.rj_branch_names: self.ndims_dict[branch] = len(rj_model.variable_params[branch]) # ---- Leaf counts per branch ---- self.nleaves_min_dict: Dict[str, int] = {} self.nleaves_max_dict: Dict[str, int] = {} for branch in self.fixed_branch_names: self.nleaves_min_dict[branch] = 1 self.nleaves_max_dict[branch] = 1 for branch in self.rj_branch_names: mn, mx = rj_model.variable_component_numbers[branch] self.nleaves_min_dict[branch] = mn self.nleaves_max_dict[branch] = mx # ---- Backward-compatible shortcuts for single RJ branch ---- if len(self.rj_branch_names) == 1: first_rj = self.rj_branch_names[0] self.ndim = self.ndims_dict[first_rj] self.nleaves_min = self.nleaves_min_dict[first_rj] self.nleaves_max = self.nleaves_max_dict[first_rj] self.base_param_names = rj_model.variable_params[first_rj] self.latex_list = [ self.latex_labels.get(name, name) for name in self.base_param_names ] else: # Multiple RJ branches — use first for backward compat attrs first_rj = self.rj_branch_names[0] self.ndim = self.ndims_dict[first_rj] self.nleaves_min = self.nleaves_min_dict[first_rj] self.nleaves_max = self.nleaves_max_dict[first_rj] self.base_param_names = rj_model.variable_params[first_rj] self.latex_list = [ self.latex_labels.get(name, name) for name in self.base_param_names ] # Also keep the old attribute name for backward compatibility self.branch_names = self.all_branch_names # Store param name lists for each branch (for results extraction) self._branch_param_names: Dict[str, List[str]] = {} for branch in self.fixed_branch_names: self._branch_param_names[branch] = rj_model.fixed_branch_param_names[branch] for branch in self.rj_branch_names: self._branch_param_names[branch] = rj_model.variable_params[branch] # Sampler will be created later self.sampler: Optional[EnsembleSampler] = None self.nwalkers: Optional[int] = None self.ntemps: int = 1
@property def has_fixed_branches(self) -> bool: """Whether this bridge has fixed (always-present) branches.""" return len(self.fixed_branch_names) > 0
[docs] def create_sampler( self, nwalkers: int, ntemps: int = 1, moves: Optional[Any] = None, move_cov_factor: float = 0.01, rj_moves: Any = True, checkpoint_file: Optional[str] = None, **kwargs, ) -> EnsembleSampler: """ Create the Eryn ensemble sampler for RJMCMC. Parameters ---------- nwalkers : int Number of walkers per temperature. ntemps : int, optional Number of temperatures for parallel tempering. Default 1. moves : eryn Move, optional Custom move proposal. If None, uses GaussianMove with diagonal covariance. move_cov_factor : float, optional Factor for diagonal covariance in default GaussianMove. Default 0.01. rj_moves : bool or str or list, optional Reversible-jump move configuration passed to Eryn. For multi-branch setups, ``"separate_branches"`` is recommended. Default ``True``. checkpoint_file : str, optional Path to an HDF5 file for checkpointing. If provided, the sampler stores all chain data to this file and can be resumed from it. If the file already exists and contains data, the sampler will resume from the last saved state when ``run_sampler`` is called with ``initial_state=None``. **kwargs Additional arguments passed to EnsembleSampler. Returns ------- EnsembleSampler The configured Eryn sampler. """ self.nwalkers = nwalkers self.ntemps = ntemps self._checkpoint_file = checkpoint_file # Set up HDF5 backend for checkpointing if checkpoint_file is not None and "backend" not in kwargs: kwargs["backend"] = HDFBackend(checkpoint_file) # Build ndims, nleaves_min, nleaves_max as dicts for Eryn ndims = {b: self.ndims_dict[b] for b in self.all_branch_names} nleaves_max = {b: self.nleaves_max_dict[b] for b in self.all_branch_names} nleaves_min = {b: self.nleaves_min_dict[b] for b in self.all_branch_names} # Default moves: Gaussian with diagonal covariance per branch if moves is None: cov = { branch: np.diag(np.ones(self.ndims_dict[branch])) * move_cov_factor for branch in self.all_branch_names } moves = GaussianMove(cov) # Tempering kwargs tempering_kwargs = kwargs.pop("tempering_kwargs", None) if tempering_kwargs is None and ntemps > 1: tempering_kwargs = {"ntemps": ntemps} # For multi-branch with fixed branches, recommend "separate_branches" if self.has_fixed_branches and rj_moves is True: rj_moves = "separate_branches" self.sampler = EnsembleSampler( nwalkers, ndims, self.rj_model.logL, priors=self.priors, tempering_kwargs=tempering_kwargs, nbranches=len(self.all_branch_names), branch_names=self.all_branch_names, nleaves_max=nleaves_max, nleaves_min=nleaves_min, moves=moves, rj_moves=rj_moves, **kwargs, ) return self.sampler
[docs] def initialize_state( self, initial_nleaves: Optional[Union[int, Dict[str, int]]] = None, initial_points: Optional[Dict[str, np.ndarray]] = None, scatter: float = 1e-6, ) -> State: """ Initialize the sampler state. Parameters ---------- initial_nleaves : int or dict, optional Number of active leaves to start with. If ``int``, applies to the first RJ branch. If ``dict``, maps branch name -> count. Fixed branches always start with 1 leaf. Defaults to nleaves_min for RJ branches. initial_points : dict, optional Initial parameter values per branch. Keys are branch names, values are arrays of shape ``(nleaves, ndim)``. If ``None``, draws from priors. scatter : float, optional Standard deviation for Gaussian scatter around initial points. Default 1e-6. Returns ------- State Eryn State object ready for sampling. """ if self.sampler is None: raise ValueError("Sampler not created. Call create_sampler() first.") # Normalize initial_nleaves to dict if initial_nleaves is None: initial_nleaves_dict = {} elif isinstance(initial_nleaves, int): # Apply to first RJ branch only initial_nleaves_dict = {self.rj_branch_names[0]: initial_nleaves} else: initial_nleaves_dict = initial_nleaves initial_points = initial_points or {} coords = {} inds = {} for branch in self.all_branch_names: ndim_b = self.ndims_dict[branch] nleaves_max_b = self.nleaves_max_dict[branch] # Determine how many leaves to activate if branch in self.fixed_branch_names: n_active = 1 # Fixed branches always have 1 leaf else: n_active = initial_nleaves_dict.get( branch, self.nleaves_min_dict[branch] ) n_active = max(n_active, self.nleaves_min_dict[branch]) coords[branch] = np.zeros( (self.ntemps, self.nwalkers, nleaves_max_b, ndim_b) ) inds[branch] = np.zeros( (self.ntemps, self.nwalkers, nleaves_max_b), dtype=bool ) inds[branch][:, :, :n_active] = True if branch in initial_points: # Use provided initial point with scatter init_pt = initial_points[branch] for nn in range(min(n_active, nleaves_max_b)): for i in range(ndim_b): pt = init_pt[nn, i] if nn < len(init_pt) else init_pt[0, i] coords[branch][:, :, nn, i] = np.random.normal( loc=pt, scale=scatter, size=(self.ntemps, self.nwalkers), ) else: # Draw from priors if branch in self.priors: for nn in range(n_active): for i, prior in self.priors[branch].items(): coords[branch][:, :, nn, i] = prior.rvs( size=(self.ntemps, self.nwalkers) ) # Compute initial log-prior and log-likelihood log_prior = self.sampler.compute_log_prior(coords, inds=inds) log_like = self.sampler.compute_log_like(coords, inds=inds, logp=log_prior)[0] state = State(coords, log_like=log_like, log_prior=log_prior, inds=inds) return state
[docs] def run_sampler( self, nsteps: int, initial_state: Optional[State] = None, initial_nleaves: Optional[Union[int, Dict[str, int]]] = None, initial_points: Optional[Dict[str, np.ndarray]] = None, burn: int = 0, thin_by: int = 1, progress: bool = True, # Backward compat: accept old kwarg name initial_point: Optional[np.ndarray] = None, **kwargs, ) -> State: """ Run the RJMCMC sampler. Parameters ---------- nsteps : int Number of MCMC steps to run. initial_state : State, optional Starting state. If None, creates one using initialize_state(). initial_nleaves : int or dict, optional Passed to initialize_state() if initial_state is None. initial_points : dict, optional Passed to initialize_state() if initial_state is None. initial_point : array, optional **Deprecated.** Use ``initial_points`` instead. If provided, applied to the first RJ branch. burn : int, optional Burn-in steps to discard. Default 0. thin_by : int, optional Thinning factor. Default 1. progress : bool, optional Show progress bar. Default True. **kwargs Additional arguments passed to sampler.run_mcmc(). Returns ------- State Final state after sampling. """ if self.sampler is None: raise ValueError("Sampler not created. Call create_sampler() first.") # Backward compatibility for initial_point (singular) if initial_point is not None: warnings.warn( "initial_point is deprecated. Use initial_points={'branch_name': array} instead.", DeprecationWarning, stacklevel=2, ) if initial_points is None: initial_points = {self.rj_branch_names[0]: initial_point} if initial_state is None: initial_state = self.initialize_state( initial_nleaves=initial_nleaves, initial_points=initial_points, ) final_state = self.sampler.run_mcmc( initial_state, nsteps, burn=burn, thin_by=thin_by, progress=progress, **kwargs, ) return final_state
[docs] def get_last_state(self) -> State: """ Return the last sampler state. This can be passed as ``initial_state`` to ``run_sampler`` to continue sampling from where the previous run left off, even without an HDF5 backend. Returns ------- State The last sampler state. """ if self.sampler is None: raise ValueError("Sampler not created. Call create_sampler() first.") return self.sampler.get_last_sample()
@property def checkpoint_file(self) -> Optional[str]: """Path to the HDF5 checkpoint file, or None.""" return self._checkpoint_file @property def can_resume(self) -> bool: """Whether the sampler can resume from a checkpoint file.""" if self.sampler is None or self._checkpoint_file is None: return False return self.sampler.backend.initialized @property def completed_steps(self) -> int: """Number of completed steps stored in the backend.""" if self.sampler is None: return 0 return self.sampler.iteration
[docs] def resume( self, nsteps: int, progress: bool = True, **kwargs, ) -> State: """ Resume sampling from the last checkpoint. Continues from the last stored state in the HDF5 backend (or the last in-memory state) for ``nsteps`` additional steps. Parameters ---------- nsteps : int Number of additional MCMC steps to run. progress : bool, optional Show progress bar. Default True. **kwargs Additional arguments passed to sampler.run_mcmc(). Returns ------- State Final state after the additional steps. """ if self.sampler is None: raise ValueError("Sampler not created. Call create_sampler() first.") last_state = self.sampler.get_last_sample() final_state = self.sampler.run_mcmc( last_state, nsteps, progress=progress, **kwargs, ) return final_state
[docs] def return_sampled_samples( self, branch: Optional[str] = None, temperature: int = 0 ) -> Dict[str, Any]: """ Return the sampled parameter chains for a branch. Parameters ---------- branch : str, optional Branch name. Defaults to first RJ branch. temperature : int, optional Temperature index. Default 0 (coldest). Returns ------- dict Dictionary with 'names', 'labels', and 'chain' keys. """ if self.sampler is None: raise ValueError("Sampler not created. Call create_sampler() first.") if branch is None: branch = self.rj_branch_names[0] chain = self.sampler.get_chain()[branch] chain_temp = chain[:, temperature] param_names = self._branch_param_names[branch] labels = [self.latex_labels.get(n, n) for n in param_names] return { "names": param_names, "labels": labels, "chain": chain_temp, }
[docs] def return_flat_samples( self, branch: Optional[str] = None, temperature: int = 0 ) -> np.ndarray: """ Return flattened samples, excluding inactive (NaN) entries. Parameters ---------- branch : str, optional Branch name. Defaults to first RJ branch. temperature : int, optional Temperature index. Default 0. Returns ------- ndarray Flattened samples of shape (n_valid_samples, ndim). """ samples = self.return_sampled_samples(branch=branch, temperature=temperature) chain = samples["chain"] ndim_b = chain.shape[-1] flat_chain = chain.reshape(-1, ndim_b) valid_mask = ~np.isnan(flat_chain[:, 0]) return flat_chain[valid_mask]
[docs] def return_nleaves(self, branch: Optional[str] = None) -> np.ndarray: """ Return the number of active leaves at each step. Parameters ---------- branch : str, optional Branch name. Defaults to first RJ branch. Returns ------- ndarray Array of shape (nsteps, ntemps, nwalkers) with leaf counts. """ if self.sampler is None: raise ValueError("Sampler not created. Call create_sampler() first.") if branch is None: branch = self.rj_branch_names[0] return self.sampler.get_nleaves()[branch]
[docs] def return_logZ(self, *, results=None) -> Dict[str, float]: """Not supported for MCMC samplers.""" raise NotImplementedError( "Eryn RJMCMC is an MCMC sampler and does not compute Bayesian evidence (logZ). " "Use nested sampling (Nessai, JAX-NS) for evidence estimates." )
[docs] def plot_nleaves_histogram( self, branch: Optional[str] = None, figsize: Tuple[int, int] = (6, 12), ) -> "plt.Figure": """ Plot histogram of the number of active leaves at each temperature. Parameters ---------- branch : str, optional RJ branch to plot. Defaults to first RJ branch. figsize : tuple, optional Figure size. Default (6, 12). Returns ------- matplotlib.figure.Figure The figure object. """ import matplotlib.pyplot as plt if branch is None: branch = self.rj_branch_names[0] nleaves = self.return_nleaves(branch=branch) nleaves_min_b = self.nleaves_min_dict[branch] nleaves_max_b = self.nleaves_max_dict[branch] bins = np.arange(nleaves_min_b - 0.5, nleaves_max_b + 1.5) fig, axes = plt.subplots(self.ntemps, 1, sharex=True, figsize=figsize) if self.ntemps == 1: axes = [axes] axes[-1].set_xlabel("Number of Models") axes[-1].set_xticks(np.arange(nleaves_min_b, nleaves_max_b + 1)) for temp, ax in enumerate(axes): ax.hist(nleaves[:, temp].flatten(), bins=bins) ax.text( 1.02, 0.45, f"Temperature {temp}", horizontalalignment="left", transform=ax.transAxes, ) fig.tight_layout() return fig
[docs] def plot_corner( self, branch: Optional[str] = None, temperature: int = 0, **kwargs ) -> "plt.Figure": """ Create a corner plot of the sampled parameters. Parameters ---------- branch : str, optional Branch name. Defaults to first RJ branch. temperature : int, optional Temperature index. Default 0. **kwargs Additional arguments passed to corner.corner(). Returns ------- matplotlib.figure.Figure The corner plot figure. """ import corner samples = self.return_sampled_samples(branch=branch, temperature=temperature) flat_chain = samples["chain"].reshape(-1, samples["chain"].shape[-1]) valid = ~np.isnan(flat_chain[:, 0]) fig = corner.corner( flat_chain[valid], labels=samples["labels"], show_titles=True, title_fmt=".2f", title_kwargs={"fontsize": 12}, **kwargs, ) return fig
[docs] def print_config_summary(self) -> None: """Print a summary of all model configurations.""" print("RJ Model Configuration Summary") print("=" * 50) if self.has_fixed_branches: print(f"Fixed branches: {self.fixed_branch_names}") for b in self.fixed_branch_names: print(f" {b}: {self.ndims_dict[b]} params (always active)") print(f"RJ branches: {self.rj_branch_names}") for b in self.rj_branch_names: print( f" {b}: {self.ndims_dict[b]} params per source, " f"leaves [{self.nleaves_min_dict[b]}, {self.nleaves_max_dict[b]}]" ) print() if self.rj_model._single_likelihood: print(f"Mode: single_likelihood (1 JIT-compiled likelihood for all configs)") max_cfg = self.rj_model._max_config max_lkl = self.rj_model.likelihood_cache[self.rj_model._max_config_key] print(f" Max config {max_cfg}: {len(max_lkl.logL.params)} total parameters") print(f" Dead sources silenced via {self.rj_model._zero_amplitude_param} = {self.rj_model._zero_amplitude_value}") else: print("Available configurations:") for config in self.rj_model.get_all_configurations(): likelihood = self.rj_model.get_likelihood_for_config(config) n_params = len(likelihood.logL.params) print(f" {config}: {n_params} total parameters")
# Backwards compatibility: ErynRJBridge was the original class name. # Use DiscoveryErynRJBridge in new code; ErynRJBridge is kept as an alias. ErynRJBridge = DiscoveryErynRJBridge __all__ = ["RJ_Discovery_model", "DiscoveryErynRJBridge", "ErynRJBridge", "GaussianMove"]