Eryn Reversible-Jump Interface

This module provides the bridge interface for reversible-jump MCMC (RJMCMC) using Eryn. RJMCMC enables trans-dimensional sampling where the number of model components can vary during sampling.

Hint

RJMCMC is useful for model selection problems where you want to determine the optimal number of signal components (e.g., how many gravitational wave sources are present in the data).

Reversible-jump bridge for Discovery models using Eryn.

This module provides an interface between Discovery models with variable dimensions (e.g., variable number of gravitational wave sources) and Eryn’s reversible-jump MCMC sampler.

The key components are: - RJ_Discovery_model: A wrapper that caches likelihoods for all model configurations

(e.g., 1 source, 2 sources, etc.) and provides a unified logL function that Eryn’s RJ sampler can call. Supports both variable branches (RJ) and fixed branches (always-present parameter groups like pulsar noise or GW background).

  • DiscoveryErynRJBridge: The interface class that sets up the Eryn sampler with proper priors and handles sampling, result extraction, and plotting.

Example usage (single-branch, backward compatible):

rj_model = RJ_Discovery_model(
    psrs=pulsars,
    fixed_components={'per_psr': {'base': make_fixed_components}},
    variable_components={'global': {'cw': (signal_constructor, base_param_names)}},
    variable_component_numbers={'cw': (1, 4)},
)
priors = {"cw": {0: uniform_dist(-20, -11), ...}}
bridge = DiscoveryErynRJBridge(rj_model, priors=priors)
bridge.create_sampler(nwalkers=32, ntemps=2)
bridge.run_sampler(nsteps=5000)

Example usage (multi-branch with fixed noise/GW branches):

rj_model = RJ_Discovery_model(
    psrs=pulsars,
    fixed_components={'per_psr': {'base': make_fixed_components}},
    variable_components={'global': {'cw': (signal_constructor, cw_param_names)}},
    variable_component_numbers={'cw': (0, 4)},
    fixed_branches={
        'psrn': ['red_noise', 'dm_gp', 'chrom', 'dmexp'],
        'gw': ['gw_'],
    },
)
priors = {
    "psrn": {0: ..., 1: ..., ...},
    "gw": {0: ..., 1: ...},
    "cw": {0: ..., 1: ..., ...},
}
bridge = DiscoveryErynRJBridge(rj_model, priors=priors)
bridge.create_sampler(nwalkers=32, ntemps=2)
bridge.run_sampler(nsteps=5000)

Example usage (single-likelihood mode — one JIT’d likelihood for all configs):

rj_model = RJ_Discovery_model(
    psrs=pulsars,
    fixed_components={
        'per_psr': {'base': make_fixed_components},
        'global': {'globalgp': globalgp},
    },
    variable_components={'global': {'cw': (signal_constructor, cw_param_names)}},
    variable_component_numbers={'cw': (0, 5)},
    fixed_branches={
        'psrn': ['red_noise', 'dm_gp', 'chrom', 'dmexp'],
        'gw': ['gw_'],
    },
    single_likelihood=True,  # Build 1 likelihood, silence dead sources
)
Tested with:
  • discovery >= 0.5

  • eryn >= 1.2

class discoverysamplers.eryn_RJ_interface.RJ_Discovery_model(psrs, fixed_components, variable_components, variable_component_numbers, fixed_branches=None, custom_logL=None, verbose=False, single_likelihood=False, zero_amplitude_value=-300.0, zero_amplitude_param='log10_h0')[source]

Bases: object

Discovery model wrapper for reversible-jump MCMC sampling with Eryn.

This class manages multiple model configurations (e.g., different numbers of gravitational wave sources) by pre-computing and caching the likelihood for each configuration. It provides a unified logL interface that Eryn’s RJ sampler can call with nested parameter lists.

It also supports fixed branches: groups of always-present parameters (e.g., pulsar noise, GW background) that are exposed as separate Eryn branches with nleaves=1. This enables Gibbs-style block updates between noise, background, and signal parameters.

