Eryn Reversible-Jump Interface
This module provides the bridge interface for reversible-jump MCMC (RJMCMC) using Eryn. RJMCMC enables trans-dimensional sampling where the number of model components can vary during sampling.
Hint
RJMCMC is useful for model selection problems where you want to determine the optimal number of signal components (e.g., how many gravitational wave sources are present in the data).
Reversible-jump bridge for Discovery models using Eryn.
This module provides an interface between Discovery models with variable dimensions (e.g., variable number of gravitational wave sources) and Eryn’s reversible-jump MCMC sampler.
The key components are:
- RJ_Discovery_model: A wrapper that caches likelihoods for all model configurations
(e.g., 1 source, 2 sources, etc.) and provides a unified
logLfunction that Eryn’s RJ sampler can call. Supports both variable branches (RJ) and fixed branches (always-present parameter groups like pulsar noise or GW background).
DiscoveryErynRJBridge: The interface class that sets up the Eryn sampler with proper priors and handles sampling, result extraction, and plotting.
Example usage (single-branch, backward compatible):
rj_model = RJ_Discovery_model(
psrs=pulsars,
fixed_components={'per_psr': {'base': make_fixed_components}},
variable_components={'global': {'cw': (signal_constructor, base_param_names)}},
variable_component_numbers={'cw': (1, 4)},
)
priors = {"cw": {0: uniform_dist(-20, -11), ...}}
bridge = DiscoveryErynRJBridge(rj_model, priors=priors)
bridge.create_sampler(nwalkers=32, ntemps=2)
bridge.run_sampler(nsteps=5000)
Example usage (multi-branch with fixed noise/GW branches):
rj_model = RJ_Discovery_model(
psrs=pulsars,
fixed_components={'per_psr': {'base': make_fixed_components}},
variable_components={'global': {'cw': (signal_constructor, cw_param_names)}},
variable_component_numbers={'cw': (0, 4)},
fixed_branches={
'psrn': ['red_noise', 'dm_gp', 'chrom', 'dmexp'],
'gw': ['gw_'],
},
)
priors = {
"psrn": {0: ..., 1: ..., ...},
"gw": {0: ..., 1: ...},
"cw": {0: ..., 1: ..., ...},
}
bridge = DiscoveryErynRJBridge(rj_model, priors=priors)
bridge.create_sampler(nwalkers=32, ntemps=2)
bridge.run_sampler(nsteps=5000)
Example usage (single-likelihood mode — one JIT’d likelihood for all configs):
rj_model = RJ_Discovery_model(
psrs=pulsars,
fixed_components={
'per_psr': {'base': make_fixed_components},
'global': {'globalgp': globalgp},
},
variable_components={'global': {'cw': (signal_constructor, cw_param_names)}},
variable_component_numbers={'cw': (0, 5)},
fixed_branches={
'psrn': ['red_noise', 'dm_gp', 'chrom', 'dmexp'],
'gw': ['gw_'],
},
single_likelihood=True, # Build 1 likelihood, silence dead sources
)
- Tested with:
discovery >= 0.5
eryn >= 1.2
- class discoverysamplers.eryn_RJ_interface.RJ_Discovery_model(psrs, fixed_components, variable_components, variable_component_numbers, fixed_branches=None, custom_logL=None, verbose=False, single_likelihood=False, zero_amplitude_value=-300.0, zero_amplitude_param='log10_h0')[source]
Bases:
objectDiscovery model wrapper for reversible-jump MCMC sampling with Eryn.
This class manages multiple model configurations (e.g., different numbers of gravitational wave sources) by pre-computing and caching the likelihood for each configuration. It provides a unified
logLinterface that Eryn’s RJ sampler can call with nested parameter lists.It also supports fixed branches: groups of always-present parameters (e.g., pulsar noise, GW background) that are exposed as separate Eryn branches with
nleaves=1. This enables Gibbs-style block updates between noise, background, and signal parameters.- Parameters:
psrs (list) – List of pulsar objects (Discovery Pulsar instances).
fixed_components (dict) –
Components that don’t change in number. Structure:
{ 'per_psr': {'name': constructor_function}, 'global': {'name': constructor_function} # optional }
Where constructor_function(psr) returns a list of model components.
variable_components (dict) –
Components that can vary in number. Structure:
{ 'global': {'branch_name': (constructor_func, base_param_names)} }
Where constructor_func() returns (delay_function, param_names) and base_param_names is a list like [‘log10_h0’, ‘log10_f0’, …].
variable_component_numbers (dict) –
Min/max counts for each variable component:
{'branch_name': (min_count, max_count)}
fixed_branches (dict or None, optional) –
Groups of always-present parameters to expose as separate Eryn branches. Keys are branch names, values are lists of substring patterns used to match Discovery parameter names. Example:
{ 'psrn': ['red_noise', 'dm_gp', 'chrom', 'dmexp'], 'gw': ['gw_'], }
Parameters matching any pattern for a branch are assigned to that branch. Parameters not matching any fixed branch or variable component become truly fixed (not sampled). Default
None(no fixed branches — backward-compatible behavior).custom_logL (callable or None, optional) – Custom log-likelihood function with signature
custom_logL(param_dict, config) -> float, whereparam_dictis a flat dictionary of all parameter values andconfigis a dict mapping variable component names to their active counts (e.g.,{'cw': 1}). When provided, this replaces the standardGlobalLikelihood.logLcall. Useful for phase-marginalized likelihoods. DefaultNone.verbose (bool, optional) – Print detailed information during setup. Default False.
single_likelihood (bool, optional) – If True, build only ONE likelihood with the maximum number of sources and silence dead sources by setting their amplitude to
zero_amplitude_value. This avoids caching N separate likelihoods and ensures the JAX-compiled likelihood function is never recompiled during sampling. Default False.zero_amplitude_value (float, optional) – Value assigned to the amplitude parameter of dead sources in single_likelihood mode. Default -300.0 (i.e., h0 = 10^-300 ≈ 0).
zero_amplitude_param (str, optional) – Substring used to identify the amplitude parameter among the base parameter names. Default
"log10_h0".
Notes
single_likelihood edge cases:
When all sources are dead (
n_active=0), all CW amplitudes are set tozero_amplitude_valueand the likelihood equals the noise-only value. Verified to match the standard 0-source likelihood to machine precision.Dead-source non-amplitude parameters default to
0.0. This is safe becauseh0 ~ 0makes the CW contribution vanish regardless of other parameter values (the signal is linear in amplitude).This mode requires that the signal model is linear in the amplitude parameter (i.e., setting amplitude to zero produces zero signal). All Discovery CW delay models satisfy this.
- fixed_branch_param_names
Maps each fixed branch name to its ordered list of Discovery parameter names.
- Type:
Examples
>>> rj_model = RJ_Discovery_model( ... psrs=pulsars, ... fixed_components={'per_psr': {'base': make_fixed}}, ... variable_components={'global': {'cw': (make_cw_signal, param_names)}}, ... variable_component_numbers={'cw': (1, 4)}, ... fixed_branches={'psrn': ['red_noise', 'dm_gp'], 'gw': ['gw_']}, ... )
- __init__(psrs, fixed_components, variable_components, variable_component_numbers, fixed_branches=None, custom_logL=None, verbose=False, single_likelihood=False, zero_amplitude_value=-300.0, zero_amplitude_param='log10_h0')[source]
- get_likelihood_for_config(config)[source]
Get the cached likelihood for a given configuration.
- Return type:
- get_current_config_from_params(params)[source]
Determine the current model configuration from the parameter dictionary.
- logL(*params)[source]
Log-likelihood function for Eryn’s RJ sampler.
When called with a single branch (no fixed branches), Eryn passes a single 2D array of shape
(n_active_leaves, ndim).When called with multiple branches (fixed + variable), Eryn passes a list of 2D arrays, one per branch in the order given by
branch_names(fixed branches first, then variable branches).- Parameters:
*params (arrays) – Nested structure of parameters from Eryn.
- Returns:
Log-likelihood value (or -inf for invalid configurations).
- Return type:
- class discoverysamplers.eryn_RJ_interface.DiscoveryErynRJBridge(rj_model, priors, latex_labels=None)[source]
Bases:
objectBridge between RJ_Discovery_model and Eryn’s reversible-jump MCMC sampler.
Supports two modes:
Single-branch (backward compatible): only variable (RJ) branches.
Multi-branch: fixed branches (always-present,
nleaves=1) alongside variable (RJ) branches, enabling Gibbs-style block updates.
- Parameters:
rj_model (RJ_Discovery_model) – The model with cached likelihoods for all configurations.
priors (dict) –
Priors in Eryn RJ format:
{ "branch_name": { 0: prior_for_param_0, 1: prior_for_param_1, ... } }
Must include entries for every fixed branch (if any) and every variable branch.
latex_labels (dict, optional) – LaTeX labels for parameter names.
- nleaves_min_dict / nleaves_max_dict
Maps branch name -> min/max leaf count.
- Type:
- create_sampler(nwalkers, ntemps=1, moves=None, move_cov_factor=0.01, rj_moves=True, checkpoint_file=None, **kwargs)[source]
Create the Eryn ensemble sampler for RJMCMC.
- Parameters:
nwalkers (int) – Number of walkers per temperature.
ntemps (int, optional) – Number of temperatures for parallel tempering. Default 1.
moves (eryn Move, optional) – Custom move proposal. If None, uses GaussianMove with diagonal covariance.
move_cov_factor (float, optional) – Factor for diagonal covariance in default GaussianMove. Default 0.01.
rj_moves (bool or str or list, optional) – Reversible-jump move configuration passed to Eryn. For multi-branch setups,
"separate_branches"is recommended. DefaultTrue.checkpoint_file (str, optional) – Path to an HDF5 file for checkpointing. If provided, the sampler stores all chain data to this file and can be resumed from it. If the file already exists and contains data, the sampler will resume from the last saved state when
run_sampleris called withinitial_state=None.**kwargs – Additional arguments passed to EnsembleSampler.
- Returns:
The configured Eryn sampler.
- Return type:
EnsembleSampler
- initialize_state(initial_nleaves=None, initial_points=None, scatter=1e-06)[source]
Initialize the sampler state.
- Parameters:
initial_nleaves (int or dict, optional) – Number of active leaves to start with. If
int, applies to the first RJ branch. Ifdict, maps branch name -> count. Fixed branches always start with 1 leaf. Defaults to nleaves_min for RJ branches.initial_points (dict, optional) – Initial parameter values per branch. Keys are branch names, values are arrays of shape
(nleaves, ndim). IfNone, draws from priors.scatter (float, optional) – Standard deviation for Gaussian scatter around initial points. Default 1e-6.
- Returns:
Eryn State object ready for sampling.
- Return type:
State
- run_sampler(nsteps, initial_state=None, initial_nleaves=None, initial_points=None, burn=0, thin_by=1, progress=True, initial_point=None, **kwargs)[source]
Run the RJMCMC sampler.
- Parameters:
nsteps (int) – Number of MCMC steps to run.
initial_state (State, optional) – Starting state. If None, creates one using initialize_state().
initial_nleaves (int or dict, optional) – Passed to initialize_state() if initial_state is None.
initial_points (dict, optional) – Passed to initialize_state() if initial_state is None.
initial_point (array, optional) – Deprecated. Use
initial_pointsinstead. If provided, applied to the first RJ branch.burn (int, optional) – Burn-in steps to discard. Default 0.
thin_by (int, optional) – Thinning factor. Default 1.
progress (bool, optional) – Show progress bar. Default True.
**kwargs – Additional arguments passed to sampler.run_mcmc().
- Returns:
Final state after sampling.
- Return type:
State
- get_last_state()[source]
Return the last sampler state.
This can be passed as
initial_statetorun_samplerto continue sampling from where the previous run left off, even without an HDF5 backend.- Returns:
The last sampler state.
- Return type:
State
- resume(nsteps, progress=True, **kwargs)[source]
Resume sampling from the last checkpoint.
Continues from the last stored state in the HDF5 backend (or the last in-memory state) for
nstepsadditional steps.
- return_sampled_samples(branch=None, temperature=0)[source]
Return the sampled parameter chains for a branch.
- return_flat_samples(branch=None, temperature=0)[source]
Return flattened samples, excluding inactive (NaN) entries.
- return_nleaves(branch=None)[source]
Return the number of active leaves at each step.
- Parameters:
branch (str, optional) – Branch name. Defaults to first RJ branch.
- Returns:
Array of shape (nsteps, ntemps, nwalkers) with leaf counts.
- Return type:
ndarray
- plot_nleaves_histogram(branch=None, figsize=(6, 12))[source]
Plot histogram of the number of active leaves at each temperature.
- discoverysamplers.eryn_RJ_interface.ErynRJBridge
alias of
DiscoveryErynRJBridge
See Also
Eryn MCMC Sampler - User guide for Eryn (includes RJMCMC section)
Reversible-Jump MCMC - Advanced RJMCMC guide
Parallel Tempering - Parallel tempering (recommended for RJMCMC)
Eryn Interface - Standard fixed-dimensional Eryn interface