Parameters:
  • psrs (list) – List of pulsar objects (Discovery Pulsar instances).

  • fixed_components (dict) –

    Components that don’t change in number. Structure:

    {
        'per_psr': {'name': constructor_function},
        'global': {'name': constructor_function}  # optional
    }
    

    Where constructor_function(psr) returns a list of model components.

  • variable_components (dict) –

    Components that can vary in number. Structure:

    {
        'global': {'branch_name': (constructor_func, base_param_names)}
    }
    

    Where constructor_func() returns (delay_function, param_names) and base_param_names is a list like [‘log10_h0’, ‘log10_f0’, …].

  • variable_component_numbers (dict) –

    Min/max counts for each variable component:

    {'branch_name': (min_count, max_count)}
    

  • fixed_branches (dict or None, optional) –

    Groups of always-present parameters to expose as separate Eryn branches. Keys are branch names, values are lists of substring patterns used to match Discovery parameter names. Example:

    {
        'psrn': ['red_noise', 'dm_gp', 'chrom', 'dmexp'],
        'gw': ['gw_'],
    }
    

    Parameters matching any pattern for a branch are assigned to that branch. Parameters not matching any fixed branch or variable component become truly fixed (not sampled). Default None (no fixed branches — backward-compatible behavior).

  • custom_logL (callable or None, optional) – Custom log-likelihood function with signature custom_logL(param_dict, config) -> float, where param_dict is a flat dictionary of all parameter values and config is a dict mapping variable component names to their active counts (e.g., {'cw': 1}). When provided, this replaces the standard GlobalLikelihood.logL call. Useful for phase-marginalized likelihoods. Default None.

  • verbose (bool, optional) – Print detailed information during setup. Default False.

  • single_likelihood (bool, optional) – If True, build only ONE likelihood with the maximum number of sources and silence dead sources by setting their amplitude to zero_amplitude_value. This avoids caching N separate likelihoods and ensures the JAX-compiled likelihood function is never recompiled during sampling. Default False.

  • zero_amplitude_value (float, optional) – Value assigned to the amplitude parameter of dead sources in single_likelihood mode. Default -300.0 (i.e., h0 = 10^-300 ≈ 0).

  • zero_amplitude_param (str, optional) – Substring used to identify the amplitude parameter among the base parameter names. Default "log10_h0".

Notes

single_likelihood edge cases:

  • When all sources are dead (n_active=0), all CW amplitudes are set to zero_amplitude_value and the likelihood equals the noise-only value. Verified to match the standard 0-source likelihood to machine precision.

  • Dead-source non-amplitude parameters default to 0.0. This is safe because h0 ~ 0 makes the CW contribution vanish regardless of other parameter values (the signal is linear in amplitude).

  • This mode requires that the signal model is linear in the amplitude parameter (i.e., setting amplitude to zero produces zero signal). All Discovery CW delay models satisfy this.

likelihood_cache

Cache of Discovery GlobalLikelihood objects for each configuration.

Type:

dict

params

Dictionary with ‘fixed’, ‘variable’, and ‘fixed_branches’ parameter info.

Type:

dict

fixed_branch_names

Ordered list of fixed branch names (empty if no fixed branches).

Type:

list

fixed_branch_param_names

Maps each fixed branch name to its ordered list of Discovery parameter names.

Type:

dict

Examples

>>> rj_model = RJ_Discovery_model(
...     psrs=pulsars,
...     fixed_components={'per_psr': {'base': make_fixed}},
...     variable_components={'global': {'cw': (make_cw_signal, param_names)}},
...     variable_component_numbers={'cw': (1, 4)},
...     fixed_branches={'psrn': ['red_noise', 'dm_gp'], 'gw': ['gw_']},
... )
__init__(psrs, fixed_components, variable_components, variable_component_numbers, fixed_branches=None, custom_logL=None, verbose=False, single_likelihood=False, zero_amplitude_value=-300.0, zero_amplitude_param='log10_h0')[source]
get_likelihood_for_config(config)[source]

Get the cached likelihood for a given configuration.

Return type:

Any

get_current_config_from_params(params)[source]

Determine the current model configuration from the parameter dictionary.

Return type:

Dict[str, int]

logL(*params)[source]

Log-likelihood function for Eryn’s RJ sampler.

When called with a single branch (no fixed branches), Eryn passes a single 2D array of shape (n_active_leaves, ndim).

When called with multiple branches (fixed + variable), Eryn passes a list of 2D arrays, one per branch in the order given by branch_names (fixed branches first, then variable branches).

Parameters:

*params (arrays) – Nested structure of parameters from Eryn.

Returns:

Log-likelihood value (or -inf for invalid configurations).

Return type:

float

get_all_configurations()[source]

Return all pre-computed configurations.

Return type:

List[Dict[str, int]]

params_all_configurations()[source]

Return the list of all parameters across all configurations.

Return type:

List[str]

get_param_dict_for_config(config)[source]

Get parameter dictionary for a specific configuration.

Return type:

Optional[Dict]

get_param_mapping_for_config(config)[source]

Get parameter mapping for a specific configuration.

Return type:

Optional[List]

class discoverysamplers.eryn_RJ_interface.DiscoveryErynRJBridge(rj_model, priors, latex_labels=None)[source]

Bases: object

Bridge between RJ_Discovery_model and Eryn’s reversible-jump MCMC sampler.

Supports two modes:

  1. Single-branch (backward compatible): only variable (RJ) branches.

  2. Multi-branch: fixed branches (always-present, nleaves=1) alongside variable (RJ) branches, enabling Gibbs-style block updates.

Parameters:
  • rj_model (RJ_Discovery_model) – The model with cached likelihoods for all configurations.

  • priors (dict) –

    Priors in Eryn RJ format:

    {
        "branch_name": {
            0: prior_for_param_0,
            1: prior_for_param_1,
            ...
        }
    }
    

    Must include entries for every fixed branch (if any) and every variable branch.

  • latex_labels (dict, optional) – LaTeX labels for parameter names.

all_branch_names

All branch names (fixed + variable), in the order Eryn sees them.

Type:

list

rj_branch_names

Names of variable (RJ) branches only.

Type:

list

fixed_branch_names

Names of fixed (always-present) branches only.

Type:

list

ndims_dict

Maps branch name -> number of parameters in that branch.

Type:

dict

nleaves_min_dict / nleaves_max_dict

Maps branch name -> min/max leaf count.

Type:

dict

__init__(rj_model, priors, latex_labels=None)[source]
property has_fixed_branches: bool

Whether this bridge has fixed (always-present) branches.

create_sampler(nwalkers, ntemps=1, moves=None, move_cov_factor=0.01, rj_moves=True, checkpoint_file=None, **kwargs)[source]

Create the Eryn ensemble sampler for RJMCMC.

Parameters:
  • nwalkers (int) – Number of walkers per temperature.

  • ntemps (int, optional) – Number of temperatures for parallel tempering. Default 1.

  • moves (eryn Move, optional) – Custom move proposal. If None, uses GaussianMove with diagonal covariance.

  • move_cov_factor (float, optional) – Factor for diagonal covariance in default GaussianMove. Default 0.01.

  • rj_moves (bool or str or list, optional) – Reversible-jump move configuration passed to Eryn. For multi-branch setups, "separate_branches" is recommended. Default True.

  • checkpoint_file (str, optional) – Path to an HDF5 file for checkpointing. If provided, the sampler stores all chain data to this file and can be resumed from it. If the file already exists and contains data, the sampler will resume from the last saved state when run_sampler is called with initial_state=None.

  • **kwargs – Additional arguments passed to EnsembleSampler.

Returns:

The configured Eryn sampler.

Return type:

EnsembleSampler

initialize_state(initial_nleaves=None, initial_points=None, scatter=1e-06)[source]

Initialize the sampler state.

Parameters:
  • initial_nleaves (int or dict, optional) – Number of active leaves to start with. If int, applies to the first RJ branch. If dict, maps branch name -> count. Fixed branches always start with 1 leaf. Defaults to nleaves_min for RJ branches.

  • initial_points (dict, optional) – Initial parameter values per branch. Keys are branch names, values are arrays of shape (nleaves, ndim). If None, draws from priors.

  • scatter (float, optional) – Standard deviation for Gaussian scatter around initial points. Default 1e-6.

Returns:

Eryn State object ready for sampling.

Return type:

State

run_sampler(nsteps, initial_state=None, initial_nleaves=None, initial_points=None, burn=0, thin_by=1, progress=True, initial_point=None, **kwargs)[source]

Run the RJMCMC sampler.

Parameters:
  • nsteps (int) – Number of MCMC steps to run.

  • initial_state (State, optional) – Starting state. If None, creates one using initialize_state().

  • initial_nleaves (int or dict, optional) – Passed to initialize_state() if initial_state is None.

  • initial_points (dict, optional) – Passed to initialize_state() if initial_state is None.

  • initial_point (array, optional) – Deprecated. Use initial_points instead. If provided, applied to the first RJ branch.

  • burn (int, optional) – Burn-in steps to discard. Default 0.

  • thin_by (int, optional) – Thinning factor. Default 1.

  • progress (bool, optional) – Show progress bar. Default True.

  • **kwargs – Additional arguments passed to sampler.run_mcmc().

Returns:

Final state after sampling.

Return type:

State

get_last_state()[source]

Return the last sampler state.

This can be passed as initial_state to run_sampler to continue sampling from where the previous run left off, even without an HDF5 backend.

Returns:

The last sampler state.

Return type:

State

property checkpoint_file: str | None

Path to the HDF5 checkpoint file, or None.

property can_resume: bool

Whether the sampler can resume from a checkpoint file.

property completed_steps: int

Number of completed steps stored in the backend.

resume(nsteps, progress=True, **kwargs)[source]

Resume sampling from the last checkpoint.

Continues from the last stored state in the HDF5 backend (or the last in-memory state) for nsteps additional steps.

Parameters:
  • nsteps (int) – Number of additional MCMC steps to run.

  • progress (bool, optional) – Show progress bar. Default True.

  • **kwargs – Additional arguments passed to sampler.run_mcmc().

Returns:

Final state after the additional steps.

Return type:

State

return_sampled_samples(branch=None, temperature=0)[source]

Return the sampled parameter chains for a branch.

Parameters:
  • branch (str, optional) – Branch name. Defaults to first RJ branch.

  • temperature (int, optional) – Temperature index. Default 0 (coldest).

Returns:

Dictionary with ‘names’, ‘labels’, and ‘chain’ keys.

Return type:

dict

return_flat_samples(branch=None, temperature=0)[source]

Return flattened samples, excluding inactive (NaN) entries.

Parameters:
  • branch (str, optional) – Branch name. Defaults to first RJ branch.

  • temperature (int, optional) – Temperature index. Default 0.

Returns:

Flattened samples of shape (n_valid_samples, ndim).

Return type:

ndarray

return_nleaves(branch=None)[source]

Return the number of active leaves at each step.

Parameters:

branch (str, optional) – Branch name. Defaults to first RJ branch.

Returns:

Array of shape (nsteps, ntemps, nwalkers) with leaf counts.

Return type:

ndarray

return_logZ(*, results=None)[source]

Not supported for MCMC samplers.

Return type:

Dict[str, float]

plot_nleaves_histogram(branch=None, figsize=(6, 12))[source]

Plot histogram of the number of active leaves at each temperature.

Parameters:
  • branch (str, optional) – RJ branch to plot. Defaults to first RJ branch.

  • figsize (tuple, optional) – Figure size. Default (6, 12).

Returns:

The figure object.

Return type:

matplotlib.figure.Figure

plot_corner(branch=None, temperature=0, **kwargs)[source]

Create a corner plot of the sampled parameters.

Parameters:
  • branch (str, optional) – Branch name. Defaults to first RJ branch.

  • temperature (int, optional) – Temperature index. Default 0.

  • **kwargs – Additional arguments passed to corner.corner().

Returns:

The corner plot figure.

Return type:

matplotlib.figure.Figure

print_config_summary()[source]

Print a summary of all model configurations.

Return type:

None

discoverysamplers.eryn_RJ_interface.ErynRJBridge

alias of DiscoveryErynRJBridge

See Also